Demo for Continuing Training with Checkpoints (in PyTorch)

Debanga Raj Neog, Ph.D
Depurr
Published in
3 min readMay 12, 2020

This is a quick notebook on how to train deep learning models in phases: for example, you can train for 5 epochs and save it, and later you can load the parameters and exactly start from where you left. It is very useful for users that rely on free but ‘frequently disconnecting’ or ‘limited hours’ cloud GPUs available on Google Colab or Kaggle notebooks.

Photo by Charles Deluvio on Unsplash

Import and Data I/O

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os, sys
import time

import torch
from torch.autograd import Variable
from tqdm.notebook import tqdm
import sklearn.preprocessing as pre

plt.rcParams['figure.figsize'] = [10,10]

# Import data
df = pd.DataFrame()
df['plastic'] = pd.Series([100, 50, 30])
df['paper'] = pd.Series([0, 50, 0])
df['glass'] = pd.Series([0, 50, 90])
df['student'] = pd.Series([1, 0, 0])
df['worker'] = pd.Series([0, 0, 1])
df['elder'] = pd.Series([0, 1, 0])

data_x = np.array(df[["plastic","paper","glass"]], dtype=np.float32)
data_y = np.array(df[["student","worker","elder"]], dtype=np.float32)

x_train = torch.from_numpy(data_x)
y_train = torch.from_numpy(data_y)

Utility Functions

def train(model, optimizer, losslogger, start_epoch, num_epoch, run_id, X, y, checkpoint_name, verbose=False):

global_iteration = 0
loss_function = torch.nn.MSELoss()

t = tqdm(range(start_epoch,start_epoch+num_epoch))

for epoch in t:
input = Variable(X)
target = Variable(y)

# forward
out = model(input)
loss = loss_function(out, target)

# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()

# show
t.set_description('Epoch[{}/{}], loss: {:.6f}'
.format(epoch + 1, num_epoch, loss.data.item()))

# set in logs
df = pd.DataFrame()
df['chackpoint_name'] = pd.Series(checkpoint_name)
df['epoch'] = pd.Series(epoch)
df['Loss'] = pd.Series(loss.data.item())
df['run'] = run_id
losslogger = losslogger.append(df)


if verbose==True:
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])

# predicting
print(model(torch.tensor([[500, 500, 500]], dtype=torch.float32)))

state = {'epoch': epoch + 1, 'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(), 'losslogger': losslogger, }
torch.save(state, f'{checkpoint_name}')

def load_checkpoint(model, optimizer, losslogger, filename):
# Note: Input model & optimizer should be pre-defined. This routine only updates their states.
start_epoch = 0
if os.path.isfile(filename):
print("=> loading checkpoint '{}'".format(filename))
checkpoint = torch.load(filename)
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
losslogger = checkpoint['losslogger']
print("=> loaded checkpoint '{}' (epoch {})"
.format(filename, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(filename))

return model, optimizer, start_epoch, losslogger

First Training Loop

# Start
start_epoch = 0

# number of epochs
num_epoch = 100

# Logger
losslogger = pd.DataFrame()

# Model
model = torch.nn.Sequential(
torch.nn.Linear(3,3, bias=True),
torch.nn.ReLU(),
torch.nn.Linear(3,3, bias=True),
torch.nn.ReLU(),
torch.nn.Linear(3,3, bias=True),
torch.nn.ReLU(),
torch.nn.Softmax(dim=1)
)

# optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)


# Checkpoint name
checkpoint_name = 'checkpoint.pth.tar'

train(model, optimizer, losslogger, start_epoch, num_epoch, 0, x_train, y_train, checkpoint_name)
time.sleep(8)

Load and Continue Training

# Load and continute train // run 1
model, optimizer, start_epoch, losslogger = load_checkpoint(model, optimizer, losslogger, filename=checkpoint_name)
!rm 'checkpoint.pth.tar'
train(model, optimizer, losslogger, start_epoch, num_epoch, 1, x_train, y_train, checkpoint_name)
time.sleep(8)

# Load and continute train // run 2
model, optimizer, start_epoch, losslogger = load_checkpoint(model, optimizer, losslogger, filename=checkpoint_name)
!rm 'checkpoint.pth.tar'
train(model, optimizer, losslogger, start_epoch, num_epoch, 2, x_train, y_train, checkpoint_name)
time.sleep(8)

# Load and continute train // run 3
model, optimizer, start_epoch, losslogger = load_checkpoint(model, optimizer, losslogger, filename=checkpoint_name)
!rm 'checkpoint.pth.tar'
train(model, optimizer, losslogger, start_epoch, num_epoch, 3, x_train, y_train, checkpoint_name)
time.sleep(8)

# Load and continute train // run 4
model, optimizer, start_epoch, losslogger = load_checkpoint(model, optimizer, losslogger, filename=checkpoint_name)
!rm 'checkpoint.pth.tar'
train(model, optimizer, losslogger, start_epoch, num_epoch, 4, x_train, y_train, checkpoint_name)
time.sleep(8)

Verify Loss is Continuous Across Loading Cycles

count = 0
for run_id in tqdm(range(len(losslogger.run.unique()))):
df = losslogger[losslogger.run==run_id]
n = len(df)
plt.plot(np.arange(count,count+n),losslogger[losslogger.run==run_id].Loss.values);
count += n
plt.show()
png

References:

--

--

Debanga Raj Neog, Ph.D
Depurr
Editor for

Entrepreneur. Computer Vision Expert. Machine Learning Nerd. Buy me a Chimay Blue.