Supervised Wheat Classification Using PyTorch’s TorchGeo — Combining Satellite Imagery and Python!

End-to-End Python Based Code for Satellite Imagery: Gathering, Pre-processing, and Training Wheat and Non-Wheat Segmentation Model

Suleman Hamdani
10 min readJul 29, 2023

TorchGeo represents a Python library that empowers users to efficiently collect, preprocess, and train satellite imagery data with Python. In this article, we shall delve into a multitude of functions that can be harnessed to effectively train your model. Specifically, we will focus on a wheat classification model using Landsat 8/9 imagery, while emphasizing the adaptability of these techniques for tackling other classification and segmentation challenges. To gain deeper insights into the library’s capabilities, the research paper titled “TorchGeo: Deep Learning with Geospatial Dataserves as a comprehensive guide.

Undoubtedly, exploring a new library can present various challenges. This article is, therefore, thoughtfully structured into four main sections for a seamless learning experience:

  1. Setting up TorchGeo: Initially, we will show you how to setup TorchGeo on your machine.
  2. Gathering Landsat 8 and CDL Data: This section will walk you through the process of acquiring the essential Landsat 8 and CDL data, providing you with the necessary foundation to proceed with your model training.
  3. Pre-processing the Gathered Data: In the third section, we will delve into the crucial pre-processing steps required to clean, transform, and prepare the gathered data for training. Proper pre-processing is key to achieving accurate and meaningful results.
  4. Model Training: The final section will focus on the actual training of UNET. We will guide you through the implementation and optimization of your wheat classification model, which can be further extended to tackle other classification and segmentation tasks.

1. Setting Up TorchGeo

Setting up TorchGeo is as simple as using the familiar pip install command. By executing the following line, you will have the library installed and ready to use:

!pip install torchgeo

Once TorchGeo is successfully installed, you can import the essential modules for dataset creation and data pre-processing.

from torch.utils.data import DataLoader
from torchgeo.datasets import CDL, Landsat7, Landsat8, Landsat9, stack_samples
from torchgeo.samplers import RandomGeoSampler

torchgeo.datasets is particularly very helpful in loading the relevant data from a directory into a dataset. These datasets are explained along with their use-cases in the second and third sections.

2. Gathering Landsat 8 Imagery

Gathering data is undeniably a crucial yet potentially monotonous aspect of the process. To successfully train your model on the CropLand Data Layer (CDL), you’ll need to collect Landsat 8 or Landsat 9 imagery from the farming regions of the United States of America. Let’s walk through the steps to accomplish this:

  1. Go to the USGS Earth Explorer Website and create an account.
  2. Select the region you’d want to extract satellite imagery from. In our case, it was the western half of Kansas as it is known to have significant winter wheat fields.
A snapshot of USGS Earth Explorer Website — the red region highlights the region that we chose to gather imagery from — source: https://earthexplorer.usgs.gov/

3. Select the month for which you’d want to acquire Landsat imagery. We chose January to mid-July as that covers most of crop cycle of winter wheat, capturing the different stages it goes through. We have chosen 1 tile from each month.

4. Select the cloud cover range in the images. We have only chosen the images with a 0.0% cloud cover for all the months.

5. Select the type of dataset you want to download. We have downloaded the Landsat 8/9 collection 2 — level 2 data.

6. Select the tile and download the imagery along with its bands. You could bulk download data or download the 7 individual bands.

Snapshot of the USGS Earth Explorer Site — Shows an overlapping tile with the selected region — source: https://earthexplorer.usgs.gov/
Snapshot of the USGS Earth Explorer Site — Downloading the tile — source:

7. This process is repeated for all the 7 months mentioned above. This is how your folder will look once you have downloaded and extracted the imagery.

Unzipped Landsat 8 data — image by author

3. Gathering CropLand Data Layer

CropLand Data Layer essentially provides the labels required for Supervised Classification models. This is how the data is defined on the official CropScape website.

The Cropland Data Layer (CDL), hosted on CropScape, provides a raster, geo-referenced, crop-specific land cover map for the continental United States. The CDL also includes a crop mask layer and planting frequency layers, as well as boundary, water and road layers. The Boundary Layer options provided are County, Agricultural Statistics Districts (ASD), State, and Region. The data is created annually using moderate resolution satellite imagery and extensive agricultural ground truth.

Snapshot of CDL website — filtered by ‘Winter Wheat’ in Western Kansas — source: https://croplandcros.scinet.usda.gov/

But how do we download this? This task is made easier with the help of TorchGeo. Simply run the code cell below and it will download the CDL data and store it in the cdl dataset which is an inbuilt dataset class provided by TorchGeo.

cdl = CDL(root="...", download=True, checksum=True)

You may choose to download the CDL data first in a given directory by specifying the path in the root argument and then load it later using the following code.

cdl = CDL(root="your_directory_path_here")

4. Pre-Processing and Loading Landsat 8/9 Data

