Segmentation of spectral images with deep learning using Keras

Antón Garcia
Abraia
6 min readMay 12, 2021

--

In this post, we’ll see how to train and test a 3D deep learning model for HSI segmentation using keras. We start with a short discussion on the best type of models for hyperspectral images and then we’ll go through the three main steps: data preparation, model training, and model validation.

Deep models for hyperspectral images

Compared to simple approaches based on a SVM classifier and purely spectral features, deep models manage to learn features that combine both spectral and spatial information. Since many surfaces are characterized more by chromatic textures rather than just simple chromatic signatures, this capability confers a key advantage in many applications.

While deep learning has shown a big impact on image segmentation problems, a proper generalization to hyperspectral and multispectral imaging applications hasn’t been straightforward. However, recent advances have confirmed a clear advantage of deep CNNs, making deep models the state of the art for spectral image classification and segmentation. Unlike with normal images, in HSI works we don’t find the same boundary between segmentation and classification models. Usually, segmentation is performed by applying classification models on a pixel by pixel basis. This reflects the lower maturity in this field.

There’re two main types of deep models applied to HSI pictures: 2D and 3D. 2D models use 2D CNNs, just like those commonly used with RGB images. They’re able to deal with 2 spatial dimensions -like any image- and up to 3 chromatic components. 3D models use 3D CNNs, which are capable to handle the two spatial dimensions and one additional dimension for spectral information. This means 3D models can handle as many spectral components as we like, of course at the prize of a higher computational effort. Broadly speaking, 2D models are faster but 3D models are potentially more accurate.

In this post we’ll use a hybrid model that combines 3D CNNs followed by 2D CNNs, but the basic steps described remain valid for any other choice of topology. We provide a ready-to-use notebook to reproduce everything at the end of the post.

Preparation of HSI data

We’ll use the Indian Pines dataset, consisting of 145x145 pixels and 200 spectral bands, and annotations corresponding to 16 different categories related to the use of the land.

Indian pines dataset- six random bands (top and centre), ground truth annotations (bottom left), and categories (bottom right)

We have already described and made a short analysis of this dataset in a previous post on spectral image classification with a SVM.

As explained there, spectral bands in the IP dataset (like in most HSI images) are highly correlated, so reducing redundancy makes perfect sense. After loading the dataset, we start by reducing the number of spectral components through PCA.

## Load IP dataset (X: samples, y: annotations) and reduce spectral componentsX, y = loadData('IP')X,pca = applyPCA(X,numComponents=30)

Then, we should get annotated image points and extract image patches around them that we’ll use to train and validate our network. We set the size of these patches to 25 pixels, that is we extract cubes of 25x25x30, around each pixel in the image.

X, y = createImageCubes(X, y, windowSize=25)

Next, we should split the data in separated sets, one for training and another one for validation.

Xtrain, Xtest, ytrain, ytest = splitTrainTestSet(X, y, test_size=testRatio)

We use a test ratio of 0.7, which means training the model with 30% of data and validating with the remaining 70%.

Defining and training the model

Now we should define the model topology. Here, we’ll use HybridSN, proposed by SK Roy et al, which combines 3D CNN and 2D CNN layers.

We should define the layers. We start by the input layer, using the spectral cube dimensions of our data, that is the patch size (S=25) and the reduced number of spectral components (L=30) already specified.

## input layerinput_layer = Input((S, S, L, 1))

Then, we proceed with the rest of layers

## CNN layersconv_layer1 = Conv3D(filters=8, kernel_size=(3, 3, 7), activation='relu')(input_layer)conv_layer2 = Conv3D(filters=16, kernel_size=(3, 3, 5), activation='relu')(conv_layer1)conv_layer3 = Conv3D(filters=32, kernel_size=(3, 3, 3), activation='relu')(conv_layer2)## Reshape
conv3d_shape = conv_layer3.shape
conv_layer3 = Reshape((conv3d_shape[1], conv3d_shape[2], conv3d_shape[3]*conv3d_shape[4]))(conv_layer3)
## 2D CNN
conv_layer4 = Conv2D(filters=64, kernel_size=(3,3), activation='relu')(conv_layer3)
flatten_layer = Flatten()(conv_layer4)## FC layersdense_layer1 = Dense(units=256, activation='relu')(flatten_layer)
dense_layer1 = Dropout(0.4)(dense_layer1)
dense_layer2 = Dense(units=128, activation='relu')(dense_layer1)
dense_layer2 = Dropout(0.4)(dense_layer2)
## Output layeroutput_layer = Dense(units=output_units, activation='softmax')(dense_layer2)

