Experiment Tracking With AWS SageMaker and PyTorch Lightning

Tobias Senst
idealo Tech Blog
Published in
6 min readOct 19, 2023

For training or developing machine learning models, logging the convergence curves and accuracy metrics such as loss, precision, or F1 scores is essential. Not only during the training but also during the validation and testing data phases. Moreover, being able to compare these metrics data across various experiments and runs is a key capability in our machine learning life cycle at idealo.de. Experimentation data not only contains evaluation metrics but may also model architectures, hyper-parameter settings, and training data or prepossessing configurations. Everything that is needed to reproduce an experiment

Photo by Alex Kondratiev on Unsplash

In modern machine learning lifecycles, these features are provided by tools such as MLflow, Weights&Biases, or Tensorboard. With SageMaker Experiments AWS has a service that allows you to manage, analyze, and compare machine learning experiments for the SageMaker domain.

In this article, an experimentation data logger for the PyTorch Lightning API that is based on AWS SageMaker Experiments will be presented.

With the PyTorch Lightning framework, AWS SageMaker-based experimentation tracking can be implemented with only a few lines of code:

from pytorch_lightning import Trainer
from experiments_addon.logger import SagemakerExperimentsLogger

sagemaker_logger = SagemakerExperimentsLogger(
experiment_name="TestExp",
run_name="TestRun"
)

trainer = Trainer(
logger=sagemaker_logger,
...
)
trainer.fit(...)

The SagemakerExperimentsLogger implementation is available on that Github Project.

You can install the latest version with pip using SSH with:

pip install sagemaker-experiments-logger

MNIST Classification Example

We will demonstrate the application of the Experiment Logger with the minimal MNIST image classification example on the Lightning.ai webpage.

Let’s implement a simple classification model.

At first, we define a set of evaluation metrics by using the MetricCollection class of the TorchMetrics package to define evaluation metrics.

from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassF1Score, MulticlassAccuracy

def create_metric_collection(no_classes: int) -> MetricCollection:
metrics = {
"F1": MulticlassF1Score(num_classes=no_classes, average="macro"),
"Accuracy": MulticlassAccuracy(num_classes=no_classes),
}

return MetricCollection(metrics)

The model consists of a simple linear layer that implements a regression-like classifier.

class MNISTModel(LightningModule):
def __init__(
self,
data_dir: str,
learning_rate: float = 0.02,
train_batch_size: int = 32,
val_batch_size: int = 32,
test_batch_size: int = 32,
no_classes: int = 10,
):
super().__init__()
self.data_dir = data_dir
self.save_hyperparameters()
self.transform = transforms.ToTensor()
self.train_metrics = create_metric_collection(
no_classes=no_classes
).clone(prefix="Train-")
self.val_metrics = create_metric_collection(
no_classes=no_classes
).clone(prefix="Val-")
self.l1 = torch.nn.Linear(28 * 28, 10)

self.train_loss = 0
self.train_step_count = 0
self.tic = time.time()

def configure_optimizers(self) -> torch.optim.Optimizer:
return torch.optim.Adam(
self.parameters(), lr=self.hparams.learning_rate
)

def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))


def prepare_data(self):
MNIST(self.data_dir, train=True, download=True)

def setup(self, stage=None):
if stage == "fit" or stage is None:
mnist_full = MNIST(
self.data_dir, train=True, transform=self.transform
)
self.mnist_train, self.mnist_val = random_split(
mnist_full, [55000, 5000]
)

def train_dataloader(self):
return DataLoader(
self.mnist_train, batch_size=self.hparams.train_batch_size
)

def val_dataloader(self):
return DataLoader(
self.mnist_val, batch_size=self.hparams.val_batch_size
)

To compute the metrics, we have to overwrite the following methods:

  • training_step
  • on_train_epoch_end
  • validation_step
  • on_validation_epoch_end

and activate the training and validation loop of the lightning module.

class MNISTModel(LightningModule):

...

def training_step(
self, batch: Tuple[torch.Tensor, torch.Tensor], batch_nb: Optional[int]
) -> torch.Tensor:
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
preds = torch.argmax(logits, dim=1)

self.train_loss += loss.item()
self.train_step_count += 1

metrics = self.train_metrics(preds, y)
metrics["loss"] = self.train_loss / self.train_step_count
self.log_dict(metrics, on_epoch=True, prog_bar=True, logger=False)
return loss

def on_train_epoch_end(self) -> None:
self.train_epoch_runtime = time.time() - self.tic
self.tic = time.time()
self.train_loss = 0
self.train_step_count = 0
metric_dict = self.train_metrics.compute()
metric_dict["Runtime"] = self.train_epoch_runtime
self.log_dict(metric_dict, on_epoch=True, prog_bar=True, logger=True)
self.train_metrics.reset()

def validation_step(
self, batch: Tuple[torch.Tensor, torch.Tensor], batch_nb: Optional[int]
) -> None:
x, y = batch
logits = self(x)
loss = F.cross_entropy(logits, y)
preds = torch.argmax(logits, dim=1)

