Forecasting book sales with Temporal Fusion Transformer

Mouna Labiadh
DataNess.AI
Published in
11 min readAug 2, 2023

A step-by-step guide on how to use Temporal Fusion Transformer for book sales forecasting. We use the model implementation that is available in Pytorch Forecasting library along with Kaggle’s “tabular playground series” dataset.

Photo by Luisa Brimble on Unsplash

This is Part 2 of my previous post about Temporal Fusion Transformer (TFT) [1]:

In this part, we will be training and testing a TFT on Kaggle’s Tabular playground series (Sept 2022) dataset [2]. For the implementation of the model, we will use the open source python package Pytorch Forecasting [3]. You can install it using pip:

$ pip install pytorch-forecasting

Pytorch Forecasting is based on Pytorch Lightning and integrates Optuna for hyperparameters tuning.

Data exploration

Tabular playground series dataset contains the total number of daily sales of 4 books from two stores located in six different countries. We have training data for over 4 years from 2017 to 2020. The goal is to predict the number of sales for 2021 for each book in each store and each country.

We will train our model on 2017~2019 and leave 2020 for validation. As you might assume, a major challenge in this use case is how to handle the COVID-19 pandemic effect on data during the year 2020.

Let’s start by taking a look on the data:

DATAPATH = Path("./tabular-playground-series-sep-2022/")

train_df = pd.read_csv(DATAPATH / "train.csv")
train_df['date'] = pd.to_datetime(train_df['date'])
test_df = pd.read_csv(DATAPATH / "test.csv")
test_df['date'] = pd.to_datetime(test_df['date'])
fig, ax = plt.subplots(1, 1, figsize = (23, 8))
sns.lineplot(x='date', y='num_sold', hue='product', data=(train_df.groupby(['date', 'product']).num_sold.sum().to_frame()))
ax.set_title("sum of num_sold per product ")
Daily total sales per book title

Similarly,

Daily total sales per store
Daily total sales per country

What we have:

  • The dataset does not contain missing values or missing timesteps,
  • The pattern of the data changes significantly in 2020 compared to previous years. The drop of sales around March~May is mostly due to COVID-19 lockdown and restrictions in Europe,
  • A significant increase of the sales average during 2020 compared to 2017~2019. Similarly, we notice an increased sales average around 11/01/2020~29/02/2020 compared to the second half of the same year,
  • In general, data seems to contain weekly and annual seasonalities, and an increase of sales in weekends. We also notice a peak during end-of-year holidays.

We start by addressing the annual shift in average sales for each country by bringing 2017~2019 data to the same level as 2020. This will improve the generalization performance of our model by training and testing on like-to-like data.

# Add a year column
train_df_processed["year"]=train_df_processed.index.year

# Get average sales for each country in each year
mean_country_year = train_df_processed[['country', 'year', "num_sold"]].groupby(['country', 'year'], as_index=False).mean()

# Shift the mean of 2017~2019 sales data w.r.t. 2020
for country in train_df_processed.country.unique():
mean_2020 = mean_country_year.loc[(mean_country_year['year'] == 2020) & (mean_country_year['country'] == country), 'num_sold'].values[0]

for year in train_df_processed.year.unique():
if year==2020:
break
mean_year = mean_country_year.loc[(mean_country_year['year'] == year) & (mean_country_year['country'] == country), 'num_sold'].values[0]
factor = mean_2020/mean_year
train_df_processed.loc[(train_df_processed["country"]==country) & (train_df_processed["year"]==year),"num_sold"]= train_df_processed.loc[(train_df_processed["country"]==country) & (train_df_processed["year"]==year),"num_sold"]*factor
Sales data after addressing annual shifts. An outlier value can be seen in 01/01/2020

Lockdown period from March~May 2020 does not reflect the general patterns that we want our model to learn, as they would not be generalizeable to following years (hopefully!). Including these months data would negatively impact the performance.

