Distributed Linear Regression via Gradient Descent¶
Distributed linear regression using gradient descent leverages the computational power of multiple nodes to efficiently train a model on large-scale datasets. The core idea is to parallelize the gradient computation and synchronize model updates, enabling scalable machine learning.
1. Distribute the Dataset¶
The dataset $D = \{(x_i, y_i)\}_{i=1}^n$ is partitioned across worker nodes in the Spark cluster.
Each worker node receives a subset $ D_k \subset D$ to process locally.
2. Local Gradient Calculation¶
Each worker computes the gradient of the loss function (e.g., MSE) with respect to the current model parameters $\theta$, using its local data $D_k$.
For linear regression:
$$ \text{Loss} = \frac{1}{2n} \sum_{i=1}^n (\hat{y}_i - y_i)^2, \quad \hat{y}_i = x_i^\top \theta $$
The local gradient on node $k$ is:
$$ \nabla_k = \frac{1}{|D_k|} \sum_{(x_i, y_i) \in D_k} x_i (x_i^\top \theta - y_i) $$
This step ensures each node computes its own partial gradient independently.
3. Aggregate Gradients Across Nodes¶
Once all local gradients $\nabla_k$ are computed:
Reduce: Local gradients are sent to the driver and summed: $$ \nabla = \sum_k \nabla_k $$
All-Reduce (efficient alternative): Each node participates in summing and receiving the global gradient, minimizing communication overhead.
4. Update Global Model Parameters¶
Using the aggregated gradient $\nabla$, update model parameters:
$$ \theta \leftarrow \theta - \eta \cdot \nabla $$
Where $\eta$ is the learning rate.
This update can be performed:
- Centrally on the driver (with
reduce
) - Or in a distributed fashion (with
all-reduce
)
5. Iterate Until Convergence¶
Repeat steps 2–4 until one of the following is met:
- Loss change is below a threshold
- Maximum number of iterations reached
- Gradient norm is below a tolerance
6. Convergence and Final Model¶
After convergence, the final model parameters $\theta^*$ are returned. This model can then be:
- Used for prediction
- Stored and deployed
- Evaluated on validation/test datasets
import findspark
findspark.init('/opt/apps/SPARK3/spark-current')
import pyspark
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("DistributedLinearRegressionGD").getOrCreate()
sc = spark.sparkContext
25/04/14 13:31:47 WARN [Thread-6] Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041. 25/04/14 13:31:47 WARN [Thread-6] Client: Neither spark.yarn.jars nor spark.yarn.archive is set, falling back to uploading libraries under SPARK_HOME.
Simulation¶
## Generate Synthetic Data
# We simulate 1000 data points where the true relationship is y = 3x + noise.
import numpy as np
np.random.seed(42)
n_samples = 1000
X = np.random.rand(n_samples) * 10
y = 3 * X + np.random.randn(n_samples)
data = list(zip(X, y))
data_rdd = sc.parallelize(data, numSlices=4)
## Define Gradient Function for Each Partition
# This function computes the partial gradient for a given partition of the data.
def compute_gradient(partition, theta):
grad_sum = 0.0
count = 0
for x, y in partition:
y_pred = theta * x
error = y_pred - y
grad_sum += error * x
count += 1
return [(grad_sum, count)] # Wrapped in list
## Gradient Descent Loop
# Here we iterate to update theta using distributed gradient computation.
import matplotlib.pyplot as plt
theta = 0.0
learning_rate = 0.001
n_iter = 500
theta_history = []
for i in range(n_iter):
theta_broadcast = sc.broadcast(theta)
grads = data_rdd.mapPartitions(lambda part: compute_gradient(part, theta_broadcast.value))
if grads.isEmpty():
raise ValueError("No gradients computed — check partition contents.")
grad_total, count_total = grads.reduce(lambda a, b: (a[0] + b[0], a[1] + b[1]))
gradient = grad_total / count_total
theta -= learning_rate * gradient
theta_history.append(theta)
if i % 5 == 0 or i == n_iter - 1:
print(f"Iter {i:2d}: theta = {theta:.4f}, grad = {gradient:.4f}")
Iter 0: theta = 0.0980, grad = -98.0360 Iter 5: theta = 0.5424, grad = -83.0813 Iter 10: theta = 0.9189, grad = -70.4079 Iter 15: theta = 1.2380, grad = -59.6677 Iter 20: theta = 1.5084, grad = -50.5658 Iter 25: theta = 1.7376, grad = -42.8524 Iter 30: theta = 1.9318, grad = -36.3156 Iter 35: theta = 2.0964, grad = -30.7759 Iter 40: theta = 2.2359, grad = -26.0813 Iter 45: theta = 2.3541, grad = -22.1028 Iter 50: theta = 2.4543, grad = -18.7311 Iter 55: theta = 2.5392, grad = -15.8738 Iter 60: theta = 2.6111, grad = -13.4524 Iter 65: theta = 2.6721, grad = -11.4003 Iter 70: theta = 2.7238, grad = -9.6613 Iter 75: theta = 2.7676, grad = -8.1875 Iter 80: theta = 2.8047, grad = -6.9386 Iter 85: theta = 2.8361, grad = -5.8802 Iter 90: theta = 2.8628, grad = -4.9832 Iter 95: theta = 2.8854, grad = -4.2230 Iter 100: theta = 2.9045, grad = -3.5788 Iter 105: theta = 2.9207, grad = -3.0329 Iter 110: theta = 2.9345, grad = -2.5703 Iter 115: theta = 2.9461, grad = -2.1782 Iter 120: theta = 2.9560, grad = -1.8459 Iter 125: theta = 2.9644, grad = -1.5643 Iter 130: theta = 2.9714, grad = -1.3257 Iter 135: theta = 2.9775, grad = -1.1235 Iter 140: theta = 2.9825, grad = -0.9521 Iter 145: theta = 2.9869, grad = -0.8069 Iter 150: theta = 2.9905, grad = -0.6838 Iter 155: theta = 2.9936, grad = -0.5795 Iter 160: theta = 2.9962, grad = -0.4911 Iter 165: theta = 2.9985, grad = -0.4162 Iter 170: theta = 3.0004, grad = -0.3527 Iter 175: theta = 3.0020, grad = -0.2989 Iter 180: theta = 3.0033, grad = -0.2533 Iter 185: theta = 3.0045, grad = -0.2147 Iter 190: theta = 3.0054, grad = -0.1819 Iter 195: theta = 3.0063, grad = -0.1542 Iter 200: theta = 3.0070, grad = -0.1306 Iter 205: theta = 3.0075, grad = -0.1107 Iter 210: theta = 3.0080, grad = -0.0938 Iter 215: theta = 3.0085, grad = -0.0795 Iter 220: theta = 3.0088, grad = -0.0674 Iter 225: theta = 3.0091, grad = -0.0571 Iter 230: theta = 3.0094, grad = -0.0484 Iter 235: theta = 3.0096, grad = -0.0410 Iter 240: theta = 3.0098, grad = -0.0348 Iter 245: theta = 3.0100, grad = -0.0295 Iter 250: theta = 3.0101, grad = -0.0250 Iter 255: theta = 3.0102, grad = -0.0212 Iter 260: theta = 3.0103, grad = -0.0179 Iter 265: theta = 3.0104, grad = -0.0152 Iter 270: theta = 3.0105, grad = -0.0129 Iter 275: theta = 3.0105, grad = -0.0109 Iter 280: theta = 3.0106, grad = -0.0092 Iter 285: theta = 3.0106, grad = -0.0078 Iter 290: theta = 3.0106, grad = -0.0066 Iter 295: theta = 3.0107, grad = -0.0056 Iter 300: theta = 3.0107, grad = -0.0048 Iter 305: theta = 3.0107, grad = -0.0040 Iter 310: theta = 3.0107, grad = -0.0034 Iter 315: theta = 3.0108, grad = -0.0029 Iter 320: theta = 3.0108, grad = -0.0025 Iter 325: theta = 3.0108, grad = -0.0021 Iter 330: theta = 3.0108, grad = -0.0018 Iter 335: theta = 3.0108, grad = -0.0015 Iter 340: theta = 3.0108, grad = -0.0013 Iter 345: theta = 3.0108, grad = -0.0011 Iter 350: theta = 3.0108, grad = -0.0009 Iter 355: theta = 3.0108, grad = -0.0008 Iter 360: theta = 3.0108, grad = -0.0007 Iter 365: theta = 3.0108, grad = -0.0006 Iter 370: theta = 3.0108, grad = -0.0005 Iter 375: theta = 3.0108, grad = -0.0004 Iter 380: theta = 3.0108, grad = -0.0003 Iter 385: theta = 3.0108, grad = -0.0003 Iter 390: theta = 3.0108, grad = -0.0002 Iter 395: theta = 3.0108, grad = -0.0002 Iter 400: theta = 3.0108, grad = -0.0002 Iter 405: theta = 3.0108, grad = -0.0001 Iter 410: theta = 3.0108, grad = -0.0001 Iter 415: theta = 3.0108, grad = -0.0001 Iter 420: theta = 3.0108, grad = -0.0001 Iter 425: theta = 3.0108, grad = -0.0001 Iter 430: theta = 3.0108, grad = -0.0001 Iter 435: theta = 3.0108, grad = -0.0001 Iter 440: theta = 3.0108, grad = -0.0000 Iter 445: theta = 3.0108, grad = -0.0000 Iter 450: theta = 3.0108, grad = -0.0000 Iter 455: theta = 3.0108, grad = -0.0000 Iter 460: theta = 3.0108, grad = -0.0000 Iter 465: theta = 3.0108, grad = -0.0000 Iter 470: theta = 3.0108, grad = -0.0000 Iter 475: theta = 3.0108, grad = -0.0000 Iter 480: theta = 3.0108, grad = -0.0000 Iter 485: theta = 3.0108, grad = -0.0000 Iter 490: theta = 3.0108, grad = -0.0000 Iter 495: theta = 3.0108, grad = -0.0000 Iter 499: theta = 3.0108, grad = -0.0000
plt.figure(figsize=(8, 5))
plt.plot(range(n_iter), theta_history, marker='o')
plt.xlabel("Iteration")
plt.ylabel("Theta (slope estimate)")
plt.title("Convergence of Theta in Gradient Descent")
plt.grid(True)
plt.show()
spark.stop()
Real data: bike-sharing¶
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import LinearRegression
from pyspark.sql.functions import col
# Step 1: Start Spark session
spark = SparkSession.builder.appName("BikeSharingLinearRegression").getOrCreate()
# Step 2: Load the dataset (assumes day.csv is available)
df = spark.read.csv("day.csv", header=True, inferSchema=True)
# Step 3: Select features and label
# Example: temp, humidity, windspeed as features
feature_cols = ["temp", "hum", "windspeed"]
df = df.select(*feature_cols, col("cnt").alias("label"))
# Step 4: Assemble features into a vector column
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
df = assembler.transform(df).select("features", "label")
# Step 5: Fit linear regression model
lr = LinearRegression(featuresCol="features", labelCol="label")
model = lr.fit(df)
# Step 6: Output results
print("Coefficients:", model.coefficients)
print("Intercept:", model.intercept)
print("R²:", model.summary.r2)
print("RMSE:", model.summary.rootMeanSquaredError)
25/04/14 13:44:15 WARN [Thread-6] Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041. 25/04/14 13:44:15 WARN [Thread-6] Client: Neither spark.yarn.jars nor spark.yarn.archive is set, falling back to uploading libraries under SPARK_HOME. 25/04/14 13:44:31 WARN [Thread-6] Instrumentation: [2d722f89] regParam is zero, which might cause numerical instability and overfitting. 25/04/14 13:44:33 WARN [dag-scheduler-event-loop] InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS 25/04/14 13:44:33 WARN [Thread-6] InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.lapack.JNILAPACK
Coefficients: [6625.532709711166,-3100.1231349121067,-4806.929324810166] Intercept: 4084.3633844519604 R²: 0.4608950096446509 RMSE: 1421.4004390323018
spark.stop()