metrics = self.val_metrics(preds, y)
self.val_loss += loss.item()
self.val_step_count += 1
metrics["loss"] = self.val_loss / self.val_step_count
self.log_dict(metrics, on_epoch=True, prog_bar=True, logger=False)

def on_validation_epoch_end(self) -> None:
self.val_loss = 0
self.val_step_count = 0
metric_dict = self.val_metrics.compute()
self.log_dict(metric_dict, on_epoch=True, prog_bar=True, logger=True)
self.val_metrics.reset()

The iterative update of the training validation metrics will be implemented in the step method. In the code block:

...
metrics = self.train_metrics(preds, y)
metrics["loss"] = self.train_loss / self.train_step_count
self.log_dict(metrics, on_epoch=True, prog_bar=True, logger=False)
...

The training metrics will be updated and return a dictionary containing the interim accuracy results. The metric dictionary can be extended due to your needs. With the log_dict method, the metrics will be logged and passed to the logger, while in this code block, the logger has been disabled and only the progress bar will be updated.

For the step methods, the logger has been disabled because the SageMaker Experiments API has a rate limit that can be easily exceeded. For fast training models, it is recommended to log only at the end of training, validation, or testing epochs.

How to Setup The Logger

The SagemakerExperimentLogger class can be easily applied to the PyTorch Lightning Trainer class to track the experiment data during model training, validation, and testing. AWS Sagemaker Experiments are organized by experiments and runs. A run is a subunit of an experiment. When logging the data of a model training, a run name and an experiment name have to be specified. Both are used to create a run-context that is needed by any metric writing method of the SageMake API. But first, let’s load the MNIST data and create our simple MNIST model:

import os
from example.model import MNISTModel

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
mnist_model = MNISTModel(data_dir=PATH_DATASETS)

Use the Logger Within a Run-Context

For use in a notebook, it is recommended to create a run-object with the with statement:

from pytorch_lightning import Trainer
from sagemaker.experiments.run import Run
from experiments_addon.logger import SagemakerExperimentsLogger

with Run(experiment_name="testExperiment", run_name="testRun1"):
logger = SagemakerExperimentsLogger()
trainer = Trainer(
logger=logger,
max_epochs=3,
)
trainer.fit(mnist_model)
trainer.test()

All log methods have to be called inside with a statement. Since the Trainer calls the logger log methods internally, it needs to be within the with statement too. The SagemakerExperimentLogger object will be called without providing experiment_name and run_name since both properties are retrieved inside the logger from the run context.

Please note. If you want to use the logger inside an AWS Training Job that already has a run-context configured, you can use the logger as above but without any with statement.

Use the Logger by Explicitly Passing in Run and Experiment Name

There may be cases where the with statement is not practical. For that reason, it is possible to obtain the run context for each log operation. To run the SagemakerExperimentLogger without the with statement, provide the experiment_name and run_name when creating the object.

from pytorch_lightning import Trainer
from sagemaker.experiments.run import Run
from experiments_addon.logger import SagemakerExperimentsLogger

logger = SagemakerExperimentsLogger(
experiment_name="TestExp",
run_name="TestRun"
)

trainer = Trainer(
logger=logger,
max_epochs=3,
)
trainer.fit(mnist_model)
trainer.test()

Internally, the sagemaker.experiment.run.load_run method will be called with the experiment and run names. It should be mentioned that the method will print a warning to stdout if the run exists, which will be after each logging operation. This can be a little disturbing.

Use Multiple Logger

PyTorch-Lightning Trainer allows the setup of multiple loggers. This is also possible with the SagemakerExperimentLogger. Just pass a list of loggers.

from pytorch_lightning import Trainer
from sagemaker.experiments.run import Run
from experiments_addon.logger import SagemakerExperimentsLogger
from pytorch_lightning.loggers import TensorBoardLogger

tensorboard_logger = TensorBoardLogger()
with Run(experiment_name="testExperiment", run_name="testRun2"):
logger = SagemakerExperimentsLogger()
trainer = Trainer(
logger=[logger, tensorboard_logger],
max_epochs=3,
)
trainer.fit(mnist_model)
trainer.test()
The image shows an example View of SageMaker Studio UI for Experiment Tracking. The top of the window shows a table. The table has 4 rows. Each row shows the evaluation results and some parameters of four different experiments done with a transformer model. At the bottom of the window there a two plots site by site. The left plot shows a scatter plot with the training time on the x-axis and the validation macro f1 score on the y-axis for all four experiments. The right plot shows the loss curves
Example View of SageMaker Studio UI for Experiment Tracking.

Conclusion

You should now be able to keep track of your machine learning model training and experimental data with PyTorch Lightning and in your AWS SageMaker environment and do some experiments on your own. We hope that you find our small project useful.

Links

Do you love agile product development? Have a look at our vacancies.

--

--