Walmart unit sales data¶
Data Structure¶
Hierarchy Level | Description | Number of Series |
---|---|---|
1 | All products, all stores, all states | 1 |
2 | All products by states | 3 |
3 | All products by store | 10 |
4 | All products by category | 3 |
5 | All products by department | 7 |
6 | Unit sales of all products, aggregated for each State and category | 9 |
7 | Unit sales of all products, aggregated for each State and department | 21 |
8 | Unit sales of all products, aggregated for each store and category | 30 |
9 | Unit sales of all products, aggregated for each store and department | 70 |
10 | Unit sales of product x, aggregated for all stores/states | 3,049 |
11 | Unit sales of product x, aggregated for each State | 9,147 |
12 | Unit sales of product x, aggregated for each store | 30,490 |
Total | 42,840 |
Features for sales data¶
Feature | Description |
---|---|
sell_price |
Price of item in store for given date. |
event_type |
108 categorical events, e.g. sporting, cultural, religious. |
event_name |
157 event names for event_type , e.g. Super Bowl, Valentine's Day, President's Day. |
event_name_2 |
Name of event feature as given in competition data. |
event_type_2 |
Type of event feature as given in competition data. |
snap_CA, TX, WI |
Binary indicator for SNAP information in CA, TX, WI. |
release |
Release week of item in store. |
- hierarchical structure of daily sales data of total $42,840$ series spanning 1,941 days
Features for sales data¶
Feature | Description |
---|---|
price_max, min |
Maximum, minimum price for item in store in the train data. |
price_mean, std, norm |
Mean, standard deviation, and normalized price for item in store in the train data. |
item, price_nunique |
Number of unique items, prices for item in store. |
price_diff_w |
Weekly price changes for items in store. |
price_diff_m |
Price changes of item in store compared to its monthly mean. |
price_diff_y |
Price changes of item in store compared to its yearly mean. |
tm_d |
Day of month. |
tm_w |
Week in year. |
tm_m |
Month in year. |
tm_y |
Year index in the train data. |
tm_wm |
Week in month. |
tm_dw |
Day of week. |
tm_w_end |
Weekend indicator. |
In [59]:
import findspark
findspark.init('/opt/apps/SPARK3/spark-current')
# Then you could import the `pyspark` module
import pyspark
from pyspark.sql import SparkSession # build Spark Session
spark = SparkSession.builder \
.appName("tsfeatures") \
.config("spark.executor.cores", "4") \
.config("spark.executor.memory", "14g") \
.config("spark.num.executors", "4") \
.getOrCreate()
25/05/23 13:58:55 WARN [Thread-6] Client: Neither spark.yarn.jars nor spark.yarn.archive is set, falling back to uploading libraries under SPARK_HOME.
In [63]:
# Load Data
train_df = spark.read.csv("/data/m5-forecasting-accuracy/sales_train_evaluation.csv", header=True, inferSchema=True)
prices_df = spark.read.csv("/data/m5-forecasting-accuracy/sell_prices.csv", header=True, inferSchema=True)
calendar_df = spark.read.csv("/data/m5-forecasting-accuracy/calendar.csv", header=True, inferSchema=True)
TARGET = 'sales' # Our main target
END_TRAIN = 1913+28 # Last day in train set
MAIN_INDEX = ['id','d'] # We can identify item by these columns
In [64]:
from pyspark.sql.functions import col, lit, expr
from pyspark.sql.types import StringType
import numpy as np
# calendar_df.printSchema()
# Define index columns
index_columns = ['id', 'item_id', 'dept_id', 'cat_id', 'store_id', 'state_id']
# Melt train_df using explode
train_df_long = train_df.selectExpr(
*index_columns,
"stack(" + str(len(train_df.columns) - len(index_columns)) +
"".join([f", '{col}', {col}" for col in train_df.columns if col not in index_columns]) + ") as (d, sales)"
)
# Convert "d" column format to match Pandas melt
train_df_long = train_df_long.withColumn("d", expr("substring(d, 3, length(d)-2)"))
# Count rows
print("Train rows:", train_df.count(), train_df_long.count())
# Create "test set" grid for future dates
from pyspark.sql.functions import monotonically_increasing_id
add_grid = train_df.select(*index_columns).dropDuplicates()
add_grid = add_grid.crossJoin(
spark.createDataFrame([(f"d_{END_TRAIN+i}", np.nan) for i in range(1, 29)], ["d", TARGET])
)
# Combine train and test sets
grid_df = train_df_long.union(add_grid)
# Show memory usage estimate (PySpark does not have direct memory usage functions)
print(f"Total rows in grid_df: {grid_df.count()}")
Train rows: 30490 59181090
[Stage 15:==========================================> (12 + 4) / 16]
Total rows in grid_df: 60034810
In [23]:
# Show a few rows
grid_df.show(20)
+--------------------+-------------+---------+-------+--------+--------+---+-----+ | id| item_id| dept_id| cat_id|store_id|state_id| d|sales| +--------------------+-------------+---------+-------+--------+--------+---+-----+ |HOBBIES_1_001_CA_...|HOBBIES_1_001|HOBBIES_1|HOBBIES| CA_1| CA| 1| 0.0| |HOBBIES_1_001_CA_...|HOBBIES_1_001|HOBBIES_1|HOBBIES| CA_1| CA| 2| 0.0| |HOBBIES_1_001_CA_...|HOBBIES_1_001|HOBBIES_1|HOBBIES| CA_1| CA| 3| 0.0| |HOBBIES_1_001_CA_...|HOBBIES_1_001|HOBBIES_1|HOBBIES| CA_1| CA| 4| 0.0| |HOBBIES_1_001_CA_...|HOBBIES_1_001|HOBBIES_1|HOBBIES| CA_1| CA| 5| 0.0| |HOBBIES_1_001_CA_...|HOBBIES_1_001|HOBBIES_1|HOBBIES| CA_1| CA| 6| 0.0| |HOBBIES_1_001_CA_...|HOBBIES_1_001|HOBBIES_1|HOBBIES| CA_1| CA| 7| 0.0| |HOBBIES_1_001_CA_...|HOBBIES_1_001|HOBBIES_1|HOBBIES| CA_1| CA| 8| 0.0| |HOBBIES_1_001_CA_...|HOBBIES_1_001|HOBBIES_1|HOBBIES| CA_1| CA| 9| 0.0| |HOBBIES_1_001_CA_...|HOBBIES_1_001|HOBBIES_1|HOBBIES| CA_1| CA| 10| 0.0| |HOBBIES_1_001_CA_...|HOBBIES_1_001|HOBBIES_1|HOBBIES| CA_1| CA| 11| 0.0| |HOBBIES_1_001_CA_...|HOBBIES_1_001|HOBBIES_1|HOBBIES| CA_1| CA| 12| 0.0| |HOBBIES_1_001_CA_...|HOBBIES_1_001|HOBBIES_1|HOBBIES| CA_1| CA| 13| 0.0| |HOBBIES_1_001_CA_...|HOBBIES_1_001|HOBBIES_1|HOBBIES| CA_1| CA| 14| 0.0| |HOBBIES_1_001_CA_...|HOBBIES_1_001|HOBBIES_1|HOBBIES| CA_1| CA| 15| 0.0| |HOBBIES_1_001_CA_...|HOBBIES_1_001|HOBBIES_1|HOBBIES| CA_1| CA| 16| 0.0| |HOBBIES_1_001_CA_...|HOBBIES_1_001|HOBBIES_1|HOBBIES| CA_1| CA| 17| 0.0| |HOBBIES_1_001_CA_...|HOBBIES_1_001|HOBBIES_1|HOBBIES| CA_1| CA| 18| 0.0| |HOBBIES_1_001_CA_...|HOBBIES_1_001|HOBBIES_1|HOBBIES| CA_1| CA| 19| 0.0| |HOBBIES_1_001_CA_...|HOBBIES_1_001|HOBBIES_1|HOBBIES| CA_1| CA| 20| 0.0| +--------------------+-------------+---------+-------+--------+--------+---+-----+ only showing top 20 rows
In [65]:
from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType
# Group by store_id and item_id to get the earliest (min) wm_yr_wk (release week)
release_df = prices_df.groupBy("store_id", "item_id").agg(F.min("wm_yr_wk").alias("release"))
# Merge release_df with grid_df
grid_df = grid_df.join(release_df, on=["store_id", "item_id"], how="left")
# Remove release_df to free memory
del release_df
# Merge with calendar_df to get wm_yr_wk column
calendar_df = calendar_df.withColumn("d", expr("substring(d, 3, length(d)-2)"))
grid_df = grid_df.join(calendar_df.select("wm_yr_wk", "d"), on="d", how="left")
# Remove rows where wm_yr_wk is earlier than release
grid_df = grid_df.filter(F.col("wm_yr_wk") >= F.col("release"))
# Reset index equivalent (not needed in PySpark, but ensuring ordering)
grid_df = grid_df.withColumn("id", F.monotonically_increasing_id())
# Minify the release values
min_release = grid_df.agg(F.min("release")).collect()[0][0] # Get minimum release week
grid_df = grid_df.withColumn("release", (F.col("release") - min_release).cast(IntegerType()))
# Show the transformed grid_df schema and a few rows
grid_df.printSchema()
# grid_df.show(5)
[Stage 27:=============================================> (7 + 2) / 9]
root |-- d: string (nullable = true) |-- store_id: string (nullable = true) |-- item_id: string (nullable = true) |-- id: long (nullable = false) |-- dept_id: string (nullable = true) |-- cat_id: string (nullable = true) |-- state_id: string (nullable = true) |-- sales: double (nullable = true) |-- release: integer (nullable = true) |-- wm_yr_wk: integer (nullable = true)
In [66]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window
# Define window partitioning with ORDER BY for sequential computations
store_item_window = Window.partitionBy("store_id", "item_id").orderBy("wm_yr_wk")
store_item_month_window = Window.partitionBy("store_id", "item_id", "month").orderBy("wm_yr_wk")
store_item_year_window = Window.partitionBy("store_id", "item_id", "year").orderBy("wm_yr_wk")
# Compute basic aggregations
prices_df = prices_df.withColumn("price_max", F.max("sell_price").over(Window.partitionBy("store_id", "item_id")))
prices_df = prices_df.withColumn("price_min", F.min("sell_price").over(Window.partitionBy("store_id", "item_id")))
prices_df = prices_df.withColumn("price_std", F.stddev("sell_price").over(Window.partitionBy("store_id", "item_id")))
prices_df = prices_df.withColumn("price_mean", F.mean("sell_price").over(Window.partitionBy("store_id", "item_id")))
# Normalize prices (min-max scaling)
prices_df = prices_df.withColumn("price_norm", F.col("sell_price") / F.col("price_max"))
# Compute distinct counts separately (fix for DISTINCT not allowed in window functions)
price_nunique_df = prices_df.groupBy("store_id", "item_id").agg(F.countDistinct("sell_price").alias("price_nunique"))
item_nunique_df = prices_df.groupBy("store_id", "sell_price").agg(F.countDistinct("item_id").alias("item_nunique"))
# Join distinct count results back to prices_df
prices_df = prices_df.join(price_nunique_df, on=["store_id", "item_id"], how="left")
prices_df = prices_df.join(item_nunique_df, on=["store_id", "sell_price"], how="left")
# Fix: Select only necessary columns from calendar_df to avoid ambiguity
calendar_prices = calendar_df.select(
F.col("wm_yr_wk"),
F.col("month").alias("calendar_month"), # Renaming to avoid ambiguity
F.col("year").alias("calendar_year")
).dropDuplicates(["wm_yr_wk"])
In [67]:
# Merge calendar information into prices_df
prices_df = prices_df.join(calendar_prices, on=["wm_yr_wk"], how="left")
# Compute price momentum
prices_df = prices_df.withColumn(
"price_momentum",
F.col("sell_price") / F.lag("sell_price", 1).over(store_item_window)
)
prices_df = prices_df.withColumn(
"price_momentum_m",
F.col("sell_price") / F.mean("sell_price").over(
Window.partitionBy("store_id", "item_id", "calendar_month").orderBy("wm_yr_wk")
)
)
prices_df = prices_df.withColumn(
"price_momentum_y",
F.col("sell_price") / F.mean("sell_price").over(
Window.partitionBy("store_id", "item_id", "calendar_year").orderBy("wm_yr_wk")
)
)
# Drop temporary columns
prices_df = prices_df.drop("calendar_month", "calendar_year")
# Show schema and verify results
prices_df.printSchema()
# prices_df.show(5)
root |-- wm_yr_wk: integer (nullable = true) |-- store_id: string (nullable = true) |-- sell_price: double (nullable = true) |-- item_id: string (nullable = true) |-- price_max: double (nullable = true) |-- price_min: double (nullable = true) |-- price_std: double (nullable = true) |-- price_mean: double (nullable = true) |-- price_norm: double (nullable = true) |-- price_nunique: long (nullable = true) |-- item_nunique: long (nullable = true) |-- price_momentum: double (nullable = true) |-- price_momentum_m: double (nullable = true) |-- price_momentum_y: double (nullable = true)
Forecasting with gradient boosted tree¶
GBTRegressor
(Gradient Boosted Tree Regressor) in Spark MLlib is a supervised learning algorithm that builds an ensemble of decision trees using gradient boosting to improve predictive accuracy.
How GBTRegressor Works¶
Gradient Boosted Trees (GBT) work by training decision trees sequentially, where:
- The first tree makes an initial prediction.
- Each subsequent tree learns from the errors (residuals) of the previous trees.
- The final prediction is the sum of all trees’ outputs.
This technique is effective for handling non-linear relationships in data and reducing bias and variance.
Code Example¶
from pyspark.ml.regression import GBTRegressor
from pyspark.ml.evaluation import RegressionEvaluator
# Initialize GBT Regressor
gbt = GBTRegressor(featuresCol="features", labelCol="sales", maxIter=50, maxDepth=5, stepSize=0.1)
# Train the model
model = gbt.fit(train_df)
# Make predictions
predictions = model.transform(test_df)
# Evaluate using RMSE
evaluator = RegressionEvaluator(labelCol="sales", predictionCol="prediction", metricName="rmse")
rmse = evaluator.evaluate(predictions)
print(f"Root Mean Squared Error (RMSE): {rmse}")
Key Hyperparameters¶
Parameter | Description |
---|---|
maxIter |
Number of trees in the ensemble (higher = more complex model) |
maxDepth |
Maximum depth of each tree (higher = risk of overfitting) |
stepSize |
Learning rate (default 0.1 for stability) |
subsamplingRate |
Fraction of data used for each tree (default 1.0 , full dataset) |
maxBins |
Number of bins for feature discretization (default 32 ) |
minInstancesPerNode |
Minimum instances required per node (default 1 ) |
Recommended Settings¶
- For small datasets →
maxIter=20, maxDepth=3
- For large datasets →
maxIter=50, maxDepth=5
- For fine-tuning → Adjust
stepSize
(0.05 - 0.2
)
Advantages and Limitations¶
✅ Handles complex non-linear relationships
✅ More accurate than a single Decision Tree
✅ Built-in feature selection (important features contribute more)
✅ Works well with missing values
🚨 Slower training compared to Random Forest (sequential training of trees)
🚨 Prone to overfitting with large maxDepth
🚨 Not suited for real-time applications (expensive to update)
In [68]:
# Perform Left Join with prices_df
grid_df = grid_df.join(prices_df, on=['store_id', 'item_id', 'wm_yr_wk'], how="left")
# We don't need prices_df anymore
del prices_df
# Show Schema and Sample Data
grid_df.printSchema()
grid_df.show(10)
root |-- store_id: string (nullable = true) |-- item_id: string (nullable = true) |-- wm_yr_wk: integer (nullable = true) |-- d: string (nullable = true) |-- id: long (nullable = false) |-- dept_id: string (nullable = true) |-- cat_id: string (nullable = true) |-- state_id: string (nullable = true) |-- sales: double (nullable = true) |-- release: integer (nullable = true) |-- sell_price: double (nullable = true) |-- price_max: double (nullable = true) |-- price_min: double (nullable = true) |-- price_std: double (nullable = true) |-- price_mean: double (nullable = true) |-- price_norm: double (nullable = true) |-- price_nunique: long (nullable = true) |-- item_nunique: long (nullable = true) |-- price_momentum: double (nullable = true) |-- price_momentum_m: double (nullable = true) |-- price_momentum_y: double (nullable = true)
[Stage 60:> (0 + 1) / 1]
+--------+-------------+--------+---+-----------+---------+-------+--------+-----+-------+----------+---------+---------+-------------------+-----------------+------------------+-------------+------------+--------------+----------------+----------------+ |store_id| item_id|wm_yr_wk| d| id| dept_id| cat_id|state_id|sales|release|sell_price|price_max|price_min| price_std| price_mean| price_norm|price_nunique|item_nunique|price_momentum|price_momentum_m|price_momentum_y| +--------+-------------+--------+---+-----------+---------+-------+--------+-----+-------+----------+---------+---------+-------------------+-----------------+------------------+-------------+------------+--------------+----------------+----------------+ | CA_4| FOODS_3_442| 11239|631|25769803776| FOODS_3| FOODS| CA| 0.0| 138| 2.48| 2.48| 2.48| 0.0|2.480000000000006| 1.0| 1| 132| NULL| 1.0| 1.0| | CA_4| FOODS_3_442| 11239|632|25769803777| FOODS_3| FOODS| CA| 0.0| 138| 2.48| 2.48| 2.48| 0.0|2.480000000000006| 1.0| 1| 132| NULL| 1.0| 1.0| | CA_4| FOODS_3_442| 11239|633|25769803778| FOODS_3| FOODS| CA| 0.0| 138| 2.48| 2.48| 2.48| 0.0|2.480000000000006| 1.0| 1| 132| NULL| 1.0| 1.0| | CA_4| FOODS_3_442| 11239|634|25769803779| FOODS_3| FOODS| CA| 0.0| 138| 2.48| 2.48| 2.48| 0.0|2.480000000000006| 1.0| 1| 132| NULL| 1.0| 1.0| | CA_4| FOODS_3_442| 11239|635|25769803780| FOODS_3| FOODS| CA| 1.0| 138| 2.48| 2.48| 2.48| 0.0|2.480000000000006| 1.0| 1| 132| NULL| 1.0| 1.0| | CA_4| FOODS_3_442| 11239|636|25769803781| FOODS_3| FOODS| CA| 2.0| 138| 2.48| 2.48| 2.48| 0.0|2.480000000000006| 1.0| 1| 132| NULL| 1.0| 1.0| | CA_4| FOODS_3_442| 11239|637|25769803782| FOODS_3| FOODS| CA| 0.0| 138| 2.48| 2.48| 2.48| 0.0|2.480000000000006| 1.0| 1| 132| NULL| 1.0| 1.0| | TX_2|HOBBIES_2_105| 11109| 57|34359738375|HOBBIES_2|HOBBIES| TX| 0.0| 7| 2.47| 2.77| 2.47|0.13977135561823964|2.564909090909094|0.8916967509025271| 2| 64| 1.0| 1.0| 1.0| | TX_2|HOBBIES_2_105| 11109| 58|34359738376|HOBBIES_2|HOBBIES| TX| 0.0| 7| 2.47| 2.77| 2.47|0.13977135561823964|2.564909090909094|0.8916967509025271| 2| 64| 1.0| 1.0| 1.0| | TX_2|HOBBIES_2_105| 11109| 59|34359738377|HOBBIES_2|HOBBIES| TX| 0.0| 7| 2.47| 2.77| 2.47|0.13977135561823964|2.564909090909094|0.8916967509025271| 2| 64| 1.0| 1.0| 1.0| +--------+-------------+--------+---+-----------+---------+-------+--------+-----+-------+----------+---------+---------+-------------------+-----------------+------------------+-------------+------------+--------------+----------------+----------------+ only showing top 10 rows
In [69]:
from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType, BooleanType
from math import ceil
icols = ['date', 'd', 'event_name_1', 'event_type_1', 'event_name_2', 'event_type_2',
'snap_CA', 'snap_TX', 'snap_WI']
grid_df = grid_df.join(calendar_df.select(*icols), on=['d'], how="left")
grid_df.printSchema()
root |-- d: string (nullable = true) |-- store_id: string (nullable = true) |-- item_id: string (nullable = true) |-- wm_yr_wk: integer (nullable = true) |-- id: long (nullable = false) |-- dept_id: string (nullable = true) |-- cat_id: string (nullable = true) |-- state_id: string (nullable = true) |-- sales: double (nullable = true) |-- release: integer (nullable = true) |-- sell_price: double (nullable = true) |-- price_max: double (nullable = true) |-- price_min: double (nullable = true) |-- price_std: double (nullable = true) |-- price_mean: double (nullable = true) |-- price_norm: double (nullable = true) |-- price_nunique: long (nullable = true) |-- item_nunique: long (nullable = true) |-- price_momentum: double (nullable = true) |-- price_momentum_m: double (nullable = true) |-- price_momentum_y: double (nullable = true) |-- date: date (nullable = true) |-- event_name_1: string (nullable = true) |-- event_type_1: string (nullable = true) |-- event_name_2: string (nullable = true) |-- event_type_2: string (nullable = true) |-- snap_CA: integer (nullable = true) |-- snap_TX: integer (nullable = true) |-- snap_WI: integer (nullable = true)
In [70]:
# Extract Date Features
grid_df = grid_df.withColumn("tm_d", F.dayofmonth("date"))
grid_df = grid_df.withColumn("tm_w", F.weekofyear("date"))
grid_df = grid_df.withColumn("tm_m", F.month("date"))
grid_df = grid_df.withColumn("tm_y", F.year("date"))
# Normalize `tm_y` (Subtract Minimum Year)
min_year = grid_df.agg(F.min("tm_y")).collect()[0][0]
grid_df = grid_df.withColumn("tm_y", (F.col("tm_y") - min_year))
# Compute `tm_wm` (Week of Month)
grid_df = grid_df.withColumn("tm_wm", (F.col("tm_d") / 7 + 0.99)) # ceil(x/7)
# Compute `tm_dw` (Day of Week) and `tm_w_end` (Weekend Indicator)
grid_df = grid_df.withColumn("tm_dw", F.dayofweek("date") - 1) # Adjust to start from Monday=0
grid_df = grid_df.withColumn("tm_w_end", (F.col("tm_dw") >= 5).cast(IntegerType()))
# Drop `wm_yr_wk` Column
grid_df = grid_df.drop("wm_yr_wk")
# how Schema & Sample Data
grid_df.printSchema()
# grid_df.show(10)
[Stage 73:===================================================> (8 + 1) / 9]
root |-- d: string (nullable = true) |-- store_id: string (nullable = true) |-- item_id: string (nullable = true) |-- id: long (nullable = false) |-- dept_id: string (nullable = true) |-- cat_id: string (nullable = true) |-- state_id: string (nullable = true) |-- sales: double (nullable = true) |-- release: integer (nullable = true) |-- sell_price: double (nullable = true) |-- price_max: double (nullable = true) |-- price_min: double (nullable = true) |-- price_std: double (nullable = true) |-- price_mean: double (nullable = true) |-- price_norm: double (nullable = true) |-- price_nunique: long (nullable = true) |-- item_nunique: long (nullable = true) |-- price_momentum: double (nullable = true) |-- price_momentum_m: double (nullable = true) |-- price_momentum_y: double (nullable = true) |-- date: date (nullable = true) |-- event_name_1: string (nullable = true) |-- event_type_1: string (nullable = true) |-- event_name_2: string (nullable = true) |-- event_type_2: string (nullable = true) |-- snap_CA: integer (nullable = true) |-- snap_TX: integer (nullable = true) |-- snap_WI: integer (nullable = true) |-- tm_d: integer (nullable = true) |-- tm_w: integer (nullable = true) |-- tm_m: integer (nullable = true) |-- tm_y: integer (nullable = true) |-- tm_wm: double (nullable = true) |-- tm_dw: integer (nullable = true) |-- tm_w_end: integer (nullable = true)
Save your prepared features¶
In [11]:
import os
grid_df.write.mode("overwrite").csv(os.path.expanduser("./data/m5_features"))
In [71]:
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml.regression import GBTRegressor
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.sql import functions as F
from pyspark.sql.types import DoubleType, IntegerType
# Feature Selection (Exclude categorical & string columns)
# FEATURES = [
# "release", "sell_price", "price_max", "price_min", "price_std", "price_mean",
# "price_norm", "price_nunique", "item_nunique", "price_momentum", "price_momentum_m",
# "price_momentum_y", "tm_d", "tm_w", "tm_m", "tm_y", "tm_wm", "tm_dw", "tm_w_end",
# "snap_CA", "snap_TX", "snap_WI"
#]
FEATURES = [
"release", "sell_price", "price_max", "price_min", "price_std", "price_mean"
]
TARGET = "sales"
# Convert sales to DoubleType (required for GBTRegressor)
grid_df = grid_df.withColumn(TARGET, F.col(TARGET).cast(DoubleType()))
# Replace NULL in values
grid_df = grid_df.na.fill(0)
# Assemble feature columns into a single 'features' vector
vector_assembler = VectorAssembler(inputCols=FEATURES, outputCol="features")
grid_df = vector_assembler.transform(grid_df)
# Train-Test Split
# Use 'd' to split data (adjust threshold as needed)
train_df = grid_df.filter(F.col("d") < 1914) # Training Data
test_df = grid_df.filter(F.col("d") >= 1914) # Test/Validation Data
Checking Features¶
In [72]:
from pyspark.sql.functions import col, count, when, lit
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import GBTRegressor
# Step 1: Check which feature columns contain NULL values
print("Checking for NULL values in feature columns:")
train_df.select([count(when(col(c).isNull(), 1)).alias(c) for c in FEATURES]).show()
# Step 2: Replace NULL values in numeric feature columns with 0
for col_name in FEATURES:
train_df = train_df.withColumn(col_name, when(col(col_name).isNull(), lit(0)).otherwise(col(col_name)))
# Step 3: Drop NULL values in target column `sales`
train_df = train_df.dropna(subset=["sales"])
# Step 4: Drop existing `features` column if it exists
if "features" in train_df.columns:
train_df = train_df.drop("features")
# Step 5: Recreate `VectorAssembler` with `handleInvalid="skip"`
vector_assembler = VectorAssembler(inputCols=FEATURES, outputCol="features", handleInvalid="skip")
train_df = vector_assembler.transform(train_df).select("features", "sales")
print(f"Training Data Count: {train_df.count()}")
print(f"Test Data Count: {test_df.count()}")
Checking for NULL values in feature columns:
+-------+----------+---------+---------+---------+----------+ |release|sell_price|price_max|price_min|price_std|price_mean| +-------+----------+---------+---------+---------+----------+ | 0| 0| 0| 0| 0| 0| +-------+----------+---------+---------+---------+----------+
Training Data Count: 46027957
[Stage 128:============> (2 + 7) / 9]
Test Data Count: 853720
In [ ]:
# Train GBT Model
gbt = GBTRegressor(featuresCol="features", labelCol="sales", maxIter=50, maxDepth=5, stepSize=0.1)
model = gbt.fit(train_df)
# Make Predictions
predictions = model.transform(test_df)
# Evaluate Model Performance
evaluator = RegressionEvaluator(labelCol=TARGET, predictionCol="prediction", metricName="rmse")
rmse = evaluator.evaluate(predictions)
print(f"Root Mean Squared Error (RMSE): {rmse}")
[Stage 168:> (0 + 8) / 13]
In [ ]:
## Save Model and Load for Future Predictions
```python
model.write().overwrite().save("m5_gbt_forecasting_model")
predictions.select("sales", "prediction").write.mode("overwrite").parquet("m5_gbt_predictions.parquet")
# Load Model for Future Predictions
from pyspark.ml.regression import GBTRegressionModel
loaded_model = GBTRegressionModel.load("m5_gbt_forecasting_model")
# Make new predictions with the loaded model
new_predictions = loaded_model.transform(test_df)
new_predictions.show(10)
```
In [58]:
spark.stop()
Further Reading¶
- Januschowski, T., Wang, Y., Torkkola, K., Erkkilä, T., Hasson, H., & Gasthaus, J. (2022). Forecasting with trees. International Journal of Forecasting, 38(4), 1473-1481. https://doi.org/10.1016/j.ijforecast.2021.10.004