Wheat Disease Detection using Keras

Lakshay Goyal
Analytics Vidhya
Published in
8 min readFeb 16, 2021
Source: Wallpaperflare

Table of Contents

  1. Introduction
  2. Dataset
  3. Libraries
  4. Data Preprocessing
  5. Data Augmentation
  6. Model
  7. Training
  8. Evaluation
  9. Testing

Introduction

India is an area of diversity not only in culture but also in food. India is agricultural land, where 75% of the Indian population is relay on agriculture. Wheat is a big source of minerals such as selenium and magnesium. Some nutrients are necessary for good health. Leaf rust damages the wheat leaf most. The fungal disease, viral disease, and there are so common.

Features of wheat diseases, it has been observed that wheat disease essentially focuses on wheat leaves and can be identified by Deep Learning and computer vision techniques.

Dataset

The dataset used is the Large Wheat Disease Classification Dataset (LWDCD2020). It consists of around 4,500 images of three classes of wheat diseases and one normal class. The images have been curated for dimensional uniformity.

The dataset contains a total of 4 classes as listed below:

  1. Leaf Rust
  2. Crown and Root Rot
  3. Healthy Wheat
  4. Wheat Loose Smut

Note: Dataset images include complex backgrounds, various capture conditions, various characterization for a distinct stage of disease evolution (early, middle and late stage of diseases), and similar features between different wheat diseases.

Libraries

We import all the required libraries needed to process the data and build the classification model.

from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
from keras.layers.pooling import AveragePooling2D
from keras.layers.core import Dropout
from keras.layers.core import Flatten
from keras.layers.core import Dense
from sklearn.preprocessing import LabelBinarizer
from keras.layers import Input
from keras.models import Model
from keras.optimizers import Adam
from keras.models import load_model
from sklearn.metrics import classification_report
from keras.applications import VGG19
from imutils import paths
from collections import deque
import matplotlib.pyplot as plt
import numpy as np
import cv2
import os
import pickle
  • Matplotlib: A plotting library for the Python programming language. It sets the backend so we can output our training plot to a .png image file.
  • Keras: Keras is an open-source neural network library written in Python, running on top of the machine learning platform TensorFlow.
  • Sklearn: From scikit-learn, we’ll use their implementation of a LabelBinarizer for one-hot encoding our class labels. The train_test_split function will segment our dataset into training and testing splits. We’ll also print a classification_report in a traditional format.
  • Numpy: Library consisting of multidimensional array objects and a collection of routines for processing those arrays.
  • Pickle: For serializing our label binarize to disk. The idea is that this character stream contains all the information necessary to reconstruct the object in another python script.
  • cv2 (OpenCV): OpenCV is a library of bindings designed to solve computer vision problems
  • os: The operating system module will be used to ensure we grab the correct file/path separator which is OS-dependent.

Data Preprocessing

Let’s proceed to initialize our LABELS and load our data:

LABELS = set(["Crown and Root Rot", "Healthy Wheat", "Leaf Rust", "Wheat Loose Smut"])imagePaths = list(paths.list_images(dataset))
data = []
labels = []
# loop over the image paths
for imagePath in imagePaths:
# extract the class label from the filename
label = imagePath.split(os.path.sep)[-2]
# if the label of the current image is not part of the labels
# are interested in, then ignore the image
if label not in LABELS:
continue
# load the image, convert it to RGB channel ordering, and resize
# it to be a fixed 224x224 pixels, ignoring aspect ratio
image = cv2.imread(imagePath)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (224, 224))
# update the data and labels lists, respectively
data.append(image)
labels.append(label)
  1. It includes the set of class LABELS for which our dataset will consist of. Each labels not present in this set will be excluded from being part of our dataset. To save on training time, our dataset will only consist of Crown and Root Rot, Healthy Wheat, Leaf Rust. Feel free to work with other classes by making changes to the set.
  2. We initialize our data and label lists, then we’ll begin looping over all imagePaths. In the loop, first, we extract the class label from imagePaths.
  3. Then load and preprocess an image. Preprocessing includes swapping color channels for OpenCV to Keras compatibility and resizing to 224×224px.
  4. The image and label are then added to the data labels lists.

We will one-hot encode our labels and partition our data:

# convert the data and labels to NumPy arrays
data = np.array(data)
labels = np.array(labels)
# perform one-hot encoding on the labels
lb = LabelBinarizer()
labels = lb.fit_transform(labels)
# partition the data into training and testing splits using 75% of
# the data for training and the remaining 25% for testing
(trainX, testX, trainY, testY) = train_test_split(data, labels,
test_size=0.25, stratify=labels, random_state=42)

Data Augmentation

Image data augmentation is a technique that can be used to artificially expand the size of a training dataset by creating modified versions of images in the dataset.

# initialize the training data augmentation object
trainAug = ImageDataGenerator(
rotation_range=30,
zoom_range=0.15,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.15,
horizontal_flip=True,
fill_mode="nearest")
# initialize the validation/testing data augmentation object (which
# we'll be adding mean subtraction to)
valAug = ImageDataGenerator()
# define the ImageNet mean subtraction (in RGB order) and set the
# the mean subtraction value for each of the data augmentation
# objects
mean = np.array([123.68, 116.779, 103.939], dtype="float32")
trainAug.mean = mean
valAug.mean = mean

Note: The trainAug object performs random rotations, zooms, shifts, shears, and flips on our data.

Model

