Recovering from training failures

Jaideep Ray
Better ML
Published in
3 min readDec 16, 2023

Checkpointing & resume.

Motivation:

Distributed training runs for large models span multiple GPUs over multiple days and weeks. Training jobs can be interrupted for various reasons:

  1. Hardware failures on device (GPUs) or host machine. (e.g Power outage)
  2. System level failures (e.g Out of memory)
  3. Code issues (e.g bug shipped in eval stage post training).
  4. VM instances (spot instances in AWS terminology) being taken away.

Recovering a training job from an interruption quickly without losing the already done expensive GPU work is an essential trainer feature.

What is checkpointing ?

  • Checkpointing enables you to recover from failures and save the expensive compute. Checkpointing are snapshots of training jobs saved in persistent storage.
  • Trainer interface should implement two APIs : save_checkpoint(…) & resume_from_checkpoint(…). The first api saves a checkpoint on a persistent store like S3 or Elastic file storage. The latter reads checkpoint from the storage and resumes training where it left. To resume training one needs to save state like model’s state_dict, optimizer’s state_dict, epoch and last training loss.

What is recovery time ?

When a training job crashes, it has to be resumed from the last known checkpoint. All the GPU compute performed within the last checkpoint and the moment the job crashed is lost and has to be redone. Recovery time is the time spent to bring the model to the same state as it was prior to the crash.

Crash between epoch i & i+1.

Checkpoint APIs

ML frameworks — Pytorch, TF and trainer frameworks on top of them — Lightning, Huggingface trainer have checkpoint support built in them. There are two main apis : save & resume. Here’s a draft implementation in native pytorch.

save_checkpoint in Pytorch:

torch.save({
'epoch': EPOCH,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': LOSS,
}, PATH)

resume_from_checkpoint in Pytorch:

model = ...
optimizer = ...

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()
Checkpointing

Next, we discuss common challenges with checkpointing and strategies to resolve them.

  1. Checkpoint frequency: Checkpointing frequency is an interesting decision for the model owners to make. The tradeoff is between cost of checkpointing and recovery time. Common defaults are to checkpoint every epoch / every N hours / every K steps.
  2. Atomic copies: Checkpointing requires the model parameters to be atomically copied for further processing and storage. Without atomicity, the training process may modify the checkpoint while it is being written to storage. This will make it inconsistent. A way to avoid this inconsistency is to stall training while creating the snapshot. As soon as the snapshot is ready, GPUs can resume training and a separate process can start processing the checkpoint and write to store.
  3. Training stall: As described in the previous step, synchronous checkpointing introduces training process stalls. This is an overhead and affects training throughput. A way to reduce training stalls is to split checkpointing into 2 phases : snapshot and persist where snapshot process stalls training but persist is done separately.
  • snapshot() : Serializes and copy into an in-memory buffer. The snapshot process can be done from GPU or CPU. The cost of serializing and snapshotting is faster on GPU is 10x+ faster than CPU. If there’s sufficient RAM available, the snapshot can be done from GPU. Otherwise the parameters can be copied to CPU and then snapshot can be created on CPU. The GPU->CPU copy will create a stall.
  • persist() : Write out the serialized contents to disk using fsync or like.

4. Write Bandwidth and Storage Capacity: As model size increases, write bandwidth to checkpoint store and storage capacity will both become bottlenecks. One optimization is writing quantized checkpoints only. Another guardrail is keeping topK active where K=3,5 etc and to only start checkpointing when the last checkpoint has been written to the store.

Conclusion:

  • We read about the important decision variables for checkpointing and resuming.
  • It’s better to implement a basic version of checkpointing & resume, observe the bottlenecks and invest in the right design.

To learn more about checkpointing here are two excellent references:

CheckFreq: Frequent, Fine-Grained DNN Checkpointing

Check-N-Run: Checkpointing for Training Recsys models.

--

--