Different strategies can be considered to handle the impact of the pandemic on data for time series modeling use cases [4]. The most straightforward solution for this is to simply exclude the problematic span from our modeling as it represents an outlier. This could work given that after the recovery from the pandemic, the data patterns are similar to what they were before.

However, given the restricted amount of historical data we have, and the fact we’re using 2020's pandemic-affected data for validation (i.e. the model will be trained on pre-pandemic and tested on post-pandemic) . We rather try to drop/replace them by our estimation of “what would have happened if it wasn’t for the pandemic”. This yielded the best blind-test results on 2021.

The outlier seen in January 1st 2020 will be also dropped and replaced.

As an imputation strategy, we take the average sales over the previous years from 2017. We make sure to respect the weekly and annual seasonality for replacing March~May:

# handle 01-01-2020 outlier
train_df_processed.loc[pd.Timestamp('2020-01-01'), "num_sold"]=np.nan

df_shifted = pd.concat(
[train_df_processed[["num_sold"]].shift(periods=365*48*x) for x in range(3)], axis=1
)
train_df_processed[["num_sold"]] = train_df_processed[["num_sold"]].fillna(df_shifted.groupby(by=df_shifted.columns, axis=1).mean())

# handle March~May lockdown outliers
train_df_processed.loc[(train_df_processed.year==2020)&(train_df_processed.index.month.isin([3,4,5])), "num_sold"]=np.nan

df_shifted = pd.concat(
[train_df_processed[["num_sold"]].shift(periods=52*7*48*x) for x in range(3)], axis=1
)
train_df_processed[["num_sold"]] = train_df_processed[["num_sold"]].fillna(df_shifted.groupby(by=df_shifted.columns, axis=1).mean())
Daily total sales per country after pre-processing

To account for weekly and annual seasonalities exhibited by the data, we add the following calendar features: weekend flag, weekday and week of the year. For weekday and week features, we use cyclical encoding to reflect their cyclical nature.

CALENDAR_CYCLES = {
"weekday": 7,
"week": 52,
"month": 12,
}

def add_cyclical_calendar_features(df: pd.DataFrame, features: Union[str, List[str]]):
"""Cyclical encoding of calendar features"""
if isinstance(features, str):multi-head attention
features = [features]
for feat in features:
assert (
feat in CALENDAR_CYCLES.keys()
), f"Cyclical encoding is not available for {feat}"

values = getattr(df.index, feat)
df[f"{feat}_sin"] = np.sin(2 * np.pi * values / CALENDAR_CYCLES[feat])
df[f"{feat}_cos"] = np.cos(2 * np.pi * values / CALENDAR_CYCLES[feat])
return df
# add calendar features
train_df_processed = add_cyclical_calendar_features(train_df_processed.set_index("date"), features=["weekday", "week"])
train_df_processed["weekend"] = (train_df_processed.index.dayofweek > 4).astype(int)

We also add boolean flags for national holidays with respect to the country (we use holidays package for this) and for end-of-year holidays (from December 24th to 31st).

# add holidays flag
holidays_dates_per_country = {}
for country in train_df_processed["country"].unique():
holidays_dates_per_country[country]=[tuple[0] for tuple in list(getattr(holidays, country)(years=set(train_df_processed.index.year)).items())]
train_df_processed.loc[train_df_processed["country"]==country, "holidays"]=train_df_processed.loc[train_df_processed["country"]==country].index.isin(holidays_dates_per_country[country])

train_df_processed["holidays"] = train_df_processed["holidays"].astype(int)
# add end-of-year holidays flag
train_df_processed["newyear"]=0
for day in range(25,32):
train_df_processed.loc[(train_df_processed.index.month == 12) & (train_df_processed.index.day == day),"newyear"]=1

Model Training

Pytorch Forecasting library requires a TimeSeriesDataSet object to encapsulate the data, features names (target, static covariates, time varying known and unknown variables, lagged values), feature encoders, scalers, the max and min lengths of the look-back window and the prediction horizon of the sequence-to-sequence block, etc.

