Creating a Plant Disease Detector from scratch using Keras

Keval Nagda
8 min readJun 26, 2020

Table of Contents

  1. Introduction
  2. Dataset
  3. Libraries
  4. Data Preprocessing
  5. Data Augmentation
  6. Model
  7. Training
  8. Evaluation
  9. Testing
  10. Reuse

Introduction

Getting affected by a disease is very common in plants due to various factors such as fertilizers, cultural practices followed, environmental conditions, etc. These diseases hurt agricultural yield and eventually the economy based on it.

Any technique or method to overcome this problem and getting a warning before the plants are infected would aid farmers to efficiently cultivate crops or plants, both qualitatively and quantitatively. Thus, disease detection in plants plays a very important role in agriculture.

Due to the limited computational power, it is difficult to train the classification model locally on a majority of normal machines. Therefore, we use the processing power offered by Google Colab notebook as it connects us to a free TPU instance quickly and effortlessly.

So without further ado, let’s dive in!

Dataset

We use a publicly available and quite famous, the PlantVillage Dataset. The dataset was published by crowdAI during the “PlantVillage Disease Classification Challenge”.

The dataset consists of about 54,305 images of plant leaves collected under controlled environmental conditions. The plant images span the following 14 species:

Apple, Blueberry, Cherry, Corn, Grape, Orange, Peach, Bell Pepper, Potato, Raspberry, Soybean, Squash, Strawberry, and Tomato.

The dataset contains a total of 38 classes of plant disease listed below:

  1. Apple Scab
  2. Apple Black Rot
  3. Apple Cedar Rust
  4. Apple healthy
  5. Blueberry healthy
  6. Cherry healthy
  7. Cherry Powdery Mildew
  8. Corn Gray Leaf Spot
  9. Corn Common Rust
  10. Corn healthy
  11. Corn Northern Leaf Blight
  12. Grape Black Rot
  13. Grape Black Measles
  14. Grape Leaf Blight
  15. Grape healthy
  16. Orange Huanglongbing
  17. Peach Bacterial Spot
  18. Peach healthy
  19. Bell Pepper Bacterial Spot
  20. Bell Pepper healthy
  21. Potato Early Blight
  22. Potato healthy
  23. Potato Late Blight
  24. Raspberry healthy
  25. Soybean healthy
  26. Squash Powdery Mildew
  27. Strawberry Healthy
  28. Strawberry Leaf Scorch
  29. Tomato Bacterial Spot
  30. Tomato Early Blight
  31. Tomato Late Blight
  32. Tomato Leaf Mold
  33. Tomato Septoria Leaf Spot
  34. Tomato Two Spotted Spider Mite
  35. Tomato Target Spot
  36. Tomato Mosaic Virus
  37. Tomato Yellow Leaf Curl Virus
  38. Tomato healthy

Note: The dataset also consists of an additional class background to differentiate between leaves and its background features.

The dataset file is stored in Google Drive (GDrive) and can be shared via a URL that provides a unique id. Using this unique id we download the dataset zip file from the GDrive and unzip it.

Downloading dataset

Libraries

We import all the necessary libraries required to process the data and build the classification model.

Importing libraries
  1. NumPy: A library for the Python programming language, adding support for large, multi-dimensional arrays and matrices, along with a large collection of high-level mathematical functions to operate on these arrays.
  2. Pickle: Any object in Python can be pickled so that it can be saved on disk. Pickling is a way to convert a python object (list, dict, etc.) into a character stream. The idea is that this character stream contains all the information necessary to reconstruct the object in another python script.
  3. Cv2 (OpenCV): OpenCV is a library of bindings designed to solve computer vision problems.
  4. Os: The OS module in Python provides functions for creating and removing a directory (folder), fetching its contents, changing and identifying the current directory, etc. It is also possible to automatically perform many operating system tasks.
  5. Sklearn: A free software machine learning library for the Python programming language. It features various classification, regression and clustering algorithms including support vector machines, random forests, gradient boosting, k-means, and DBSCAN, and is designed to interoperate with the Python numerical and scientific libraries NumPy and SciPy.
  6. Keras: Keras is an open-source neural network library written in Python. Designed to enable fast experimentation with deep neural networks, it focuses on being user-friendly, modular, and extensible.
  7. Matplotlib: A plotting library for the Python programming language and its numerical mathematics extension.

Data Preprocessing

First, let’s define a couple of variables required to perform operations on the dataset images.

