Demo for Continuing Training with Checkpoints (in PyTorch)
Published in
3 min readMay 12, 2020
This is a quick notebook on how to train deep learning models in phases: for example, you can train for 5 epochs and save it, and later you can load the parameters and exactly start from where you left. It is very useful for users that rely on free but ‘frequently disconnecting’ or ‘limited hours’ cloud GPUs available on Google Colab or Kaggle notebooks.
Import and Data I/O
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os, sys
import time
import torch
from torch.autograd import Variable
from tqdm.notebook import tqdm
import sklearn.preprocessing as pre
plt.rcParams['figure.figsize'] = [10,10]
# Import data
df = pd.DataFrame()
df['plastic'] = pd.Series([100, 50, 30])
df['paper'] = pd.Series([0, 50, 0])
df['glass'] = pd.Series([0, 50, 90])
df['student'] = pd.Series([1, 0, 0])
df['worker'] = pd.Series([0, 0, 1])
df['elder'] = pd.Series([0, 1, 0])
data_x = np.array(df[["plastic","paper","glass"]], dtype=np.float32)
data_y = np.array(df[["student","worker","elder"]], dtype=np.float32)
x_train = torch.from_numpy(data_x)
y_train = torch.from_numpy(data_y)
Utility Functions
def train(model, optimizer, losslogger, start_epoch, num_epoch, run_id, X, y, checkpoint_name, verbose=False):
global_iteration = 0
loss_function = torch.nn.MSELoss()
t = tqdm(range(start_epoch,start_epoch+num_epoch))
for epoch in t:
input = Variable(X)
target = Variable(y)
# forward
out = model(input)
loss = loss_function(out, target)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
# show
t.set_description('Epoch[{}/{}], loss: {:.6f}'
.format(epoch + 1, num_epoch, loss.data.item()))
# set in logs
df = pd.DataFrame()
df['chackpoint_name'] = pd.Series(checkpoint_name)
df['epoch'] = pd.Series(epoch)
df['Loss'] = pd.Series(loss.data.item())
df['run'] = run_id
losslogger = losslogger.append(df)
if verbose==True:
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
# predicting
print(model(torch.tensor([[500, 500, 500]], dtype=torch.float32)))
state = {'epoch': epoch + 1, 'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(), 'losslogger': losslogger, }
torch.save(state, f'{checkpoint_name}')
def load_checkpoint(model, optimizer, losslogger, filename):
# Note: Input model & optimizer should be pre-defined. This routine only updates their states.
start_epoch = 0
if os.path.isfile(filename):
print("=> loading checkpoint '{}'".format(filename))
checkpoint = torch.load(filename)
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
losslogger = checkpoint['losslogger']
print("=> loaded checkpoint '{}' (epoch {})"
.format(filename, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(filename))
return model, optimizer, start_epoch, losslogger
First Training Loop
# Start
start_epoch = 0
# number of epochs
num_epoch = 100
# Logger
losslogger = pd.DataFrame()
# Model
model = torch.nn.Sequential(
torch.nn.Linear(3,3, bias=True),
torch.nn.ReLU(),
torch.nn.Linear(3,3, bias=True),
torch.nn.ReLU(),
torch.nn.Linear(3,3, bias=True),
torch.nn.ReLU(),
torch.nn.Softmax(dim=1)
)
# optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
# Checkpoint name
checkpoint_name = 'checkpoint.pth.tar'
train(model, optimizer, losslogger, start_epoch, num_epoch, 0, x_train, y_train, checkpoint_name)
time.sleep(8)
Load and Continue Training
# Load and continute train // run 1
model, optimizer, start_epoch, losslogger = load_checkpoint(model, optimizer, losslogger, filename=checkpoint_name)
!rm 'checkpoint.pth.tar'
train(model, optimizer, losslogger, start_epoch, num_epoch, 1, x_train, y_train, checkpoint_name)
time.sleep(8)
# Load and continute train // run 2
model, optimizer, start_epoch, losslogger = load_checkpoint(model, optimizer, losslogger, filename=checkpoint_name)
!rm 'checkpoint.pth.tar'
train(model, optimizer, losslogger, start_epoch, num_epoch, 2, x_train, y_train, checkpoint_name)
time.sleep(8)
# Load and continute train // run 3
model, optimizer, start_epoch, losslogger = load_checkpoint(model, optimizer, losslogger, filename=checkpoint_name)
!rm 'checkpoint.pth.tar'
train(model, optimizer, losslogger, start_epoch, num_epoch, 3, x_train, y_train, checkpoint_name)
time.sleep(8)
# Load and continute train // run 4
model, optimizer, start_epoch, losslogger = load_checkpoint(model, optimizer, losslogger, filename=checkpoint_name)
!rm 'checkpoint.pth.tar'
train(model, optimizer, losslogger, start_epoch, num_epoch, 4, x_train, y_train, checkpoint_name)
time.sleep(8)
Verify Loss is Continuous Across Loading Cycles
count = 0
for run_id in tqdm(range(len(losslogger.run.unique()))):
df = losslogger[losslogger.run==run_id]
n = len(df)
plt.plot(np.arange(count,count+n),losslogger[losslogger.run==run_id].Loss.values);
count += n
plt.show()