PyTorch Module — quick reference

geekgirldecodes
HowsOfCoding
3 min readFeb 5, 2022

--

Hi there! Welcome back to torch thursdays (*delayed) blog. In this quick reference , we will touch upon the definition of a PyTorch model, what is torch.nn.module, common PyTorch layer types and also look at a small example to define your own custom modules in PyTorch along with save and load api.

Models in PyTorch

  • Models can be designed in PyTorch using two classes namely, torch.nn.Module and torch.nn.parameter.
  • torch.nn.Module is fundamental unit of a model in PyTorch. They are the building blocks of stateful computations. You can define custom layer types as sub-classes of this type.
  • torch.nn.parameter is a sub-class of a torch tensor (covered in this blog) and is used to represent learning weights of a given module / layer.
  • By default, parameters and floating-point buffers for modules provided by torch.nn are initialized during module instantiation as 32-bit floating point values on the CPU using an initialization scheme determined to perform well historically for the module type.

Why should you use this module class?

  • Modules make it simple to specify learnable parameters for PyTorch’s Optimizers to update.
  • Modules are straightforward to save and restore, transfer between CPU / GPU / TPU devices, prune, quantize, and more.

Common Layer types

  • PyTorch has implemented some of the common layers used in ML models such as fully connected / Linear layers, Conv2D, BatchNorm etc. — with their forward pass method as well as gradient methods (auto grad) for backward pass computation.
  • Users can implement their own layer types by inheriting from the nn.Module class. For ex : HuggingFace transformer model implementation uses custom Layer definitions — check out here.
  • Shown here is an example of how to implement a Linear layer.
  • The forward() implementation for a given module can perform arbitrary computation involving any number of inputs and outputs. As we know, a linear layer computation is y=x*w+b, y is the output, x is the input, w is weights and b is the bias.
import torch
from torch import nn

class MyLinear(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight = nn.Parameter(torch.randn(in_features, out_features))
self.bias = nn.Parameter(torch.randn(out_features))

def forward(self, input):
return (input*self.weight) + self.bias
  • We can also define models which contain PyTorch defined layer types, as shown below.
class CustomMod(torch.nn.Module):       def __init__(self):
super(CustomModel, self).__init__()
self.linear1 = torch.nn.Linear(100,200)
self.act = torch.nn.ReLU()

def forward(self, x):
x = linear1(x)
x = act(x)
return x

Instantiate a PyTorch model

custom_model = CustomModel()
learnable_params_model = custom_model.paramaters()
# these parameters are passed through optimizer during training.

Save model to disk

The various types of state a module can have:

  • Parameters: learnable aspects of computation; contained within the state_dict.
  • Buffers: non-learnable aspects of computation
  • Persistent buffers: contained within the state_dict (i.e. serialized when saving & loading)
  • Non-persistent buffers: not contained within the state_dict (i.e. left out of serialization)

If we want to save the trained model to disk, we can do so by saving its state_dict (i.e. “state dictionary”):

torch.save(custom_model.state_dict(), 'custom_model.pt')

Now, if we want to load this saved model, we can do so as shown below.

# Load the module later on
custom_model = CustomModel()
custom_model.load_state_dict(torch.load('custom_model.pt'))

And, that’s it for this quick blog. Like always if you learnt anything from this blog, drop a clap below — so it can reach more folks on the platform. If you want a specific topic covered, drop your suggestion below. Thank you~

References

  1. PyTorch video tutorial https://www.youtube.com/watch?v=OSqIP-mOWOI
  2. https://pytorch.org/docs/stable/notes/modules.html#modules

--

--

geekgirldecodes
HowsOfCoding

Full-time engineer, part-time procrastinator — always overdosing on coffee! 8). Author at publication : @howsofcoding