Defining image dataset variables
  1. We need to resize the raw dataset images to the DEFAULT_IMAGE_SIZE so that it would match the input shape of the primary layer of the neural network.
  2. Each directory of the plant disease dataset folder varies in the number of images. Instead of taking them all, we select the first N_IMAGES from each directory to train our model.
  3. Finally, we set the path of the dataset in the root_dir to access plant images.

Now, we write a function to convert or resize the input dataset images so that they can be fit for training.

Image resizing function
  1. cv2.imread() loads an image from the specified file, if the image exists, otherwise it returns an empty matrix if an image cannot be read.
  2. cv2.resize() changes the dimensions of an image, be it width alone, height alone, or both. Also, the aspect ratio of the original image could be preserved in the resized image.

We then iterate through the plant disease dataset folder, resize the images from each of the folders and convert or load them into a NumPy array.

Loading image dataset
Plant disease classes

After loading the image dataset, we map each label or class of each plant disease to a unique value for the training task. Also, saving this transform to a pickle file will help us later in predicting a label or class of plant disease from the output of the classification model.

Creating label transform

Finally, we split the loaded image dataset into two sets, namely train and test sets with a 0.2 split ratio. Train set to train the classification model and test set to validate the model while training.

Splitting data into training and validation set

Data Augmentation

The data augmentation technique is used to significantly increase the number of images in a dataset. We perform various operations such as shift, rotation, zoom, and flip on the image dataset to diversify our dataset.

Providing augmented images to a model helps it efficiently learn features from different areas of the same image and thus perform better on unseen image data.

Data augmentation

Note: We just make an object of the ImageDataGenerator at the moment, which will later be passed to the model while training.

Model

Here, we define all the hyperparameters of our plant disease classification model. Executing them initially in a separate cell makes it easy for us to tweak them later if needed.

Hyperparameters of neural network

Now, we create a sequential model for the classification task. In the model, we are not only defaulting to “channel_last” architecture but also creating a switch for backends that support “channel_first” on the fourth line.

For the model, we first create a 2D Convolutional layer with 32 filters of 3 x 3 kernel and a ReLU (Rectified Linear Unit) activation. We then perform batch normalization, max pooling, and 25% (0.25) dropout operation in the following layers.

Next, we create two blocks of 2D Convolutional layer with 64 filters and ReLU activation followed by a pooling and dropout layer. We repeat this step for the last set of FC (Fully Connected) layers with 128 filters in the Conv2D layer being the only difference.

Building classification model

Training

Before starting the training of our model, we initialize our optimizer with the learning rate and decay parameters we defined above. We select the Adam optimization technique as it nearly always performs faster and better global minimum convergence as compared to the other optimization techniques.

Compiling classification model
Training classification model

Evaluation

We plot a graph to compare the maximum accuracy achieved by the model while minimizing the loss during the training phase.

Plotting graphs for accuracy and loss
Training v/s validation accuracy
Training v/s validation loss

By studying the above graphs, we observe that as the training accuracy increases, validation accuracy increases. Similarly, as the training loss decreases, the validation loss decreases too.

We can obtain better results by tweaking the learning rate or by training on more images or just by simply training the model for more epochs.

To check the actual (test) accuracy of the model we trained, we use the evaluate() method and obtain a test accuracy of 98.75%!!

Test accuracy

Testing

We write the following predict_disease function to predict the class or disease of a plant image. We just need to provide the complete path to the image and it displays the image along with its prediction class or the plant disease.

Predict function

For testing purposes, we randomly choose images from the dataset and try predicting class or disease of the plant image.

Blueberry healthy
Potato Early blight
Tomato Target spot
Orange Citrus greening

N.B. (Nota Bene): Images for testing here are chosen randomly and coincidentally all of them were predicted correctly. However, this might not always work out as expected or seen here. You may observe false positives or false negatives depending on the quality of the testing image, training epochs, model architecture, and the type of data the model has been trained on.

Reuse

We can reuse our trained model and labels associated with it by saving (dumping) its values in two separate files. These files can be loaded anytime later for plant disease classification from an image.

Saving trained model and labels

Now to reload the trained model and its labels we stored above, we download the files from the GDrive.

Downloading pretrained model and labels

We load the model and labels into model and image_labels respectively.

Loading pretrained model and labels

Finally, we predict the disease of a plant from an image.

Disease prediction

The full source code for this project is available on GitHub and the Google Colab notebook can be viewed here.

Feel free to comment if you have any suggestions or queries.
Thank you for reading!

References

  1. Paper on PlantVillage Dataset.
  2. Plant Disease Detection using Convolutional Neural Network.

--

--