As a reminder, static covariates in TFT consist in static metadata about the measured entities (i.e. books) that do no depend on time (i.e. country, store, product). Static covariates are used to define groups to separate individual times series contained in the dataset. The number of groups is the number of all possible combinations of static covariates values.

In our use case, we will have 48 individual time series : number of countries * number of stores * number of products

TimeSeriesDataSet expects the provided data to contain an integer column denoting the time index. It should start by 0 and increment by +1 for each individual time series if there are no missing time steps.

We replace our date column with time_idx:

train = train_df_processed.reset_index()
train = (train.merge((train[['date']].drop_duplicates(ignore_index=True).rename_axis('time_idx')).reset_index(), on = ['date'])).drop(["date", "row_id"], axis=1)

We use this time index to remove the last max_prediction_length timesteps from each individual time series, i.e. to make sure that the validation set is completely held out from the training.

max_prediction_length = 365 # a whole year
min_encoder_length = 365
max_encoder_length = train.index.nunique()

# keep the validation set held-out
training_cutoff = train["time_idx"].max() - max_prediction_length # validation on 2020

Given that our time series data are different in magnitude between different groups (country, store and product), we use a GroupNormalizer to normalize each time-series individually.

To account for the autocorrelation in the sales data, we explicitly add lagged features for t-7 (same day of previous week) and t-365 (previous year).

# Create training set
training_dataset = TimeSeriesDataSet(
train[lambda x: x.time_idx <= training_cutoff],
time_idx="time_idx",
target="num_sold", # target variable
group_ids=["country", "store", "product"], # static covariates
max_encoder_length=max_encoder_length, # maximum size of lookup window
min_encoder_length=max_encoder_length//2,
max_prediction_length=max_prediction_length, # maximum size of horizon window
min_prediction_length=max_prediction_length,
time_varying_known_reals=[
"time_idx", 'weekday_cos', 'weekday_sin', 'week_cos', 'week_sin', 'weekend', 'holidays', 'newyear'],
time_varying_unknown_categoricals=[],
time_varying_unknown_reals=['num_sold'],
target_normalizer=GroupNormalizer(finalement
groups=["country", "store", "product"], transformation="softplus"
), # use softplus transformation and normalize by group
lags={'num_sold': [7,365]}, # add lagged values of target variable
add_relative_time_idx=True,
add_target_scales=True,
add_encoder_length=True,
)

We then create a validation TimeSeriesDataSet that has the same parameters as the training dataset object as follows:

# create validation set (predict=True)
validation_dataset = TimeSeriesDataSet.from_dataset(
training_dataset, # dataset from which to copy parameters (encoders, scalers, ...)
train, # data from which new dataset will be generated
predict=True, # predict the decoder length on the last entries in the time index
stop_randomization=True,
)

Similarly, we use the training TimeSeriesDataSet object to create the TFT model:

# Create network from TimeSeriesDataSet
tft = TemporalFusionTransformer.from_dataset(
training_dataset,
learning_rate=0.03,
hidden_size=16,
attention_head_size=1,
dropout=0.1,
hidden_continuous_size=8,
output_size=7, # number of quantiles
loss=QuantileLoss(),
log_interval=10, # logging every 10 batches
reduce_on_plateau_patience=4,
)
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

Pytorch Forecasting implementation is based on Pytorch Lightning. Hence, to train the model, we need to define a Trainer object and dataloaders.

# define callbacks
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min")
lr_logger = LearningRateMonitor() # log the learning rate
logger = TensorBoardLogger(save_dir=SAVE_DIR) # log results to a tensorboard

# create trainer
trainer = pl.Trainer(
max_epochs=50,
accelerator="gpu" if torch.cuda.is_available() else "cpu",
devices=1,
gradient_clip_val=0.1,
limit_train_batches=30, # run valiation every 30 batches
log_every_n_steps=10,
# fast_dev_run=True, # comment in to check that networkor dataset has no serious bugs
callbacks=[lr_logger, early_stop_callback],
logger=logger,
)
# create training and validation dataloaders
batch_size = 128
train_dataloader = training_dataset.to_dataloader(train=True, batch_size=batch_size, num_workers=8)
val_dataloader = validation_dataset.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=8)

