How to save our model to Google Drive and reuse it

Avinash
2 min readDec 16, 2018

--

If you are using Google Colab and if the runtime restarts during training, you will lose your trained model. Then you have to start again from the scratch, which is not optimal. Instead, you should save your model checkpoint to Google Drive at every epoch and reload it next time when you start. In this short tutorial, I will explain how to mount Google Drive, save our model checkpoint to Google Drive and reload it back.

How to mount Google Drive

Google Colab runs isolated from Google Drive, so you cannot access it directly. To access it, you need to authenticate, give permissions to Colab so that it can access it and mount the drive.

Add the following code to a cell:

from google.colab import drive
drive.mount('/content/gdrive')

Once run, it will print an URL, asking for an authorization code:

Visit that URL, authorize with your Google account and copy the code which it displays. Paste the code back in the colab notebook. Once authorized, it displays that our drive has been mounted:

You should able to see your drive in sidebar, under Files tab. Notice that our drive is mounted under path `/content/gdrive/My Drive`. If you do ls you should be able to see your drive contents:

!ls /content/gdrive/My Drive

How to save your model in Google Drive

Make sure you have mounted your Google Drive. Now, to save our model checkpoint (or any file), we need to save it at the drive’s mounted path.

To save our model, we just use torch.savemethod:

model_save_name = 'classifier.pt'
path = F"/content/gdrive/My Drive/{model_save_name}"
torch.save(model.state_dict(), path)

Now, if you visit your google drive at https://drive.google.com/drive/my-drive you will be able to see classifier.pt file saved!

How to load the model from Google Drive

Make sure you have mounted your Google Drive. Now, we will access our saved model checkpoint from the Google Drive and use it. We know the path and we will use that in torch.load :

model_save_name = 'classifier.pt'
path = F"/content/gdrive/My Drive/{model_save_name}"
model.load_state_dict(torch.load(path))

That’s it! Make sure you load the state dict just before your training starts.

Pro tip: Save the model name along with its accuracy. So that you can pick the best model available. Hope this helps!

If you need any help, feel free to ping me on Slack. My id is avinashss

--

--