Distributed Logistic Regression via Stochastic Gradient Descent¶
Logistic regression is a fundamental classification algorithm used to model the probability that a given input belongs to a particular category. In a distributed environment like Apache Spark, we can scale logistic regression training using Stochastic Gradient Descent (SGD), a simple yet powerful optimization algorithm.
This document outlines how to implement logistic regression from scratch using distributed SGD in Spark, including handling regularization and partition-wise gradient aggregation.
Mathematical Formulation¶
Given input features $ \mathbf{x} \in \mathbb{R}^d $ and binary target $ y \in \{0,1\} $, logistic regression models the probability as:
$$ P(y = 1 \mid \mathbf{x}) = \sigma(\mathbf{w}^\top \mathbf{x}) = \frac{1}{1 + e^{-\mathbf{w}^\top \mathbf{x}}} $$
The objective is to minimize the regularized log-loss:
$$ L(\mathbf{w}) = -\frac{1}{n} \sum_{i=1}^n \left[ y_i \log(\sigma(\mathbf{w}^\top \mathbf{x}_i)) + (1 - y_i)\log(1 - \sigma(\mathbf{w}^\top \mathbf{x}_i)) \right] + \lambda \| \mathbf{w} \|_2^2 $$
where $ \lambda $ is the regularization parameter.
Distributed Training with SGD¶
Step 1: Data Partitioning¶
In Spark, data is partitioned across worker nodes. Each worker will compute gradients based on the subset of data it holds.
Step 2: Broadcasting Model Weights¶
At the beginning of each epoch, the current weight vector $ \mathbf{w} $ is broadcast from the driver node to all workers.
Step 3: Local Gradient Computation¶
Each worker computes its local gradient:
$$ \nabla L_p(\mathbf{w}) = \frac{1}{|P|} \sum_{i \in P} (\sigma(\mathbf{w}^\top \mathbf{x}_i) - y_i) \mathbf{x}_i + \lambda \mathbf{w} $$
where $ P $ is the set of data points in the partition.
Step 4: Gradient Aggregation¶
Local gradients are sent back to the driver and averaged:
$$ \nabla L(\mathbf{w}) = \frac{1}{k} \sum_{p=1}^k \nabla L_p(\mathbf{w}) $$
Step 5: Weight Update¶
The driver updates the weights:
$$ \mathbf{w} \leftarrow \mathbf{w} - \eta \nabla L(\mathbf{w}) $$
where $ \eta $ is the learning rate.
Step 6: Iteration¶
Steps 2–5 are repeated for multiple epochs until convergence.
Advantages of Distributed SGD¶
- Scalability: Easily handles large datasets that don’t fit in memory.
- Parallelism: Each worker independently computes gradients.
- Efficiency: Reduces communication overhead using broadcasting and aggregation.
Practical Notes¶
- Use broadcast variables for model parameters to avoid shuffling.
- Always normalize or standardize data to improve convergence.
- Consider using mini-batches per partition for stability.
- Use checkpointing or logging to monitor convergence.
import findspark
findspark.init('/opt/apps/SPARK3/spark-current')
from pyspark.sql import SparkSession
spark = SparkSession.builder \
.appName("Distributed logistic regression using pyspark") \
.config("spark.executor.cores", "4") \
.config("spark.executor.memory", "14g") \
.config("spark.num.executors", "4") \
.getOrCreate()
25/05/05 17:38:19 WARN [Thread-6] Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041. 25/05/05 17:38:19 WARN [Thread-6] Client: Neither spark.yarn.jars nor spark.yarn.archive is set, falling back to uploading libraries under SPARK_HOME.
# Logistic Regression from Scratch in PySpark (with L2 Regularization)
from pyspark.sql import SparkSession
import numpy as np
import pandas as pd
from sklearn.datasets import make_classification
from sklearn.metrics import accuracy_score
# -----------------------------
# Step 1: Generate synthetic binary classification data
# -----------------------------
X, y = make_classification(n_samples=1000, n_features=5, random_state=42)
# Create pandas DataFrame with features and labels
pdf = pd.DataFrame(X, columns=[f"x{i}" for i in range(X.shape[1])])
pdf["y"] = y
# Convert to Spark DataFrame for distributed processing
df = spark.createDataFrame(pdf)
# -----------------------------
# Step 2: Initialize model parameters
# -----------------------------
dim = 5 # Number of features
w = np.zeros(dim) # Initial weights set to zeros
w_broadcast = spark.sparkContext.broadcast(w) # Broadcast weights to workers
learning_rate = 0.1 # Step size for SGD
l2_lambda = 0.01 # L2 regularization strength
epochs = 100 # Number of SGD passes over data
# -----------------------------
# Step 3: Define utility functions
# -----------------------------
def sigmoid(z):
"""Sigmoid function to map values to [0, 1]"""
return 1 / (1 + np.exp(-z))
def logistic_gradient(X, y, w, l2_lambda):
"""
Compute the gradient of logistic loss with L2 regularization.
X: features matrix (batch)
y: label vector (batch)
w: current weights
l2_lambda: regularization coefficient
"""
preds = sigmoid(X.dot(w))
error = preds - y
grad = X.T.dot(error) / len(y) # Gradient of logistic loss
grad += l2_lambda * w # Add L2 regularization term
return grad
# -----------------------------
# Step 4: Train model using distributed SGD
# -----------------------------
for epoch in range(epochs):
def sgd_partition(iterator):
"""
This function runs on each partition.
It accumulates X and y, computes the local gradient, and returns it.
"""
X_batch = []
y_batch = []
for row in iterator:
x = np.array([row[f"x{i}"] for i in range(dim)])
y_val = row["y"]
X_batch.append(x)
y_batch.append(y_val)
if not X_batch:
return iter([]) # Skip empty partitions
X = np.array(X_batch)
y = np.array(y_batch)
w = w_broadcast.value # Read broadcasted weights
grad = logistic_gradient(X, y, w, l2_lambda)
return iter([grad]) # Return local gradient
# Aggregate gradients across partitions and update weights
grads = df.rdd.mapPartitions(sgd_partition).collect()
avg_grad = np.mean(grads, axis=0) # Mean gradient across all workers
w -= learning_rate * avg_grad # SGD update rule
w_broadcast = spark.sparkContext.broadcast(w) # Re-broadcast updated weights
print(f"Epoch {epoch+1}: weights = {w}")
Epoch 1: weights = [ 0.04986998 -0.02152319 -0.00114528 -0.02654043 -0.00078016] Epoch 2: weights = [ 0.09601851 -0.04104163 -0.00222253 -0.0509112 -0.00247573] Epoch 3: weights = [ 0.13879198 -0.05876462 -0.00323557 -0.07332502 -0.00494608] Epoch 4: weights = [ 0.1785323 -0.07489668 -0.00418739 -0.0939909 -0.00805742] Epoch 5: weights = [ 0.21555908 -0.08962683 -0.0050807 -0.11310316 -0.01169 ] Epoch 6: weights = [ 0.25016122 -0.10312385 -0.00591818 -0.13083646 -0.01574062] Epoch 7: weights = [ 0.28259471 -0.11553544 -0.0067026 -0.1473447 -0.02012241] Epoch 8: weights = [ 0.3130839 -0.12698937 -0.00743679 -0.16276197 -0.02476329] Epoch 9: weights = [ 0.34182441 -0.1375956 -0.00812359 -0.17720447 -0.02960394] Epoch 10: weights = [ 0.36898645 -0.14744861 -0.00876579 -0.19077278 -0.0345957 ] Epoch 11: weights = [ 0.39471831 -0.15662964 -0.00936612 -0.20355399 -0.03969878] Epoch 12: weights = [ 0.41914943 -0.16520871 -0.00992721 -0.21562371 -0.0448806 ] Epoch 13: weights = [ 0.44239313 -0.17324637 -0.01045152 -0.22704776 -0.05011452] Epoch 14: weights = [ 0.464549 -0.18079518 -0.01094141 -0.23788368 -0.05537872] Epoch 15: weights = [ 0.48570488 -0.18790094 -0.01139907 -0.24818198 -0.0606553 ] Epoch 16: weights = [ 0.50593857 -0.19460371 -0.01182657 -0.25798714 -0.0659296 ] Epoch 17: weights = [ 0.52531922 -0.20093872 -0.01222582 -0.26733849 -0.07118959] Epoch 18: weights = [ 0.54390854 -0.20693702 -0.01259861 -0.276271 -0.07642544] Epoch 19: weights = [ 0.56176184 -0.21262614 -0.01294661 -0.28481582 -0.08162907] Epoch 20: weights = [ 0.57892882 -0.21803053 -0.01327135 -0.29300084 -0.08679392] Epoch 21: weights = [ 0.59545432 -0.22317203 -0.01357428 -0.30085109 -0.09191464] Epoch 22: weights = [ 0.61137893 -0.22807018 -0.01385672 -0.30838913 -0.09698692] Epoch 23: weights = [ 0.62673948 -0.23274254 -0.01411992 -0.31563536 -0.10200727] Epoch 24: weights = [ 0.64156955 -0.23720494 -0.01436502 -0.32260828 -0.10697294] Epoch 25: weights = [ 0.65589977 -0.24147167 -0.01459309 -0.32932471 -0.11188177] Epoch 26: weights = [ 0.6697582 -0.24555572 -0.01480514 -0.33579998 -0.1167321 ] Epoch 27: weights = [ 0.6831706 -0.24946887 -0.01500208 -0.34204812 -0.1215227 ] Epoch 28: weights = [ 0.69616066 -0.25322186 -0.01518479 -0.348082 -0.12625266] Epoch 29: weights = [ 0.70875025 -0.25682452 -0.01535407 -0.35391342 -0.13092141] Epoch 30: weights = [ 0.72095957 -0.26028582 -0.01551067 -0.35955327 -0.13552859]
Epoch 31: weights = [ 0.73280733 -0.26361402 -0.0156553 -0.36501159 -0.14007409]
Epoch 32: weights = [ 0.74431091 -0.26681672 -0.01578862 -0.37029766 -0.14455793] Epoch 33: weights = [ 0.75548645 -0.26990089 -0.01591123 -0.3754201 -0.14898032] Epoch 34: weights = [ 0.766349 -0.27287301 -0.01602372 -0.38038689 -0.15334157] Epoch 35: weights = [ 0.7769126 -0.27573905 -0.01612663 -0.38520546 -0.1576421 ] Epoch 36: weights = [ 0.78719035 -0.27850455 -0.01622046 -0.38988274 -0.16188241] Epoch 37: weights = [ 0.79719454 -0.28117465 -0.0163057 -0.39442519 -0.16606309] Epoch 38: weights = [ 0.80693668 -0.28375413 -0.01638278 -0.39883884 -0.17018477] Epoch 39: weights = [ 0.81642757 -0.28624744 -0.01645213 -0.40312934 -0.17424812] Epoch 40: weights = [ 0.82567738 -0.28865874 -0.01651415 -0.40730199 -0.17825388] Epoch 41: weights = [ 0.83469565 -0.2909919 -0.01656921 -0.41136177 -0.18220279] Epoch 42: weights = [ 0.84349141 -0.29325056 -0.01661766 -0.41531334 -0.18609563] Epoch 43: weights = [ 0.85207315 -0.29543811 -0.01665983 -0.41916109 -0.18993319] Epoch 44: weights = [ 0.86044891 -0.29755775 -0.01669603 -0.42290918 -0.19371628] Epoch 45: weights = [ 0.86862627 -0.29961246 -0.01672657 -0.42656151 -0.19744572] Epoch 46: weights = [ 0.87661243 -0.30160505 -0.01675172 -0.43012176 -0.20112232] Epoch 47: weights = [ 0.88441421 -0.30353818 -0.01677175 -0.43359342 -0.20474692] Epoch 48: weights = [ 0.89203806 -0.30541434 -0.0167869 -0.4369798 -0.20832033] Epoch 49: weights = [ 0.89949012 -0.30723588 -0.01679742 -0.44028401 -0.21184337] Epoch 50: weights = [ 0.90677623 -0.30900501 -0.01680353 -0.44350903 -0.21531686] Epoch 51: weights = [ 0.91390193 -0.31072385 -0.01680544 -0.44665766 -0.21874161] Epoch 52: weights = [ 0.92087251 -0.31239437 -0.01680336 -0.44973257 -0.22211843] Epoch 53: weights = [ 0.92769299 -0.31401845 -0.01679747 -0.45273631 -0.2254481 ] Epoch 54: weights = [ 0.93436817 -0.31559787 -0.01678797 -0.45567128 -0.22873141] Epoch 55: weights = [ 0.94090264 -0.31713432 -0.01677502 -0.4585398 -0.23196913] Epoch 56: weights = [ 0.94730077 -0.31862941 -0.01675879 -0.46134404 -0.23516203] Epoch 57: weights = [ 0.95356673 -0.32008465 -0.01673943 -0.46408611 -0.23831087] Epoch 58: weights = [ 0.95970452 -0.32150149 -0.0167171 -0.46676799 -0.24141637] Epoch 59: weights = [ 0.96571797 -0.32288132 -0.01669194 -0.4693916 -0.24447928] Epoch 60: weights = [ 0.97161074 -0.32422543 -0.01666408 -0.47195875 -0.2475003 ] Epoch 61: weights = [ 0.97738633 -0.32553509 -0.01663365 -0.47447118 -0.25048014] Epoch 62: weights = [ 0.98304811 -0.32681147 -0.01660078 -0.47693057 -0.25341949] Epoch 63: weights = [ 0.98859932 -0.32805572 -0.01656558 -0.4793385 -0.25631903] Epoch 64: weights = [ 0.99404304 -0.32926892 -0.01652817 -0.48169651 -0.25917943] Epoch 65: weights = [ 0.99938226 -0.33045211 -0.01648865 -0.48400606 -0.26200133] Epoch 66: weights = [ 1.00461983 -0.33160627 -0.01644713 -0.48626856 -0.26478539] Epoch 67: weights = [ 1.00975851 -0.33273236 -0.01640369 -0.48848536 -0.26753223] Epoch 68: weights = [ 1.01480094 -0.33383127 -0.01635845 -0.49065775 -0.27024247] Epoch 69: weights = [ 1.01974966 -0.33490388 -0.01631148 -0.49278698 -0.27291671] Epoch 70: weights = [ 1.02460712 -0.33595102 -0.01626287 -0.49487425 -0.27555555] Epoch 71: weights = [ 1.0293757 -0.33697348 -0.01621271 -0.4969207 -0.27815956] Epoch 72: weights = [ 1.03405766 -0.33797203 -0.01616106 -0.49892744 -0.28072932] Epoch 73: weights = [ 1.0386552 -0.3389474 -0.01610801 -0.50089555 -0.28326539] Epoch 74: weights = [ 1.04317043 -0.33990028 -0.01605363 -0.50282605 -0.28576831] Epoch 75: weights = [ 1.04760541 -0.34083137 -0.01599799 -0.50471992 -0.28823861] Epoch 76: weights = [ 1.05196209 -0.34174131 -0.01594115 -0.50657812 -0.29067684] Epoch 77: weights = [ 1.0562424 -0.34263071 -0.01588317 -0.50840158 -0.29308349] Epoch 78: weights = [ 1.06044817 -0.34350019 -0.01582412 -0.51019117 -0.29545908] Epoch 79: weights = [ 1.06458118 -0.34435032 -0.01576405 -0.51194776 -0.2978041 ] Epoch 80: weights = [ 1.06864315 -0.34518165 -0.01570302 -0.51367217 -0.30011903] Epoch 81: weights = [ 1.07263575 -0.34599471 -0.01564109 -0.5153652 -0.30240436] Epoch 82: weights = [ 1.07656058 -0.34679003 -0.0155783 -0.51702762 -0.30466053] Epoch 83: weights = [ 1.08041922 -0.3475681 -0.0155147 -0.51866019 -0.30688802] Epoch 84: weights = [ 1.08421316 -0.3483294 -0.01545035 -0.52026361 -0.30908727] Epoch 85: weights = [ 1.08794388 -0.34907439 -0.01538527 -0.5218386 -0.31125872] Epoch 86: weights = [ 1.09161278 -0.34980351 -0.01531953 -0.52338582 -0.31340279] Epoch 87: weights = [ 1.09522125 -0.3505172 -0.01525316 -0.52490593 -0.31551991] Epoch 88: weights = [ 1.09877062 -0.35121587 -0.0151862 -0.52639955 -0.3176105 ] Epoch 89: weights = [ 1.10226217 -0.35189991 -0.01511869 -0.52786732 -0.31967495] Epoch 90: weights = [ 1.10569717 -0.35256973 -0.01505067 -0.52930981 -0.32171366] Epoch 91: weights = [ 1.10907684 -0.35322568 -0.01498216 -0.53072761 -0.32372702] Epoch 92: weights = [ 1.11240234 -0.35386813 -0.01491322 -0.53212127 -0.32571542] Epoch 93: weights = [ 1.11567484 -0.35449743 -0.01484386 -0.53349134 -0.32767923] Epoch 94: weights = [ 1.11889544 -0.35511392 -0.01477413 -0.53483833 -0.32961882] Epoch 95: weights = [ 1.12206523 -0.35571793 -0.01470404 -0.53616277 -0.33153454] Epoch 96: weights = [ 1.12518526 -0.35630977 -0.01463364 -0.53746514 -0.33342675] Epoch 97: weights = [ 1.12825655 -0.35688974 -0.01456294 -0.53874592 -0.3352958 ] Epoch 98: weights = [ 1.1312801 -0.35745815 -0.01449198 -0.54000558 -0.33714202] Epoch 99: weights = [ 1.13425687 -0.35801528 -0.01442079 -0.54124458 -0.33896575] Epoch 100: weights = [ 1.13718781 -0.35856141 -0.01434937 -0.54246335 -0.34076732]
# -----------------------------
# Step 5: Evaluate model accuracy on training data
# -----------------------------
X_test = pdf[[f"x{i}" for i in range(dim)]].values
y_test = pdf["y"].values
y_pred = (sigmoid(X_test.dot(w)) > 0.5).astype(int)
print("Accuracy:", accuracy_score(y_test, y_pred))
Accuracy: 0.863
Do it in Mllib¶
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.feature import VectorAssembler
# Assemble features into a single vector
assembler = VectorAssembler(inputCols=[f"x{i}" for i in range(5)], outputCol="features")
df_vec = assembler.transform(df)
# Train logistic regression model using PySpark MLlib
lr = LogisticRegression(featuresCol="features", labelCol="y", regParam=0.01, elasticNetParam=0.0, maxIter=100)
model = lr.fit(df_vec)
# Summarize model
summary = model.summary
print("Accuracy:", summary.accuracy)
print("Coefficients:", model.coefficients)
print("Intercept:", model.intercept)
Accuracy: 0.864 Coefficients: [0.9578491147209527,-0.5288889779493914,-0.0035872139406808628,-1.200261794979123,-0.6101891453239969] Intercept: 0.16691036814177213
spark.stop()