Building a PyTorch based image classifier for on device inference

Madhur Zanwar
Eumentis
Published in
3 min readJan 30, 2024
Machine Learning on device.

This article is the third in a series of four articles on building an image classification model in PyTorch for on-device inference. The first two articles of the series can be found here:

  1. Training an image classification model in PyTorch
  2. Porting machine learning models to mobile for Edge inference

In this article, we’ll talk about running the mobile-optimized model, that we have with us as a .ptl file, on mobile devices in React Native.

Our mobile app was already developed with extensive features using React Native. Our target was to run our model in this app. Rebuilding the app was not an option as it would’ve resulted in considerable duplication of effort. So we went researching for a react native package that would integrate with PyTorch mobile and allow us to

  • perform some preprocessing steps on the input image.
  • run on device inference on the processed image.

We went with Playtorch’s react-native-pytorch-core developed by the Facebook research team. I would like to take a moment to thank the good souls who do such great work and make it open-source.

Install react-native-pytorch-core

pip install react-native-pytorch-core

While no additional steps are necessary for iOS, you’ll need to perform some extra setup on Android to ensure smooth operation with react-native-pytorch-core. Follow the provided steps here for the required Android setup.

Steps to integrate react-native with PyTorch based ML model using react-native-pytorch-core

Before utilizing the model for inference, it’s crucial to execute the same pre-processing steps that were employed during the model’s web training. The following code snippet illustrates how to load an image, convert it to a tensor, and apply the essential pre-processing steps within react-native. These steps mainly involve resizing the tensor and normalizing the values, which were carried out during the model’s training on the web.
We take special care to replicate the same steps, as this allows us to make meaningful comparisons of the results at the end. While it’s not necessary to replicate these steps, doing so provides us with a benchmark — our web model. This enables us to compare the results of the mobile model (.ptl) with the benchmark web results.

  // loading the image
const image = await ImageUtil.fromFile(path to image);

// getting the image height and width
let imageWidth = image.getWidth();
let imageHeight = image.getHeight();

// converting the image to a blob
const blob = media.toBlob(image);

// from blob converting it to a tensor
let tensor = torch.fromBlob(blob, [imageHeight, imageWidth, 3]);

// Rearrange the tensor shape to be [CHW]
tensor = tensor.permute([2, 0, 1]);

// Divide the tensor values by 255 to get values between [0, 1]
tensor = tensor.div(255);

// resize the tensor to [244,244] shape
const resize = T.resize([224,224]);
tensor = resize(tensor);

const normalize = T.normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]);
tensor = normalize(tensor);

// unsqueezing the tensor. The shape now will be [1,244,244]
const formattedInputTensor = tensor.unsqueeze(0);

// running inference
const output = (await model.forward(formattedInputTensor))[0];

Here we’ve completed the pre-processing steps and inference part on a mobile device. We are now ready to compare the web model results with our mobile model results. Check out our next post on the comparison between the web (.pt) and the mobile (.ptl) model’s inference results.

--

--