Now that we have the data downloaded, it’s time to pre-process the Landsat 8/9 data that we had downloaded earlier. In order to load the Landsat data into a torchgeo.datasets dataset, you will have to rename the .tif files of the 7 bands.

Currently, your files will be named something like: LC08_L2SP_030033_20220218_20220301_02_T1_SR_B1

We want to get rid of the _SR_ and replace it with an underscore (_). This is how your renamed Landsat .tif file should like after renaming it: LC08_L2SP_030033_20220218_20220301_02_T1_B1. (Notice how the _SR_ is now removed). Do this for the .tif files of the 7 bands. Once this is done, you can finally load the data into a Landsat 8 or 9 TorchGeo dataset. Run the following code cell and give the path of the directory where your renamed data is stored.

kansas = Landsat8(root="your_directory_path_here", bands=["B1", "B2", "B3", "B4", "B5", "B6", "B7"])

5. Creating the Dataset

Now that we have the imagery and the labels loaded in the dataset, let’s group them together to form a single dataset. This is as easy as ‘&ing’ the two datasets. Just run the following code cell and you’ll have a single dataset which includes the images and the labels.

dataset = cdl & kansas

TorchGeo handles the intersection of the two datasets on its own. You also do not need to worry about the different Coordinate Reference Systems being used as TorchGeo makes sure the 2 CRSs are consistent.

Once the dataset is created, we can access it using TorchGeo’s RandomGeoSampler. In a nutshell, RandomGeoSampler takes the dataset and takes random samples of the size you specify. In our case, we have taken random samples from western part of Kansas of size 256x256. The sampler gives us a patch of the satellite imagery and its corresponding CDL mask (label) as shown in the figure below.

Left — Landsat 8 imagery; Right — Corresponding CDL mask — image by author

You can learn more about the RandomGeoSampler from the official documentation. Here’s the code cell you can run to create the sampler:

sampler = RandomGeoSampler(dataset, size=256, length=100) #Dataset includes the intersection of Kansas & CDL

6. Pre-processing the Dataset

We were able to successfully create a dataset with the satellite imagery and its corresponding labels, now what? Well, we could use these directly into a training loop but there are a few problems — landsat8 sometimes happens to have missing data which should not be fed into the model. Furthermore, there are noisy pixels in the masks which need to be removed for a better accuracy. Therefore, before feeding data into our model, we preprocessed it by first removing all the images with NANs or missing pixels. We semi-automated this process by going through ~150 images for each month and removing any black images.

The code cell below goes through samples from the RandomGeoSampler and prints the satellite imagery. Any faulty images are discarded by the user manually while the rest are saved. The saved images have 5 channels (R, G, B, NIR and NDVI). An extensive literature review showed that these bands + NDVI index performs reasonably well for crop classification models. Since we’re only interested in the wheat pixels from the CDL, every pixel other than the ones with a value 24 (Wheat pixels) is made 255.


while(True):
#clear screen
disp.clear_output()
sample = next(iter(sampler))
datapoint = dataset[sample]

mask = datapoint["mask"].numpy().transpose(1, 2, 0)
mask[mask != 24] = 255

img_7bands = datapoint["image"].numpy().transpose(1, 2, 0)
ndvi = (img_7bands[:, :, 4] - img_7bands[:, :, 3]) / (img_7bands[:, :, 4] + img_7bands[:, :, 3])
img_with_ndvi = np.dstack((img_7bands, ndvi))

#show mask and image side by side
img_rgb = img_with_ndvi[:, :, [3, 2, 1]]
if (np.max(img_rgb) == 0):
img_rgb = img_rgb
else:
img_rgb = img_rgb / np.max(img_rgb)

fig, ax = plt.subplots(1, 2, figsize=(5, 5))
ax[0].imshow(img_rgb)
ax[1].imshow(mask)
plt.show()

user_input = input("Do you want to save the image? (y/n)")
if (user_input == "y"):

img_5_bands_ndvi = img_with_ndvi[:, :, [1, 2, 3, 4, 7]]
# change mask to unsigned int 8
mask = mask.astype(np.uint8)

img_list.append(img_5_bands_ndvi)
mask_list.append(mask)

mask2 = mask_list[count].copy()
im2 = img_list[count][:,:,[2,1,0]].copy()
im2 = im2/np.max(im2)
fig, ax = plt.subplots(1, 2, figsize=(5, 5))
ax[0].imshow(im2)
ax[1].imshow(mask2)
plt.show()
count += 1
# wait for 2 seconds
time.sleep(2)

elif (user_input == "n"):
continue

elif (user_input == "q"):
break
Landsat 8 imagery with missing data — such images were discarded — image by author

The last pre-processing step for the dataset includes the denoising and morphological operations on the mask. This reduces the noisy pixels in the mask which can affect the model’s accuracy. The figure below shows the before and after denoising & closing of the mask.

Left — Before; Right — After — image by author

7. Combining Data for 7 Months

