Teacher-Student Architecture in Plant Disease Classification
In this article, we’ll briefly discuss a method for plant disease classification via the Teacher-Student CNN architecture. This uses two types of classifiers: the first being the “teacher” and the second being the “student”.
We’ll apply multitask learning to train both classifiers at once. The jointly coupled representation between these classifiers is then used to visualize the essential regions detected in image classification.
The outcome of this exercise will yield better results when compared to other methods performed in plant disease classification. In our analysis, we shall leverage the PlantVillage dataset containing 54,306 images of plants.
You can find a link to the arxiv paper here:
https://arxiv.org/pdf/1905.13523.pdf.
Getting Started
First make sure the following packages are installed:
- tensorflow==1.9.0
- Keras==2.2.4
- matplotlib==3.0.2
- OpenCV==3.0
- Pillow==5.4.1
Let’s then import the required libraries.
import sys
import os
import numpy as np
import glob
import argparse
from keras import backend as K
from keras import __version__
import cv2
from keras.applications.vgg16 import VGG16, preprocess_input
from keras.models import *
from keras.layers import *
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import SGD
from keras import optimizers
from keras import callbacks
from keras.regularizers import l2,l1
from keras.preprocessing import image
from PIL import Image
import tensorflow as tf
import matplotlib.pyplot as plt
from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot
Next, we’ll download the pretrained weights found at this link:
http://download1500.mediafire.com/7ml55id66qfg/3494uuen1714dqy/black_models_15epochs_weights.h5
wget http://download1500.mediafire.com/7ml55id66qfg/3494uuen1714dqy/black_models_15epochs_weights.h5
We’ll then set up our directory structure.
mkdir images
mkdir visualizations
mkdir model
And then move our downloaded weights to the correct folder.
mv black_models_15epochs_weights.h5 model
You can download the images from the original Github repo and place them into the images directory (or use your own):
https://github.com/Tahedi1/Teacher_Student_Architecture/tree/master/images
Our next step is to definethe paths for the files and directories
images_folder = 'images'
out_folder = 'visualizations'
model_weights_path = 'model/black_models_15epochs_weights.h5'
Building the Teacher-Student Graph
In this next section, our focus will shift towards building the Teacher/Student graph architecture, as explained in the paper.
# Teacher/Student graph.# Teacher ---> Decoder ---> Student
def build_graph(input_shape = (224,224,3),nbr_of_classes=38,view_summary=False):#Teacher's graph.
base_model1 = VGG16(include_top=False, weights='imagenet',input_shape = input_shape)
x1_0 = base_model1.output
x1_0 = Flatten(name='Flatten1')(x1_0)
x1_1 = Dense(256, name='fc1',activation='relu')(x1_0)
x1_2 = classif_out_encoder1 = Dense(nbr_of_classes, name='out1', activation = 'softmax')(x1_1) #Decoder's graph.
#Get Teacher's tensors for skip connection.
pool5 = base_model1.get_layer('block5_pool').output
conv5 = base_model1.get_layer('block5_conv3').output
conv4 = base_model1.get_layer('block4_conv3').output
conv3 = base_model1.get_layer('block3_conv3').output
conv2 = base_model1.get_layer('block2_conv2').output
conv1 = base_model1.get_layer('block1_conv2').output #Inverse fully connected Teacher's layers.
inv_x1_1 = Dense(256, name='inv_x1_1',activation='relu')(x1_2)
merge_x1_1 = Add(name='merge_x1_1')([inv_x1_1,x1_1])
inv_x1_0 = Dense(7*7*512, name='x1_1',activation='relu')(merge_x1_1)
reshaped_inv_x1_0 = Reshape((7, 7,512), name='')(inv_x1_0)
inv_x1_0 = Add(name='merge_x1_0')([reshaped_inv_x1_0,pool5]) #DECONV Block1
up7 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(inv_x1_0))
merge7 = concatenate([conv5,up7], axis = 3)
conv7 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
conv7 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7) #DECONV Block2
up8 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
merge8 = concatenate([conv4,up8], axis = 3)
conv8 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
conv8 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8) #DECONV Block13
up9 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
merge9 = concatenate([conv3,up9], axis = 3)
conv9 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
conv9 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9) #DECONV Block14
up10 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv9))
merge10 = concatenate([conv2,up10], axis = 3)
conv10 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge10)
conv10 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv10) #DECONVBlock15
up11 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv10))
merge11 = concatenate([conv1,up11], axis = 3)
conv11 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge11)
conv11 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv11) #Reconstructed image refinement
conv11 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv11)
mask = conv11 = Conv2D(3, 1, activation = 'sigmoid',name='Mask')(conv11)
#Graphe of Student
base_model2 = VGG16(include_top=False, weights='imagenet',input_shape = (224,224,3))
x2_0 = base_model2(mask)
x2_0 = Flatten(name='Flatten2')(x2_0)
x2_1 = Dense(256, name='fc2',activation='relu')(x2_0)
classif_out_encoder2 = Dense(nbr_of_classes, name='out2',activation='softmax')(x2_1)
#Get Teacher/Student Model
model = Model(input = base_model1.input, output = [classif_out_encoder1,classif_out_encoder2])
if(view_summary):
print(mode.summary())
#Compile the mode to use multi-task learning
losses = {
"out1": 'categorical_crossentropy',
"out2": 'categorical_crossentropy'
}
alpha=0.4
lossWeights = {"out1": alpha, "out2": (1.0-alpha)}
model.compile(optimizer=optimizers.SGD(lr=1e-4, momentum=0.9), loss=losses, loss_weights=lossWeights,metrics = ['accuracy'])
return model
After building the graph method for our Teacher/Student architecture, we’ll need several helper functions.
Preprocess the images w/ the correct size
def preprocess_image(image_path, image_size = (224, 224)):
img = image.load_img(image_path, target_size=image_size)
img = image.img_to_array(img)
img = np.expand_dims(img, axis=0)
img = preprocess_input(img) return img
Build the visualization graph w/ the loaded weights
def build_visualization(model_weights_path):
model = build_graph()
model.load_weights(model_weights_path)
layer_name ='Mask'
NewInput = model.get_layer(layer_name).output
visualization = K.function([model.input], [NewInput])
return visualization
Reduce the channels and create the heatmap
def reduce_channels_sequare(heatmap):
channel1 = heatmap[:,:,0]
channel2 = heatmap[:,:,1]
channel3 = heatmap[:,:,2]
new_heatmap = np.sqrt((channel1*channel1)+(channel2*channel2)+(channel3*channel3))
return new_heatmap
Postprocess the visualizations w/ a threshold value and return the new heatmap
def postprocess_vis(heatmap1, threshold = 0.9):
heatmap = heatmap1.copy()
heatmap = (heatmap - heatmap.min())/(heatmap.max() - heatmap.min())
heatmap = reduce_channels_sequare(heatmap)
heatmap = (heatmap - heatmap.min())/(heatmap.max() - heatmap.min())
heatmap[heatmap <= threshold] = 0
heatmap = heatmap*255 return heatmap
Set up the visualization schema w/ the image path and output folder
def visualize_image(visualization, image_path, out_folder):
base = os.path.basename(image_path)
image_name = os.path.splitext(base)[0]
img = preprocess_image(image_path)
vis = visualization([img])[0][0]*255
heatmap = postprocess_vis(vis)
vis_path = os.path.join(out_folder, image_name+'_vis.jpg')
cv2.imwrite(vis_path, vis) heatmap_path = os.path.join(out_folder, image_name+'_heatmap.jpg')
cv2.imwrite(heatmap_path, heatmap)
Visualize the images given the image folder path and output folder
def visualize_folder(visualization, images_folder, out_folder):
if not os.path.exists(out_folder):
os.makedirs(out_folder
for path in os.listdir(images_folder):
print(path)
image_path = os.path.join(images_folder, path)
visualize_image(visualization, image_path, out_folder)
Finally, we’ll build out the visualization and return the results
visualization = build_visualization(model_weights_path)
visualize_folder(visualization, images_folder, out_folder)
Let’s list the files in the visualization folder to confirm the results
ls visualizations
Plot the visualized results for a quick sanity check
import glob
import matplotlib.pyplot as plt
import matplotlib.image as mpimgimages = []for img_path in glob.glob('visualizations/*.jpg'):
images.append(mpimg.imread(img_path))plt.figure(figsize=(20,10))
columns = 5for i, image in enumerate(images):
plt.subplot(len(images) / columns + 1, columns, i + 1)
plt.savefig("result.png")
plt.imshow(image)
This was a quick and dirty look into the paper and our takeaway should be that new architectures are often worth exploring in image classification even while the dataset at hand remains the same.
Here’s a link to the Github repository accompanying the paper: