Lag-Llama: An Open-Source Base Model for Predicting Time Series Data

tom odhiambo
7 min readMar 16, 2024

--

Discover the structure of Lag-Llama and understand how to use it in a prediction project with Python

Introduction

Lag-Llama is designed specifically for single-variable probabilistic forecasting.

The model employs a universal approach for converting time-series data into tokens, independent of frequency. This allows it to adapt well to unfamiliar frequencies.

It utilizes the Transformer structure in combination with a distribution head to interpret the input tokens and align them with future predictions and their respective confidence intervals.

Given the complexity of this subject, let’s delve deeper into each primary component.

Creating Tokens Using Lag Features

Lag-Llama’s tokenization approach entails generating delayed features of the series using a designated set of lags.

It specifically selects all suitable frequencies for a particular dataset from the following list:

. quarterly

. monthly

. weekly

. daily

. hourly

. every second

This implies that if we supply a dataset with a daily frequency, Lag-Llama will strive to create features using a daily lag (t-1), a weekly lag (t-7), a monthly lag (t-30), and so forth.

The following image illustrates this approach.

As seen in the above diagram, other fixed covariates are also created, such as second-of-minute, hour-of-day, and continuing up to quarter-of-year.

Although this methodology adapts well to all types of time series, it also has the drawback of potentially creating very large input tokens due to the fixed list of lag indices.

Take, for instance, examining the monthly frequency of hourly data necessitates 730 time steps. Therefore, the length of the input token would be at least 730, excluding all the fixed covariates.

Lag-Llama’s Structure

Lag-Llama employs a unique structure resembling LLaMA, a prominent language model. It utilizes only a decoder component, unlike traditional Transformer models which consist of both an encoder and decoder.

The schematic below illustrates this distinctive architecture.

As indicated in the diagram above, the input token is a combination of delayed time steps and static covariates.

The input sequence is passed through a linear projection layer, which maps the features to the attention module’s hidden dimension in the decoder.

Subsequently, the input sequence is directed to a distribution head, which is tasked with producing a probability distribution.

The input sequence generates the distribution for the next time point during inference. Then, through autoregressive decoding, the model generates the rest of the forecast sequence until the horizon length is reached.

The autoregressive process of generating predictions essentially enables the model to generate uncertainty intervals for its forecasts.

Hence, it’s apparent that the distribution head holds a significant function in Lag-Llama, warranting further exploration.

Understanding the Distribution Head of Lag-Llama

As previously discussed, the distribution head of Lag-Llama is in charge of producing a probability distribution.

This is how the model can yield prediction intervals.

In this version of the model, the final layer employs the Student’s t-distribution to put together the uncertainty intervals.

Technically, different distribution heads could be integrated, but such an experiment was not carried out and is left for future research.

Having gained a more in-depth understanding of Lag-Llama’s internal mechanisms, let’s explore how the model was trained.

Training Lag-Llama

As a foundational, Lag-Llama, naturally, was trained on an extensive corpus of time series data. This enabled the model to generalize effectively on unseen time series and carry out zero-shot forecasting.

In this instance, Lag-Llama was trained on 27 time series datasets from diverse domains including energy, transportation, and economics, among others.

The training corpus comprises 7965 univariate time series, totaling approximately 352 million tokens.

All datasets are open-source, and incorporate popular benchmarks like Etth, Exchange, and Weather.

It’s important to note that the datasets were divided into a training and test set, allowing the authors to use open-source data for training and evaluation of the model.

You can refer to the complete list of datasets used for training.

Let’s now put Lag-Llama to the test in a small forecasting project.

Forecasting with Lag-Llama

In this mini-forecasting project, we initially utilize Lag-Llama’s zero-shot forecasting abilities and compare its effectiveness to data-specific models such as TFT and DeepAR.

It appears that the implementation of Lag-Llama was created using GluonTS, so we employ this library for this experiment.

Specifically, we use the Australian Electricity Demand dataset, which comprises five univariate time series tracking the energy demand at a half-hourly frequency. The dataset can be obtained here.

The complete source code for this experiment can be found on GitHub

Code walkthrough

To use Lag-Llama, we must first clone the repository and install the requirements.

!git clone https://github.com/time-series-foundation-models/lag-llama/ 
cd lag-llama 
pip install -r requirements.txt --quiet 

Once the packages are installed, we can download the weights of the model from HuggingFace.

!huggingface-cli download time-series-foundation-models/Lag-Llama lag-llama.ckpt --local-dir /content/lag-llama 

Load the dataset

Now, we can load the dataset and prepare it for inference.

We start with the required library imports.

import pandas as pd 
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import torch

from itertools import islice

from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.repository.datasets import get_dataset
from lag_llama.gluon.estimator import LagLlamaEstimator

Then, we can load the dataset directly from GluonTS.

dataset = get_dataset("australian_electricity_demand") 
backtest_dataset = dataset.test prediction_length = dataset.metadata.prediction_length
context_length = 3 * prediction_length

Initialize the model

ckpt = torch.load("lag-llama.ckpt", map_location=torch.device('cuda:0')) 
estimator_args = ckpt["hyper_parameters"]["model_kwargs"]
estimator = LagLlamaEstimator( ckpt_path="lag-llama.ckpt", 
prediction_length=prediction_length,
context_length=context_length,
input_size=estimator_args["input_size"],
n_layer=estimator_args["n_layer"],
n_embd_per_head=estimator_args["n_embd_per_head"],
n_head=estimator_args["n_head"],
scaling=estimator_args["scaling"],
time_feat=estimator_args["time_feat"])

