Can Machines Learn and Predict? — Training a Deep Neural Network using Pytorch for Iris Data set.

ARNOLD SACHITH A HANS
Analytics Vidhya
Published in
4 min readMar 25, 2020

Using Python Program, train a Deep Neural Network for Iris Dataset and Predict the class of Iris plant.

In this fast moving world where the Machines are trained and used to predict the outcomes, the applications which involves the contribution of Artificial Intelligence has been increasing rapidly at a faster rate.

This post takes you through training a Deep Neural Network for one of the famous/basic dataset — IRIS DATASET. Finally the network will be able to predict the species of Iris based on 4 different physical parameters which will be gives as input.

Photo by Ben Mater on Unsplash

The following steps will be followed in the process:

  1. What is Iris Dataset ?— Let’s make an attempt to understand the dataset so that we can know which attributes will serve as the input and what exactly should be the output from the network.
  2. Splitting the dataset — Split the dataset into Training, Validation and Test set.
  3. Developing Layers for Neural Network — We will create the Deep neural network which includes Input layer, Hidden Layer, and Output layer.
  4. Defining the training parameters — The most important part of a DNN is defining the hyper parameters (Learning rate, batch size, number of epochs etc;), activation function, loss function which we will be discussing as we write the program.
  5. Train the network — After completing all the above process we will then train the Deep Neural Network.
  6. Prediction — Based on the accuracy obtained we will give a random value and check whether the network is predicting the correct values.

What exactly is a Iris Dataset?

The Iris dataset has 150 rows and 5 columns (excluding the ID column) and it looks something like this:

Starting 5 rows of the dataset has been shown.

Iris Dataset consists of 4 physical attributes of the flower namely:

  1. sepal length in cm
  2. sepal width in cm
  3. petal length in cm
  4. petal width in cm

and one attribute to be predicted that is the “Species” column — the class of the Iris plant. In the dataset provided there are three species namely “Iris-setosa”, “Iris- versicolor”, “Iris-virginica”.

So, here the 4 physical attributes will serve as the input to the network and the network will give one output based on the input provided whether it belongs to “Iris-setosa”or “Iris- versicolor” or “Iris-virginica”.

Let’s start coding by importing the necessary packages:

[Any Additional information if required the same is provided by # commenting in the program]

Defining the Deep Neural Network:

Defining the Hyper parameters:

Choosing the Optimizer, Loss function and Splitting the data into test, validation and test set:

Check how many records are there in each set by displaying its shape:

x_train has 120 rows and 4 columns, while x_test and validation set (x_val) both of them have 15 rows and 4 columns. Looking at y_train, y_test and y_val it is 120, 15, 15 records respectively. This how your output should look like:

Create the dataloaders for Pytorch:

Define a function to train the model:

Similarly define a function to test the model:

It’s time to train the model:

Now, during the training process you can see the training loss and validation loss for each and every epoch, it looks something like this:

Let’s visualize how the training loss and validation loss is varying at each epoch:

Looking at the numbers assigned to each weighted variable:

Check how your model is working for test set:

Further we use the Confusion matrix as a evaluation factor for test set:

This is how the confusion matrix looks like:

Finally, It’s PREDICTION TIME:

You can change the values in x_new variable and check whether your model predicts the correct class of Iris plant.

That’s how you train a Deep Neural Network and try to predict using the model built.

Any suggestions with respect to the post is always welcomed.

Feel free to connect with me through LinkedIn, Instagram or Facebook.

Cheers :)

Arnold Sachith

--

--

ARNOLD SACHITH A HANS
Analytics Vidhya

An Aspiring AI engineer|M.Tech (Artificial Intelligence)|B.E (Mechatronics Engineering)| Writer| Robots Rule| AI for the betterment of the society|