Flower Classification using CNN
We all come across numerous flowers on a daily basis. But we don’t even know their names at times. We all wonder “I wish my computer/mobile could classify this” when we come across a beautiful looking flower. That is the motive behind this article, to classify flower images.
The main objective of this article is to use Convolutional Neural Networks (CNN) to classify flower images into 10 categories
DATASET
Kaggle Dataset — https://www.kaggle.com/olgabelitskaya/flower-color-images
The 10 classes in the dataset are:
- Phlox
- Rose
- Calendula
- Iris
- Leucanthemum maximum (Shasta daisy)
- Campanula (Bellflower)
- Viola
- Rudbeckia laciniata (Goldquelle)
- Peony
- Aquilegia
IMPORTS
I will be using Tensorflow to implement the CNN, Matplotlib to plot graphs and display images, Seaborn to display the heatmap
MODEL
The model consists of 2 Conv2D layers of 128 neurons each along with MaxPooling layers and followed by 2 Dense layers
I have used LeakyReLU here. ReLU might also provide good results here.
The Loss is Categorical Crossentropy and Optimizer is Adam
Model Architecture:
Layer (type) Output Shape Param #
=================================================================
conv2d_4 (Conv2D) (None, 126, 126, 128) 3584
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU) (None, 126, 126, 128) 0
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 63, 63, 128) 0
_________________________________________________________________
dropout_6 (Dropout) (None, 63, 63, 128) 0
_________________________________________________________________
conv2d_5 (Conv2D) (None, 61, 61, 128) 147584
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU) (None, 61, 61, 128) 0
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 30, 30, 128) 0
_________________________________________________________________
dropout_7 (Dropout) (None, 30, 30, 128) 0
_________________________________________________________________
global_max_pooling2d_2 (Glob (None, 128) 0
_________________________________________________________________
dense_4 (Dense) (None, 512) 66048
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU) (None, 512) 0
_________________________________________________________________
dropout_8 (Dropout) (None, 512) 0
_________________________________________________________________
dense_5 (Dense) (None, 10) 5130
_________________________________________________________________
activation_2 (Activation) (None, 10) 0
=================================================================
Total params: 222,346
Trainable params: 222,346
Non-trainable params: 0
Callbacks
I have defined 2 callbacks
- ModelCheckpoint — To save the best model during training
- ReduceLROnPlateau — Reduce the learning rate accordingly during training
Train
The model is being training with a Batch Size = 32 and for 75 Epochs
As we don’t have a large amount of data, we use Image Augmentation to synthesize more data and train our model with it
RESULT
The model reached a validation accuracy of 80.95238% which is quite decent. And we can see that the model did not overfit a lot. So it’s quite a good model
PREDICTIONS
Let’s look at some prediction made by our model on randomly chosen images
Notebook Link: Here
Credit: Rasswanth Shankar