Logistic Regression on Fashion: MNIST using PyTorch

Harsh R
5 min readJun 4, 2020

--

Welcome to my second post from the series on “Deep learning with PyTorch: Zero to GANs” taught by the team at jovian.ml. This post demonstrates how to perform logistic regression on Fashion-MNIST.

If you are not familiar and want to learn about PyTorch and its basic tensor operations the visit Beginners guide to Tensor operations in PyTorch.

About Data

Fashion-MNIST dataset is a dataset of Zalando’s article images. This dataset is divided into train and test sets. The training set has 60,000 samples and testing set has 10,000 samples. Fashion-MNIST has the same 28x28 image size and structure of training and testing splits as MNIST dataset. The images are grey scale and are labeled from these 10 classes:

  1. T-shirt/top
  2. Trouser/pants
  3. Pullover shirt
  4. Dress
  5. Coat
  6. Sandal
  7. Shirt
  8. Sneaker
  9. Bag
  10. Ankle boot

Let’s dive into code:

Imports

First, make sure you import all the necessary packages that we are going to need.

# Imports
import torch
import jovian
import torchvision
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import FashionMNIST
from torch.utils.data import random_split
from torch.utils.data import DataLoader

If you encounter any errors during importing packages, just use pip or conda to install the package.

For example : pip install <package-name> or conda install <package-name>

Set-up Hyperparameters

Now its time to set up hyperparameters and constraints for later use.

# Hyperparmeters
batch_size = 128
learning_rate = 0.001
# Other constants
input_size = 28*28
num_classes = 10

As you can see, image is 28x28 and number of classes are 10.

After importing packages and setting all the hyperparameters, now we will proceed to load the datasets.

We first download the Fashion-MNIST dataset from torchvision package that we imported earlier. The torchvision package consists of many popular datasets such as MNIST, Fashion-MNIST, CIFAR10 and many more. Notice while downloading, we use use ToTensor() because we need to use predefined functions from torchvision.transforms to convert our images to PyTorch Tensor. This is to prepare the data for use with the regression model. Later, we use random_split to randomly divide the data into training set (train_ds), which we will use to train the model and validation set (val_set), which will be used to evaluate the model and adjust the hyperparameters. The train set will contain 50,000 images and validation set will contain 10,000 images. The test set will have 10,000 images which will be used for comparing different versions of models and return the accuracy of the model prediction.

Now, we need to make our data iterable by using the DataLoader class which is present inside the torch.utils.data package. We use Dataloader to load the data into memory in batches by specifing the batch_size. The shuffle is set to True for training data to load different images each time at each epoch. This will make model more robust and avoid over/underfitting.

Lets take a look at our data.

The model

Now that we have loaded and prepared our data, its now finally time to build a logistic model.

But first let’s understand what is Logistic Regression?

Logistic regression is a linear model which can be subjected for nonlinear transforms. Logistic regression predicts probabilities in the range of ‘0’ and ‘1’. It measures the relationship between dependent variable and one or more independent variables. The dependent variables are binary and the independent variables should have no multicollinearity. A logistic regression model is very similar to a linear regression which is why we will use a linear classifier i.e. nn.Linear in our model.

As you can see in FMnistModel class, the linear classifier accepts input_size which is 28x28 i.e. the size of the images and number of classes which is 10 as mentioned earlier.

In the forward() we use the .reshape method because the shape of our input data is 1x28x28. We need to reshape or flatten them to size 784 by using .reshape. The forward method will take xb as input variable which will then contain resized input shape and xb will be passed to our linear model.

The training_step method will use to generate predictions and calculate the loss based on the batch of data from training set. To calculate the losses during training and validation, I have use cross_entropy. cross_entropy performs logarithmic softmax on classification model whose output value ranges between 0 to 1. The validation step method assesses the accuracy and calculates the loss on a batch of data from validation set. The validation_epoch_end method takes a list of output generated by previous methods and calculates the average on all losses and accuracies.

The training process

Here, we define fit function which will perform the training process. The fit method uses stochastic gradient-descent optimizer. There are many other optimizers provided by torch.optim package such as Adagrad, Adam, Adamax, RMSprop and many more. To know more about optimizers, check out different gradient descent optimization algorithms.

There are two phases, in the training phase we perform gradient descent on batches of data from training dataloader. Then we reset the gradient. In validation phase we store our evaluation results in result variable which inturn is stored in history list. In other words we are storing all losses and accuracies in history for each training epoch.

After training using different learning rates and number of epochs I got a validation accuracy of 0.84 and validation loss of 0.46.

Prediction

As you can see the model has performed well and predicted most of the classes of images right except for sandal (sneaker) . However, the model can be further improved by incorporating convulational neural nets.

Final Words

In this post, we discussed the Fashion-MNIST dataset and how to download it. Then we understood logistic regression and the workings of logistic model. I have also tried a similar experiment on CIFAR10 dataset but the results were too low, if you want to check it out, I have put the link to my source code in the reference section. I’d recommend using a CNN to improve the model efficiency. If you have any doubts, please let me know and don’t hesitate to leave a clap!

--

--