# fit network
trainer.fit(
tft,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)

Hyper-parameters Tuning

In the above code, we used pre-fixed hyper-parameters to train our model. Pytorch Forecasting proposes a wrapper function to tune them on the validation set using Optuna :

# Hyperparameters Tuning with Optuna
# create study
study = optimize_hyperparameters(
train_dataloader,
val_dataloader,
model_path="optuna_test",
n_trials=200,
max_epochs=50,
gradient_clip_val_range=(0.01, 1.0),
hidden_size_range=(8, 128),
hidden_continuous_size_range=(8, 128),
attention_head_size_range=(1, 4),
learning_rate_range=(0.001, 0.1),
dropout_range=(0.1, 0.3),
trainer_kwargs=dict(limit_train_batches=30),
reduce_on_plateau_patience=4,
use_learning_rate_finder=False, # use Optuna to find ideal learning rate or use in-built learning rate finder
)

# pickle study results
with open("study.pkl", "wb") as fout:
pickle.dump(study, fout)

# show best hyperparameters
study.best_trial.params

Validation

Once trained, we load the model from the the checkpoint with the best validation loss.

# load the best model w.r.t. the validation loss
best_model_path = trainer.checkpoint_callback.best_model_path
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)

By default, the last max_prediction_length time steps (365 days) of each individual time series are used for validation. We set return_x to True to return input data that we plot later on the same plot as predictions.

# get prediction results as a dictionary
val_prediction_results = best_tft.predict(val_dataloader, mode="raw", return_x=True)
# Plot actuals vs prediction and attention
for idx in range(val_predictions.prediction.shape[0]): # nb of groups combinations
fig, ax = plt.subplots(figsize=(23,5))
best_tft.plot_prediction(val_prediction_results.x, # network input
val_prediction_results.output, # network output
idx=idx,
add_loss_to_title=True,
ax=ax);
Prediction results on the validation set for the first group

The gray line shows the attention weights for different points in time. Exploring these weights can be used to improve the model. For example, if attention peaks are located at the start of the look-back window, one should increase the size of this window to include all relevant past time-steps.

Inferring 2021

Now, we will infer the number of sales of 2021 for each individual time series. For this, we need to append the training dataset with static covariates and time-varying known values of 2021. In our case, this includes apriori known calendar features.

# add calendar features
test = add_cyclical_calendar_features(test_df.set_index("date"), features=["weekday", "week"])
test["weekend"] = (test.index.dayofweek > 4).astype(int)

# add holidays flag
holidays_dates_per_country = {}
for country in test["country"].unique():
holidays_dates_per_country[country]=[tuple[0] for tuple in list(getattr(holidays, country)(years=set(test.index.year)).items())]
test.loc[test["country"]==country, "holidays"]=test.loc[test["country"]==country].index.isin(holidays_dates_per_country[country])

test["holidays"] = test["holidays"].astype(int)

# add end-of-year holidays flag
test["newyear"]=0
for day in range(25,32):
test.loc[(test.index.month == 12) & (test.index.day == day),"newyear"]=1

# Add required time_idx column w.r.t to last time index of training df
test = test.reset_index()
test = (test.merge((test[['date']].drop_duplicates(ignore_index=True).rename_axis('time_idx')).reset_index(), on = ['date']))
test["time_idx"]+=train["time_idx"].max()+1

# Drop unused columns
test = test.drop(["date", "row_id"], axis=1)

# Vertically stack the test df at the end of the training df
test = pd.concat([train, test], ignore_index=True).fillna(0.0)

We then create a TimeSeriesDataSet object, a dataloader for our test set and launch the prediction loop.

# Create test dataset
test_dataset = TimeSeriesDataSet.from_dataset(training_dataset,
test,
predict=True,
stop_randomization=True)