lightning_module = estimator.create_lightning_module()
transformation = estimator.create_transformation()
predictor = estimator.create_predictor(transformation, lightning_module)

We then generate zero-shot predictions using the make_evaluation_predictions function.

forecast_it, ts_it = make_evaluation_predictions(
dataset=backtest_dataset,
predictor=predictor)
forecasts = list(forecast_it) 
tss = list(ts_it)

Evaluation of Lag-Llama

GluonTS can conveniently calculate various performance metrics using the Evaluator object.

evaluator = Evaluator() 

agg_metrics, ts_metrics = evaluator(iter(tss), iter(forecasts))

We achieved an RMSE of 481.50 and can visualize predictions. We will show the first four series of the dataset for convenience.

plt.figure(figsize=(20, 15)) 
date_formater = mdates.DateFormatter('%b, %d')
plt.rcParams.update({'font.size': 15})

for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 4):
ax = plt.subplot(2, 2, idx+1)
plt.plot(ts[-4 * dataset.metadata.prediction_length:].to_timestamp(), label="target")
forecast.plot( color='g')

plt.xticks(rotation=60)
ax.xaxis.set_major_formatter(date_formater)
ax.set_title(forecast.item_id)

plt.gcf().tight_layout()
plt.legend()
plt.show()

In the above figure, we can see that the model made sensible predictions on the data, although it does have difficulty with the fourth series (bottom right of the figure).

Furthermore, since Lag-Llama implements probabilistic predictions, we also get uncertainty intervals along with the predictions.

Now that we know how to use Lag-Llama for zero-shot forecasting, let’s compare its performance against data-specific models.

Compared to TFT and DeepAR

For consistency, we continue using the GluonTS library and train TFT and DeepAR models on the dataset to see if they can outperform.

To save time, we limit training to only five epochs.

from gluonts.torch import TemporalFusionTransformerEstimator, DeepAREstimator 

tft_estimator = TemporalFusionTransformerEstimator(
prediction_length=prediction_length,
context_length=context_length,
freq="30min",
trainer_kwargs={"max_epochs": 5})

deepar_estimator = DeepAREstimator(
prediction_length=prediction_length,
context_length=context_length,
freq="30min",
trainer_kwargs={"max_epochs": 5})
tft_predictor = tft_estimator.train(dataset.train) 
deepar_predictor = deepar_estimator.train(dataset.train)

After completing the training process, we create predictions and then calculate the Root Mean Squared Error (RMSE).

# Make predictions

tft_forecast_it, tft_ts_it = make_evaluation_predictions(
dataset=backtest_dataset,
predictor=tft_predictor)

deepar_forecast_it, deepar_ts_it = make_evaluation_predictions(
dataset=backtest_dataset,
predictor=deepar_predictor)

tft_forecasts = list(tft_forecast_it)
tft_tss = list(tft_ts_it)

deepar_forecasts = list(deepar_forecast_it)
deepar_tss = list(deepar_ts_it)

# Get evaluation metrics
tft_agg_metrics, tft_ts_metrics = evaluator(iter(tft_tss), iter(tft_forecasts))
deepar_agg_metrics, deepar_ts_metrics = evaluator(iter(deepar_tss), iter(deepar_forecasts))

Here are the results of the comparisons;
1. Lag-Llama has an RMSE of 481.50

2. TFT has an RMSE of 272.62

3.DeepAR has an RMSE of 445.51

Conclusion

While the performance of Lag-Llama may appear to be underwhelming, it’s important to remember that the model wasn’t fine-tuned, and zero-shot forecasting is inherently more challenging.

Conversely, the data-specific models, TFT and DeepAR, were only trained for five epochs, and yet they both achieved better results than Lag-Llama. While zero-shot forecasting may save time initially, the reality is that training for five epochs isn’t particularly demanding in terms of time and computational power.

However, it’s crucial to remember that Lag-Llama is in its early stages of development. When the capabilities for fine-tuning become available, the model’s performance is likely to improve.

Furthermore, this isn’t an exhaustive benchmark of Lag-Llama’s capabilities — it’s always worth testing it against other methods on a project-by-project basis.

Lag-Llama is an open-source base model for univariate probabilistic forecasting. It utilizes a decoder-only Transformer architecture with a distribution head to generate probabilistic predictions, meaning that uncertainty intervals are instantly available.

The model employs a general tokenization strategy that involves creating lagged features and constructing static covariates such as time-of-day, and day-of-week, among others.

It’s also built on top of GluonTS, meaning that we have to use this library to generate predictions from Lag-Llama for now.

As always, it’s vital to remember that each problem demands its unique solution. Be sure to test Lag-Llama against other methods to find the best fit for your specific needs.

Thank you for reading! I hope you found this informative and that you’ve learned something new!

References

Original repository of Lag-Llama — GitHub

Lag-Llama: Towards Foundation Models for Probabilistic Time Series Forecasting.

--

--

tom odhiambo

ML & data explorer. Riding the AI wave in our modern world. Techie on a mission to tackle real challenges 🚀