Next, we should define the model and compile it specifying an optimizer.

# Define and compilemodel = Model(inputs=input_layer, outputs=output_layer)adam = Adam(lr=0.001, decay=1e-06)model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy'])

Once our model is ready, we can train it. Here we should provide only the training data. We need to specify the batch size, that is the number of samples from the training set that will be propagated through the network before the internal parameters are updated. We also need to set the number of epochs, that is the number of times that the entire training dataset is used to update the model parameters. Tipycally, an epoch consists of a number of batches.

history = model.fit(x=Xtrain, y=ytrain, batch_size=256, epochs=100, callbacks=callbacks_list)

After training, we may visualize the learning curve, which shows how the accuracy of the model evolved over training epochs.

The learning curve shows how model accuracy has evolved over successive epochs

We should note that this training accuracy measurement has been obtained on the training data and should not be taken as a final validation accuracy. But it helps to visualize the stability of the learning stage and to check that the number of epochs used is ok.

Model validation

To validate the model, we must face it with the validation set.

Y_pred_test = model.predict(Xtest)

As we have hold this set apart, the model didn’t see these samples during training. This is important, because validation allows to discard overfitting issues (when the model simply memorises the data) and assure that the model has a predictive capability that works with unseen data. That is, validation provides a more reliable measurement of accuracy. We can look at the validation results using Scikitlearn classification report.

y_pred_test = np.argmax(Y_pred_test, axis=1)classification = classification_report(np.argmax(ytest, axis=1), y_pred_test)print(classification)

The classification report provides the metrics that characterize the performance of our model: precision, recall, and f-score. The precision metric shows the percentage of samples correctly classified under each category. The recall metric shows the percentage of samples from a given category that have been detected by our model. The f-score is the harmonic mean of precision and recall and provides an overall metric of model performance that relates to both accuracy and sensitivity.

Finally, we can segment an image by applying our deep HSI classifier on a pixel-by-pixel basis. That is, we apply the model to every pixel on the Indian Pines HSI scene and then we assess the quality of our segmentation by visual comparison to the ground truth.

A deep HSI model produces a highly accurate segmentation on the IP scene

When we compare this result to what we get using a SVM classifier on the same HSI image, we immediately notice the big improvement in performance.

Final comment on public HSI datasets

In this post, we have used the IP dataset and followed training and validation practices, as described in state-of-the-art work on HSI (see for intance Li et al.). However, there is a problem with the fact that the IP dataset is a single hyperspectral image and the way to split training and validation data described that may raise questions on the reliability of validation. While we used different image patches only for training or for validation, ideally we should use patches from different images that don’t overlap. But this is also an issue with the other datasets widely used in state-of-the-art studies on deep models applied to HSI, like the University of Pavia dataset or the Salinas dataset.

Notebook for reproduction

To facilitate reproduction and avoid the need of any installation or download, we provide a simple python notebook that follows all the steps (loading and manipulating HSI data as needed) and can be executed with Google Colab. The notebook makes use of state-of-the-art libraries wraped together with Abraia’s MULTIPLE API.

References

Li, Shutao, et al. “Deep learning for hyperspectral image classification: An overview.” IEEE Transactions on Geoscience and Remote Sensing 57.9 (2019): 6690–6709.

Roy, Swalpa Kumar, et al. “HybridSN: Exploring 3-D–2-D CNN feature hierarchy for hyperspectral image classification.” IEEE Geoscience and Remote Sensing Letters 17.2 (2019): 277–281

--

--

Antón Garcia
Abraia
Editor for

Providing tools to analyse, process, transcode, and deliver images and videos, on the shoulders of state-of-the-art cloud, media and AI technology.