Porting machine learning models to mobile for Edge inference
This post is the second in a series of four articles on building an image classification model for Edge inference. We built this model using PyTorch and ported it to a mobile-compatible format. The first article in this series talked about training an image classification model in PyTorch on a web device.
In this article, we’ll talk about converting the trained machine learning model into a mobile-optimized format.
We built the machine learning model using the PyTorch framework and generated a .pt file. Since .pt files are not directly supported on mobile devices, we need to convert them to .ptl format, which is mobile-optimized. Below is the code snippet for converting a .pt file to .ptl format:
We first define the structure of our model and set it to evaluation mode.
# for PyTorch models to convert into .ptl format.
model = efficientnet_b0(pretrained=True)
# since we are using the efficientnet_b0 model as a feature extractor we set
# its parameters to non-trainable (by default they are trainable)
for param in model.parameters():
param.requires_grad = False
# append a new classification top to our feature extractor and pop it on to the current device
model.classifier[1] = nn.Sequential(nn.Linear(in_features=1280, out_features=len(trainDS.classes)), nn.Softmax(dim=1))
model = model.to(DEVICE)
# set the model to evaluation mode before converting
model.eval()
We then convert the model to .ptl format using PyTorch JIT, an optimizing JIT compiler for PyTorch. It helps to automate optimizations like layer fusion, quantization, etc.
# loading and converting the model in .ptl format
state_dict = torch.load(source_model_path+"path to the .pt file", map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
scripted_module = torch.jit.script(model)
scripted_module._save_for_lite_interpreter(source_model_path+ "model_weights_best.ptl")
We now have a .ptl file, which can now be used to run inferences on mobile devices. Watch out for the next post, which details out how we can run inference on mobile devices using the .ptl file.
Stay tuned…