Introduction to PyTorch (6/7)
Previous << Introduction to PyTorch (5/7)
In this unit, we will look at how to load a model along with its persisted parameter states and inference model predictions. To load the model, we will define the model class which contains the states and parameters of the neural network used to train the model.
%matplotlib inline
import torch
import onnxruntime
from torch import nn
import torch.onnx as onnx
import torchvision.models as models
from torchvision import datasets
from torchvision.transforms import ToTensor
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
nn.ReLU()
)
def forward(self, x):
x = self.flatten(x)
logits = self.line
When loading model weights, we needed to instantiate the model class first, because the class defines the structure of a network. Next, we load the parameters using the load_state_dict() method.
model = NeuralNetwork()
model.load_state_dict(torch.load('data/model.pth'))
model.eval()
Note: Be sure to call the model.eval() method before inferencing to set the dropout and batch normalization layers to evaluation mode. Failing to do this will yield inconsistent inference results.