VGG19 is pre-trained on a large dataset (ImageNet) to build image representations. The model reaches around 92.7% top-5 test accuracy in ImageNet. It achieves competitive classification accuracy compared to more complicated nets, even though at the expense of more slow evaluation speed and a lot bigger net size. It is significant for its amazingly basic structure, with all the convolutional layers having a kernel size of 3x3 with stride 1. There are five sets of conv layers, 2 of them have 64 filters, the next set has two conv layers with 128 filters, the next set has four conv layers with 256 filters, and the next two sets have four conv layers each, with 512 filters. There are max-pooling layers in each set of conv layers. Max-pooling layers have 2x2 filters with a stride of 2 (pixels). The output of the last pooling layer is flattened and is fed to a fully connected layer that is utilized for classification with 4096 neurons. The output goes to another fully connected layer with 4096 neurons, whose output is fed into another fully connected layer with 1000 neurons. All these layers are ReLU activated. At last, there is a softmax layer. There are around 138,357,544 parameters in which Trainable parameters are 138,357,544 and Non-trainable parameters are 0.

# load the VGG19 network, ensuring the head FC layer sets are left
# off
headmodel = VGG19(weights="imagenet", include_top=False,
input_tensor=Input(shape=(224, 224, 3)))
# construct the head of the model that will be placed on top of the
# the base model
model = headmodel.output
model = AveragePooling2D(pool_size=(5, 5))(model)
model = Flatten(name="flatten")(model)
model = Dense(512, activation="relu")(model)
model = Dropout(0.4)(model)
model = Dense(len(lb.classes_), activation="softmax")(model)
# place the head FC model on top of the base model (this will become
# the actual model we will train)
moodel = Model(inputs=headmodel.input, outputs=model)
# loop over all layers in the base model and freeze them so they will
# *not* be updated during the training process
for layer in headmodel.layers:
layer.trainable = False

Training

We initialize our optimizer with the learning rate of 1e-3 and learning rate decay. We pick the Adam optimization technique as it almost always works faster and better global minimum convergence as compared to the other optimization techniques.

# compile our model (this needs to be done after our setting our
# layers to being non-trainable)
opt = Adam(lr=1e-3)
moodel.compile(loss="categorical_crossentropy", optimizer=opt,
metrics=["accuracy"])
# train the head of the network for a few epochs (all other layers
# are frozen) -- this will allow the new FC layers to start to become
# initialized with actual "learned" values versus pure random
H = moodel.fit(
trainAug.flow(trainX, trainY, batch_size=64),
steps_per_epoch=len(trainX) // 64,
validation_data=valAug.flow(testX, testY),
validation_steps=len(testX) // 64,
epochs=30)

Evaluation

We plot a graph to compare the maximum accuracy obtained by the model while reducing the loss as the training period.

# evaluate the network
predictions = moodel.predict(testX, batch_size=64)
print(classification_report(testY.argmax(axis=1),
predictions.argmax(axis=1), target_names=lb.classes_))
# plot the training loss and accuracy
N = 30
plt.plot(np.arange(0, N), H.history['accuracy'], label="Training Accuracy")
plt.plot(np.arange(0, N), H.history['val_accuracy'], label="Test Accuracy")
plt.title('VGG19 Model Train vs Test Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(loc='lower right')
plt.show()
plt.savefig(r"E:\Wheat Disease Detection\Accuracy_Plot.png")
plt.plot(H.history['loss'], label="Training Loss")
plt.plot(H.history['val_loss'], label="Test Loss")
plt.title('VGG19 Model Train vs Test Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(loc='upper right')
plt.show()
plt.savefig(r"E:\Wheat Disease Detection\Loss_Plot.png")
Accuracy and Loss Graphs

By reading the above graphs, we see that as the training accuracy increases, validation accuracy increases. Likewise, as the training loss decreases, the validation loss decreases too.

We can achieve better results by tweaking the learning rate or by training on more images or just by simply training the model for more epochs.

We use the evaluate() method and obtain a test accuracy of 97.85%!!

#Save the Model and label file to Disk
moodel.save("E:\Wheat Disease Detection\activity_model.h5")
f = open("label", "wb")
f.write(pickle.dumps(lb))
f.close()

Testing

model_path = "E:\Wheat Disease Detection\activity_model.h5"
input = "E:\Wheat Disease Detection\input_image.png"
label = "E:\Wheat Disease Detection\lb.pickle"
# load the trained model and label binarizer from disk
moodel = load_model(model_path)
lb = pickle.loads(open("label", "rb").read())
# initialize the image mean for mean subtraction along with the
# predictions queue
mean = np.array([123.68, 116.779, 103.939][::1], dtype="float32")
Q = deque(maxlen=128)
vs = cv2.VideoCapture(input)(W, H) = (None, None)while True:
(grabbed, frame) = vs.read()
if not grabbed:
break
if W is None or H is None:
(H, W) = frame.shape[:2]
output = frame.copy()
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = cv2.resize(frame, (224, 224)).astype("float32")
frame -= mean
preds = moodel.predict(np.expand_dims(frame, axis=0))[0]
Q.append(preds)
results = np.array(Q).mean(axis=0)
i = np.argmax(results)
label = lb.classes_[i]
text = "PREDICTION: {}".format(label.upper())
cv2.putText(output, text, (4, 4), cv2.FONT_HERSHEY_SIMPLEX,
0.25, (200,255,155), 2)
# show the output image
cv2.imshow("Output",output)
key = cv2.waitKey(10) & 0xFF

# if the `q` key was pressed, break from the loop
if key == ord("q"):
break
vs.release()

For testing, we randomly choose images and try predicting the class or disease of the wheat image.

The full source code for this project is available on GitHub

Feel free to comment if you have any suggestions or queries.
Thank you for reading!

--

--