Understanding and using the class structure for creating models with PyTorch 🔥
For creating the general models in PyTorch we use
nn.Sequential API and create our neural network model.
model = nn.Sequential(
here the problem with this structure is that you can’t have hybrid models. Create complicated models and you are not having full control over your model and pass the data on the model.
For having whole control of the model and creating a model which will be helpful to us in real machine learning we can do the same using class like structure and inheriting the
torch.nn.Module from PyTorch module class
import torch.nn as nn class MyModel(torch.nn.Module):
self.lay1 = nn.Conv2d(1,20,5)
self.lay2 = nn.Conv2d(20,64,5)
self.flatten = nn.Flatten()
self.relu = nn.ReLU()
ins = self.relu(self.lay1(inputs))
ins = self.relu(self.lay2(ins))
output = self.flatten(ins)
here the advantage is you are having whole control over how does the data flow into the model you can have multiple layers and hybrid structures like this shown in the image below
While you are controlling the whole flow of the data which is inside your model let’s go through what does each part we have in our code previously shows
What is __init__ in class structure?
This method helps in object-oriented programming in python. this is a reserved method and it acts as a constructor to our class.
constructor is the first thing which initializes the contents in the class it’s the first method getting called when we create an object of our model and will get our Conv and Linear layers initialized with random weights
What is super(MyModel,self).__init__() doing ?
super(MyModel, self).__init__() refers to the fact that this is a subclass of
nn.Module and is inheriting all methods. With defining this we are ready to use all precooked code that is already implemented in
nn.Module class of PyTorch.
In simple words the super method lets you use all the modules implemented in torch.nn.Module class.
What is self in class?
self represents the instance of the class this attribute can be used to access the attributes and the methods (functions ) defined inside the class. In the above example, we are making use of self to declare the layers so that we can use them in forward method.
What is the forward method?
this method is the driving method for our model this is where the weights of our layers get updated after calling our model in the training loop.
thanks for reading my blog :) follow for more say hi to me in comments it gives me encouragement for writing more blogs :) have a good day :)