Teacher-Student Architecture in Plant Disease Classification

Saamahn Mahjouri
5 min readJan 21, 2020

--

In this article, we’ll briefly discuss a method for plant disease classification via the Teacher-Student CNN architecture. This uses two types of classifiers: the first being the “teacher” and the second being the “student”.

We’ll apply multitask learning to train both classifiers at once. The jointly coupled representation between these classifiers is then used to visualize the essential regions detected in image classification.

The outcome of this exercise will yield better results when compared to other methods performed in plant disease classification. In our analysis, we shall leverage the PlantVillage dataset containing 54,306 images of plants.

You can find a link to the arxiv paper here:

https://arxiv.org/pdf/1905.13523.pdf.

Getting Started

First make sure the following packages are installed:

  • tensorflow==1.9.0
  • Keras==2.2.4
  • matplotlib==3.0.2
  • OpenCV==3.0
  • Pillow==5.4.1

Let’s then import the required libraries.

import sys
import os
import numpy as np
import glob
import argparse

from keras import backend as K
from keras import __version__
import cv2

from keras.applications.vgg16 import VGG16, preprocess_input
from keras.models import *
from keras.layers import *
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import SGD
from keras import optimizers
from keras import callbacks
from keras.regularizers import l2,l1
from keras.preprocessing import image

from PIL import Image

import tensorflow as tf
import matplotlib.pyplot as plt
from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot

Next, we’ll download the pretrained weights found at this link:

http://download1500.mediafire.com/7ml55id66qfg/3494uuen1714dqy/black_models_15epochs_weights.h5

wget http://download1500.mediafire.com/7ml55id66qfg/3494uuen1714dqy/black_models_15epochs_weights.h5

We’ll then set up our directory structure.

mkdir images
mkdir visualizations
mkdir model

And then move our downloaded weights to the correct folder.

mv black_models_15epochs_weights.h5 model

You can download the images from the original Github repo and place them into the images directory (or use your own):

https://github.com/Tahedi1/Teacher_Student_Architecture/tree/master/images

Our next step is to definethe paths for the files and directories

images_folder = 'images'
out_folder = 'visualizations'
model_weights_path = 'model/black_models_15epochs_weights.h5'

Building the Teacher-Student Graph

In this next section, our focus will shift towards building the Teacher/Student graph architecture, as explained in the paper.

# Teacher/Student graph.# Teacher ---> Decoder ---> Student
def build_graph(input_shape = (224,224,3),nbr_of_classes=38,view_summary=False):
#Teacher's graph.
base_model1 = VGG16(include_top=False, weights='imagenet',input_shape = input_shape)
x1_0 = base_model1.output
x1_0 = Flatten(name='Flatten1')(x1_0)
x1_1 = Dense(256, name='fc1',activation='relu')(x1_0)
x1_2 = classif_out_encoder1 = Dense(nbr_of_classes, name='out1', activation = 'softmax')(x1_1)
#Decoder's graph.
#Get Teacher's tensors for skip connection.
pool5 = base_model1.get_layer('block5_pool').output
conv5 = base_model1.get_layer('block5_conv3').output
conv4 = base_model1.get_layer('block4_conv3').output
conv3 = base_model1.get_layer('block3_conv3').output
conv2 = base_model1.get_layer('block2_conv2').output
conv1 = base_model1.get_layer('block1_conv2').output
#Inverse fully connected Teacher's layers.
inv_x1_1 = Dense(256, name='inv_x1_1',activation='relu')(x1_2)
merge_x1_1 = Add(name='merge_x1_1')([inv_x1_1,x1_1])
inv_x1_0 = Dense(7*7*512, name='x1_1',activation='relu')(merge_x1_1)
reshaped_inv_x1_0 = Reshape((7, 7,512), name='')(inv_x1_0)
inv_x1_0 = Add(name='merge_x1_0')([reshaped_inv_x1_0,pool5])
#DECONV Block1
up7 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(inv_x1_0))
merge7 = concatenate([conv5,up7], axis = 3)
conv7 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
conv7 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)
#DECONV Block2
up8 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
merge8 = concatenate([conv4,up8], axis = 3)
conv8 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
conv8 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)
#DECONV Block13
up9 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
merge9 = concatenate([conv3,up9], axis = 3)
conv9 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
conv9 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
#DECONV Block14
up10 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv9))
merge10 = concatenate([conv2,up10], axis = 3)
conv10 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge10)
conv10 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv10)
#DECONVBlock15
up11 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv10))
merge11 = concatenate([conv1,up11], axis = 3)
conv11 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge11)
conv11 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv11)
#Reconstructed image refinement
conv11 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv11)
mask = conv11 = Conv2D(3, 1, activation = 'sigmoid',name='Mask')(conv11)

