PyTorch Module — quick reference
Hi there! Welcome back to torch thursdays (*delayed) blog. In this quick reference , we will touch upon the definition of a PyTorch model, what is torch.nn.module, common PyTorch layer types and also look at a small example to define your own custom modules in PyTorch along with save and load api.
Models in PyTorch
- Models can be designed in PyTorch using two classes namely, torch.nn.Module and torch.nn.parameter.
- torch.nn.Module is fundamental unit of a model in PyTorch. They are the building blocks of stateful computations. You can define custom layer types as sub-classes of this type.
- torch.nn.parameter is a sub-class of a torch tensor (covered in this blog) and is used to represent learning weights of a given module / layer.
- By default, parameters and floating-point buffers for modules provided by
torch.nn
are initialized during module instantiation as 32-bit floating point values on the CPU using an initialization scheme determined to perform well historically for the module type.
Why should you use this module class?
- Modules make it simple to specify learnable parameters for PyTorch’s Optimizers to update.
- Modules are straightforward to save and restore, transfer between CPU / GPU / TPU devices, prune, quantize, and more.
Common Layer types
- PyTorch has implemented some of the common layers used in ML models such as fully connected / Linear layers, Conv2D, BatchNorm etc. — with their forward pass method as well as gradient methods (auto grad) for backward pass computation.
- Users can implement their own layer types by inheriting from the nn.Module class. For ex : HuggingFace transformer model implementation uses custom Layer definitions — check out here.
- Shown here is an example of how to implement a Linear layer.
- The
forward()
implementation for a given module can perform arbitrary computation involving any number of inputs and outputs. As we know, a linear layer computation is y=x*w+b, y is the output, x is the input, w is weights and b is the bias.
import torch
from torch import nn
class MyLinear(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight = nn.Parameter(torch.randn(in_features, out_features))
self.bias = nn.Parameter(torch.randn(out_features))
def forward(self, input):
return (input*self.weight) + self.bias
- We can also define models which contain PyTorch defined layer types, as shown below.
class CustomMod(torch.nn.Module): def __init__(self):
super(CustomModel, self).__init__() self.linear1 = torch.nn.Linear(100,200)
self.act = torch.nn.ReLU()
def forward(self, x):
x = linear1(x)
x = act(x)
return x
Instantiate a PyTorch model
custom_model = CustomModel()
learnable_params_model = custom_model.paramaters()# these parameters are passed through optimizer during training.
Save model to disk
The various types of state a module can have:
- Parameters: learnable aspects of computation; contained within the
state_dict.
- Buffers: non-learnable aspects of computation
- Persistent buffers: contained within the
state_dict
(i.e. serialized when saving & loading) - Non-persistent buffers: not contained within the
state_dict
(i.e. left out of serialization)
If we want to save the trained model to disk, we can do so by saving its state_dict
(i.e. “state dictionary”):
torch.save(custom_model.state_dict(), 'custom_model.pt')
Now, if we want to load this saved model, we can do so as shown below.
# Load the module later on
custom_model = CustomModel()
custom_model.load_state_dict(torch.load('custom_model.pt'))
And, that’s it for this quick blog. Like always if you learnt anything from this blog, drop a clap below — so it can reach more folks on the platform. If you want a specific topic covered, drop your suggestion below. Thank you~