While spatial features are crucial for an effective satellite imagery classification model, temporal features are equally vital, especially for crop classification. Capturing the entire crop cycle is essential. The steps shown above were therefore repeated for all 7 months. Data for each month (Imagery + Masks) was saved separately and then concatenated to obtain 834 data points. 12 images still contained NAN values which were, therefore, removed from the dataset. We stored the remaining 822 images and masks in numpy arrays.

The following code blocks show the concatenation of images from each month followed by the removal of NANs.

img_arrays = [jan_img, feb_img, march_img, april_img, may_img, june_img, july_img]
final_data_img = np.concatenate(img_arrays, axis = 0)

mask_arrays = [jan_mask, feb_mask, march_mask, april_mask, may_mask, june_mask, july_mask]
final_data_mask = np.concatenate(mask_arrays, axis = 0)
for i in range (len(images)):
if np.isnan(images[i]).any():
print(i) # Printing index of images containing NANs
images = np.delete(images, [17, 26, 27, 41, 50, 214, 335, 404, 453, 460, 563, 708], axis=0)
masks = np.delete(masks, [17, 26, 27, 41, 50, 214, 335, 404, 453, 460, 563, 708], axis=0) #Deleting images and masks containing NANs

8. Model Training

Once we had generated our numpy arrays for the images and masks, we trained a UNET model on it. The model training was done on Google Colab. The input shape to the model was 256x256x5 with 2 output classes. The following UNET architecture was used:

def multi_unet_model(n_classes=2, image_height=256, image_width=256, image_channels=5):

inputs = Input((image_height, image_width, image_channels))

source_input = inputs

c1 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(source_input)
c1 = Dropout(0.2)(c1)
c1 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c1)
p1 = MaxPooling2D((2,2))(c1)

c2 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p1)
c2 = Dropout(0.2)(c2)
c2 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c2)
p2 = MaxPooling2D((2,2))(c2)

c3 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p2)
c3 = Dropout(0.2)(c3)
c3 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c3)
p3 = MaxPooling2D((2,2))(c3)

c4 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p3)
c4 = Dropout(0.2)(c4)
c4 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c4)
p4 = MaxPooling2D((2,2))(c4)

c5 = Conv2D(256, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p4)
c5 = Dropout(0.2)(c5)
c5 = Conv2D(256, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c5)

u6 = Conv2DTranspose(128, (2,2), strides=(2,2), padding="same")(c5)
u6 = concatenate([u6, c4])
c6 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u6)
c6 = Dropout(0.2)(c6)
c6 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c6)

u7 = Conv2DTranspose(64, (2,2), strides=(2,2), padding="same")(c6)
u7 = concatenate([u7, c3])
c7 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u7)
c7 = Dropout(0.2)(c7)
c7 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c7)

u8 = Conv2DTranspose(32, (2,2), strides=(2,2), padding="same")(c7)
u8 = concatenate([u8, c2])
c8 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u8)
c8 = Dropout(0.2)(c8)
c8 = Conv2D(32, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c8)

u9 = Conv2DTranspose(16, (2,2), strides=(2,2), padding="same")(c8)
u9 = concatenate([u9, c1], axis=3)
c9 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u9)
c9 = Dropout(0.2)(c9)
c9 = Conv2D(16, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c9)

outputs = Conv2D(n_classes, (1,1), activation="softmax")(c9)

model = Model(inputs=[inputs], outputs=[outputs])
return model

The complete Colab notebook for model training is attached below.

The UNET model achieved a commendable training accuracy of 94% and a validation accuracy of 90%. These results demonstrate its effectiveness in not only capturing spatial information but also in learning the temporal patterns exhibited by wheat during its crop cycle.

The high training accuracy indicates that the model has successfully learned from the training data, generalizing well to predict crop classes accurately. Meanwhile, the validation accuracy, though slightly lower, confirms that the model can effectively generalize to new, unseen data, making it reliable for real-world applications.

The ability to capture temporal patterns in wheat’s crop cycle is of utmost importance for crop classification tasks. By understanding the dynamic changes that occur over time, the model can discern different growth stages, account for seasonal variations, and make contextually informed predictions.

Conclusion

TorchGeo is a library with a lot of potential especially for people who are comfortable with Python and its syntax. It gives users a lot of control over the data and the ability to practically do anything with it. Currently, TorchGeo has been explored by very few people making it difficult to understand the variety of function it offers. This article, we believe, gives enough material for users to explore its full potential!

The code for the project is available at GitHub. Happy learning!

Undertaken with the support of the Smart AgriTech lab, EmbeddedAIoT lab, and under the guidance of our founder Dr. Shahzad Younis, the blog provides a detailed account of the process – from collecting multi-spectral imagery from Landsat 8 to training a segmentation model. It serves as a comprehensive, step-by-step guide to effectively utilizing TorchGeo for wheat classification, representing one of the few published works for TorchGeo.

--

--

Suleman Hamdani

Using Computer Science and Artificial Intelligence to spread knowledge. Advocating Open-Source software and data!