Introduction to PyTorch (6/7)

The V Notebook
2 min readAug 28, 2023

Previous << Introduction to PyTorch (5/7)

In this unit, we will look at how to load a model along with its persisted parameter states and inference model predictions. To load the model, we will define the model class which contains the states and parameters of the neural network used to train the model.

%matplotlib inline
import torch
import onnxruntime
from torch import nn
import torch.onnx as onnx
import torchvision.models as models
from torchvision import datasets
from torchvision.transforms import ToTensor
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
nn.ReLU()
)

def forward(self, x):
x = self.flatten(x)
logits = self.line

When loading model weights, we needed to instantiate the model class first, because the class defines the structure of a network. Next, we load the parameters using the load_state_dict() method.

model = NeuralNetwork()
model.load_state_dict(torch.load('data/model.pth'))
model.eval()

Note: Be sure to call the model.eval() method before inferencing to set the dropout and batch normalization layers to evaluation mode. Failing to do this will yield inconsistent inference results.

--

--

The V Notebook

I'm👩‍💻who have passion for tech, heart for data. My mission? Turning numbers into chapters, algorithms into stories. Let's ride the data science wave! 💻🌊✨