#Graphe of Student
base_model2 = VGG16(include_top=False, weights='imagenet',input_shape = (224,224,3))
x2_0 = base_model2(mask)
x2_0 = Flatten(name='Flatten2')(x2_0)
x2_1 = Dense(256, name='fc2',activation='relu')(x2_0)
classif_out_encoder2 = Dense(nbr_of_classes, name='out2',activation='softmax')(x2_1)

#Get Teacher/Student Model
model = Model(input = base_model1.input, output = [classif_out_encoder1,classif_out_encoder2])
if(view_summary):
print(mode.summary())
#Compile the mode to use multi-task learning
losses = {
"out1": 'categorical_crossentropy',
"out2": 'categorical_crossentropy'
}
alpha=0.4
lossWeights = {"out1": alpha, "out2": (1.0-alpha)}
model.compile(optimizer=optimizers.SGD(lr=1e-4, momentum=0.9), loss=losses, loss_weights=lossWeights,metrics = ['accuracy'])

return model

After building the graph method for our Teacher/Student architecture, we’ll need several helper functions.

Preprocess the images w/ the correct size

def preprocess_image(image_path, image_size = (224, 224)):
img = image.load_img(image_path, target_size=image_size)
img = image.img_to_array(img)
img = np.expand_dims(img, axis=0)
img = preprocess_input(img)
return img

Build the visualization graph w/ the loaded weights

def build_visualization(model_weights_path):
model = build_graph()
model.load_weights(model_weights_path)
layer_name ='Mask'
NewInput = model.get_layer(layer_name).output
visualization = K.function([model.input], [NewInput])

return visualization

Reduce the channels and create the heatmap

def reduce_channels_sequare(heatmap):
channel1 = heatmap[:,:,0]
channel2 = heatmap[:,:,1]
channel3 = heatmap[:,:,2]
new_heatmap = np.sqrt((channel1*channel1)+(channel2*channel2)+(channel3*channel3))

return new_heatmap

Postprocess the visualizations w/ a threshold value and return the new heatmap

def postprocess_vis(heatmap1, threshold = 0.9):
heatmap = heatmap1.copy()
heatmap = (heatmap - heatmap.min())/(heatmap.max() - heatmap.min())
heatmap = reduce_channels_sequare(heatmap)
heatmap = (heatmap - heatmap.min())/(heatmap.max() - heatmap.min())
heatmap[heatmap <= threshold] = 0
heatmap = heatmap*255
return heatmap

Set up the visualization schema w/ the image path and output folder

def visualize_image(visualization, image_path, out_folder):
base = os.path.basename(image_path)
image_name = os.path.splitext(base)[0]
img = preprocess_image(image_path)
vis = visualization([img])[0][0]*255
heatmap = postprocess_vis(vis)
vis_path = os.path.join(out_folder, image_name+'_vis.jpg')
cv2.imwrite(vis_path, vis)
heatmap_path = os.path.join(out_folder, image_name+'_heatmap.jpg')
cv2.imwrite(heatmap_path, heatmap)

Visualize the images given the image folder path and output folder

def visualize_folder(visualization, images_folder, out_folder):
if not os.path.exists(out_folder):
os.makedirs(out_folder
for path in os.listdir(images_folder):
print(path)
image_path = os.path.join(images_folder, path)
visualize_image(visualization, image_path, out_folder)

Finally, we’ll build out the visualization and return the results

visualization = build_visualization(model_weights_path)
visualize_folder(visualization, images_folder, out_folder)

Let’s list the files in the visualization folder to confirm the results

ls visualizations

Plot the visualized results for a quick sanity check

import glob
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
images = []for img_path in glob.glob('visualizations/*.jpg'):
images.append(mpimg.imread(img_path))
plt.figure(figsize=(20,10))
columns = 5
for i, image in enumerate(images):
plt.subplot(len(images) / columns + 1, columns, i + 1)
plt.savefig("result.png")
plt.imshow(image)

This was a quick and dirty look into the paper and our takeaway should be that new architectures are often worth exploring in image classification even while the dataset at hand remains the same.

Here’s a link to the Github repository accompanying the paper:

--

--