# Create test dataloader
test_dataloader = test_dataset.to_dataloader(train=False, batch_size=batch_size, num_workers=8)

# Get prediction results
test_prediction_results = best_tft.predict(
test_dataloader,
mode="raw",
return_index=True, # return the prediction index in the same order as the output
return_x=True, # return network inputs in the same order as prediction output
)
Prediction results on the test set for the first group

2021 predictions are of shape <number of groups combinations, number of timesteps, number of quantiles>

To be able to construct the predictions dataframe and associate each predicted value with its corresponding store, country and product, we need to figure out the groups combination associated to the first dimension of the predictions Tensor. For this, we set return_index to True.

The time_idx in the index dataframe corresponds to the time_idx of the first prediction.

# Create predictions dataframe
predictions_df = test_df.copy()
# Add num_sold column
predictions_df["num_sold"]=np.nan

# get 0.5 quantile (median) predictions
median_predictions = test_prediction_results.output.prediction.cpu().numpy()[:,:,4]

# add sales predictions w.r.t to groups combination
for i, row in test_prediction_results.index.iterrows():
predictions_df.loc[
(predictions_df["country"]==row["country"]) &
(predictions_df["store"]==row["store"]) &
(predictions_df["product"]==row["product"]),
"num_sold"] = median_predictions[i]

Interpretatability

Interpretability is a major property of the Temporal Fusion Transformer. Pytorch Forecasting proposes utility functions to get and plot interpretation results from the trained model.

Interpretability in TFT is established using :

  • Shared weights in the multi-head attention mechanism that allow to trace back most relevant historical data points. Attention weights associated with each data point in the look-back window can be plotted as previously shown in the Validation section (gray line)
  • Special variable selection blocks that weigh the importance of each input variable. You can retrieve learned weights from predictions dictionary
# plot variable importance
interpretation = best_tft.interpret_output(
val_predictions,
reduction="sum", # sum attentions
)
best_tft.plot_interpretation(interpretation)

Pytorch forecasting also provides a function to cross-plot predictions vs actual values of different variables to detect possible dependencies between residuals and input features. This might guide the explanatory features identification step.

val_prediction_results = best_tft.predict(
val_dataloader,
mode="prediction", # get only median predictions
return_x=True,
)
predictions_vs_actuals = best_tft.calculate_prediction_actual_by_variable(val_prediction_results.x, val_prediction_results.output)

# remove added lagged features
features = list(set(predictions_vs_actuals['support'].keys())-set(['num_sold_lagged_by_365', 'num_sold_lagged_by_7']))

# plot cross_plots
for feature in features:
best_tft.plot_prediction_actual_by_variable(predictions_vs_actuals, name=feature);

Key takeaways

Pytorch Forecasting is an open source Python library. It proposes an implementation of Temporal Fusion Transformer based on Pytorch Lightning. It provides support to hyperparameters tuning with Optuna and the visualization of inherent interpretability properties.

In addition to TFT, Pytorch Forecasting proposes implementations of other time series forecasting deep learning models like DeepAR, N-BEATS and N-HiTS.

Thank you for reading!

All the code is available as a Jupyter Notebook on my Github: https://github.com/mounalab/Temporal-fusion-transformer_Pytorch-ForecastingForecasting

References

[1] Lim, Bryan, et al. “Temporal fusion transformers for interpretable multi-horizon time series forecasting.” International Journal of Forecasting 37.4 (2021): 1748–1764.

[2] “Tabular Playground Series” dataset: https://www.kaggle.com/competitions/tabular-playground-series-aug-2022/overview

[3] Step by step tutorial : https://pytorch-forecasting.readthedocs.io/en/stable/tutorials/stallion.html

Good Reads

[4] A helpful Github discussion on how to handle shocks in time series modeling use cases : https://github.com/facebook/prophet/issues/1416

--

--

Mouna Labiadh
DataNess.AI

Data scientist | PhD | Machine Learning | Deep Learning | Time Series