Creating a Plant Disease Detector from scratch using Keras
Table of Contents
- Introduction
- Dataset
- Libraries
- Data Preprocessing
- Data Augmentation
- Model
- Training
- Evaluation
- Testing
- Reuse
Introduction
Getting affected by a disease is very common in plants due to various factors such as fertilizers, cultural practices followed, environmental conditions, etc. These diseases hurt agricultural yield and eventually the economy based on it.
Any technique or method to overcome this problem and getting a warning before the plants are infected would aid farmers to efficiently cultivate crops or plants, both qualitatively and quantitatively. Thus, disease detection in plants plays a very important role in agriculture.
Due to the limited computational power, it is difficult to train the classification model locally on a majority of normal machines. Therefore, we use the processing power offered by Google Colab notebook as it connects us to a free TPU instance quickly and effortlessly.
So without further ado, let’s dive in!
Dataset
We use a publicly available and quite famous, the PlantVillage Dataset. The dataset was published by crowdAI during the “PlantVillage Disease Classification Challenge”.
The dataset consists of about 54,305 images of plant leaves collected under controlled environmental conditions. The plant images span the following 14 species:
Apple, Blueberry, Cherry, Corn, Grape, Orange, Peach, Bell Pepper, Potato, Raspberry, Soybean, Squash, Strawberry, and Tomato.
The dataset contains a total of 38 classes of plant disease listed below:
- Apple Scab
- Apple Black Rot
- Apple Cedar Rust
- Apple healthy
- Blueberry healthy
- Cherry healthy
- Cherry Powdery Mildew
- Corn Gray Leaf Spot
- Corn Common Rust
- Corn healthy
- Corn Northern Leaf Blight
- Grape Black Rot
- Grape Black Measles
- Grape Leaf Blight
- Grape healthy
- Orange Huanglongbing
- Peach Bacterial Spot
- Peach healthy
- Bell Pepper Bacterial Spot
- Bell Pepper healthy
- Potato Early Blight
- Potato healthy
- Potato Late Blight
- Raspberry healthy
- Soybean healthy
- Squash Powdery Mildew
- Strawberry Healthy
- Strawberry Leaf Scorch
- Tomato Bacterial Spot
- Tomato Early Blight
- Tomato Late Blight
- Tomato Leaf Mold
- Tomato Septoria Leaf Spot
- Tomato Two Spotted Spider Mite
- Tomato Target Spot
- Tomato Mosaic Virus
- Tomato Yellow Leaf Curl Virus
- Tomato healthy
Note: The dataset also consists of an additional class background to differentiate between leaves and its background features.
The dataset file is stored in Google Drive (GDrive) and can be shared via a URL that provides a unique id. Using this unique id we download the dataset zip file from the GDrive and unzip it.
Libraries
We import all the necessary libraries required to process the data and build the classification model.
- NumPy: A library for the Python programming language, adding support for large, multi-dimensional arrays and matrices, along with a large collection of high-level mathematical functions to operate on these arrays.
- Pickle: Any object in Python can be pickled so that it can be saved on disk. Pickling is a way to convert a python object (list, dict, etc.) into a character stream. The idea is that this character stream contains all the information necessary to reconstruct the object in another python script.
- Cv2 (OpenCV): OpenCV is a library of bindings designed to solve computer vision problems.
- Os: The OS module in Python provides functions for creating and removing a directory (folder), fetching its contents, changing and identifying the current directory, etc. It is also possible to automatically perform many operating system tasks.
- Sklearn: A free software machine learning library for the Python programming language. It features various classification, regression and clustering algorithms including support vector machines, random forests, gradient boosting, k-means, and DBSCAN, and is designed to interoperate with the Python numerical and scientific libraries NumPy and SciPy.
- Keras: Keras is an open-source neural network library written in Python. Designed to enable fast experimentation with deep neural networks, it focuses on being user-friendly, modular, and extensible.
- Matplotlib: A plotting library for the Python programming language and its numerical mathematics extension.
Data Preprocessing
First, let’s define a couple of variables required to perform operations on the dataset images.
- We need to resize the raw dataset images to the DEFAULT_IMAGE_SIZE so that it would match the input shape of the primary layer of the neural network.
- Each directory of the plant disease dataset folder varies in the number of images. Instead of taking them all, we select the first N_IMAGES from each directory to train our model.
- Finally, we set the path of the dataset in the root_dir to access plant images.
Now, we write a function to convert or resize the input dataset images so that they can be fit for training.
cv2.imread()
loads an image from the specified file, if the image exists, otherwise it returns an empty matrix if an image cannot be read.cv2.resize()
changes the dimensions of an image, be it width alone, height alone, or both. Also, the aspect ratio of the original image could be preserved in the resized image.
We then iterate through the plant disease dataset folder, resize the images from each of the folders and convert or load them into a NumPy array.
After loading the image dataset, we map each label or class of each plant disease to a unique value for the training task. Also, saving this transform to a pickle file will help us later in predicting a label or class of plant disease from the output of the classification model.
Finally, we split the loaded image dataset into two sets, namely train and test sets with a 0.2 split ratio. Train set to train the classification model and test set to validate the model while training.
Data Augmentation
The data augmentation technique is used to significantly increase the number of images in a dataset. We perform various operations such as shift, rotation, zoom, and flip on the image dataset to diversify our dataset.
Providing augmented images to a model helps it efficiently learn features from different areas of the same image and thus perform better on unseen image data.
Note: We just make an object of the
ImageDataGenerator
at the moment, which will later be passed to the model while training.
Model
Here, we define all the hyperparameters of our plant disease classification model. Executing them initially in a separate cell makes it easy for us to tweak them later if needed.
Now, we create a sequential model for the classification task. In the model, we are not only defaulting to “channel_last” architecture but also creating a switch for backends that support “channel_first” on the fourth line.
For the model, we first create a 2D Convolutional layer with 32 filters of 3 x 3 kernel and a ReLU (Rectified Linear Unit) activation. We then perform batch normalization, max pooling, and 25% (0.25) dropout operation in the following layers.
Next, we create two blocks of 2D Convolutional layer with 64 filters and ReLU activation followed by a pooling and dropout layer. We repeat this step for the last set of FC (Fully Connected) layers with 128 filters in the Conv2D layer being the only difference.
Training
Before starting the training of our model, we initialize our optimizer with the learning rate and decay parameters we defined above. We select the Adam optimization technique as it nearly always performs faster and better global minimum convergence as compared to the other optimization techniques.
Evaluation
We plot a graph to compare the maximum accuracy achieved by the model while minimizing the loss during the training phase.
By studying the above graphs, we observe that as the training accuracy increases, validation accuracy increases. Similarly, as the training loss decreases, the validation loss decreases too.
We can obtain better results by tweaking the learning rate or by training on more images or just by simply training the model for more epochs.
To check the actual (test) accuracy of the model we trained, we use the evaluate()
method and obtain a test accuracy of 98.75%!!
Testing
We write the following predict_disease
function to predict the class or disease of a plant image. We just need to provide the complete path to the image and it displays the image along with its prediction class or the plant disease.
For testing purposes, we randomly choose images from the dataset and try predicting class or disease of the plant image.
N.B. (Nota Bene): Images for testing here are chosen randomly and coincidentally all of them were predicted correctly. However, this might not always work out as expected or seen here. You may observe false positives or false negatives depending on the quality of the testing image, training epochs, model architecture, and the type of data the model has been trained on.
Reuse
We can reuse our trained model and labels associated with it by saving (dumping) its values in two separate files. These files can be loaded anytime later for plant disease classification from an image.
Now to reload the trained model and its labels we stored above, we download the files from the GDrive.
We load the model and labels into model
and image_labels
respectively.
Finally, we predict the disease of a plant from an image.
The full source code for this project is available on GitHub and the Google Colab notebook can be viewed here.
Feel free to comment if you have any suggestions or queries.
Thank you for reading!