Authors: Aman Dalmia and Vishal Agarwal, Wadhwani AI
Cotton is a major fibre crop across the world, cultivated in over 80 countries with nearly 100 million families across the world rely on cotton farming for their livelihood. With such importance placed on many farmers’ crops, cotton’s particular vulnerability to pest infestations has been troubling to many. However, pest infestation is also simultaneously one of the most significant and preventable problems that farmers face with 55% of all pesticide usage in India being devoted to cotton farming.
AI for Pest Monitoring
At Wadhwani AI, we are dedicated to applying AI to solve problems for social good. When it comes to combating and preventing crop attacks, pest traps can be used as an early warning system, with the number of pests captured used as a proxy for infestation. By applying deep learning on trap images, Wadhwani AI is able to provide detailed spray recommendations to farmers in order to proactively address specific pest issues.
By using PyTorch, Wadhwani AI researchers have been able to create a model that is able to accurately predict the location of pests within cotton crops. Once an image is captured, it passes through a multi-task network which verifies whether the image is valid. If it is, the detection branch provides the potential locations of the detected pests. The final recommendation is given based on the number of pests detected, and rules laid out by entomologists.
The overall process is outlined with further detail below:
Another advantage PyTorch is able to offer is through its ability to use offline inference. As many smallholder farms are located in rural areas, many farmers’ access to network coverage is poor and limited. To help improve user experience, Wadhwani AI is able to deploy their models by using model compression algorithms to reduce their model size by 98% from 250MB to less than 5MB. A recent paper by our team published in KDD 2020 talks about our journey, the challenges we faced and what we learnt.
Building the Machine Learning Solution
Our work is enabled by a host of tools, particularly, PyTorch for developing and deploying our models, and Weights & Biases for experiment tracking and communicating results to relevant stakeholders. A comprehensive summary of all the tools we use has been shown below.
The Multi-task Model
The core of our solution is object detection. Broadly, object detection models can be divided into two categories, two-stage and single-stage detectors. Two-stage architectures are generally more accurate than single-stage approaches, albeit slower. As inference speed is paramount to our user experience, we use a single-stage approach.
Single Shot MultiBox Detector (SSD) was one of the first single-stage detectors to give comparable performance to its two-stage counterparts while being fast. We faced a significant challenge as our dataset is very different from standard object detection datasets. Our images are of poorer quality as they are captured from smartphones in remote villages. The pests often tend to cluster together, often making it hard to identify boundaries. These boundaries — like limbs and wings — which are the key features for classification, end up not being present in the input image.
The objective function used to train SSD is called MultiBox loss. It consists of a localization and a classification component to account for the tightness of the bounding box and the accuracy of the predicted class respectively. This is where experiment tracking comes in very handy to understand how different aspects of the loss functions are getting affected by changes in optimization, data augmentation, network architecture and input size. For example, it can be seen from the image below that the overall loss (train/loss) is always dominated by the classification loss (train/loss_c) and not the localization loss (train/loss_l).
Another way to observe the evolution of model performance is by visualizing how the predictions evolve over time. Weights & Biases helps us do this as we can track how the predicted bounding boxes change over time as shown below:
Upon deploying our model in the field, we found that it was common for users to submit images that did not contain any pest traps. Deep learning models are known to perform poorly on out-of-distribution images, often being overly confident about its mispredictions. To explicitly tackle this issue — and after utilizing various different approaches — a separate classifier was trained to reject outliers for a wide variety of out-of-distribution images. Further, a VGGNet model was trained to use images from the COCO dataset as the invalid class and images from our pest dataset as the valid class.
However, during inference, the image had to pass through two networks, significantly increasing the inference time. Since both the image validation and object detection models used the same base network (VGGNet) which accounted for the majority of the inference time, we combined them together by adding a new branch from the last layer of the base network in SSD for image validation so that we have a single model. The network was jointly trained in a multi-task learning setup with the detection and validation heads being trained by the MultiBox loss and cross-entropy respectively.
The most common approach to evaluating object detection networks is by using Average Precision (AP) per class. However, the relationship between AP and the goal of this solution is not always straightforward. AP is defined using a range of confidence thresholds for each class whereas, during deployment, we need to select a single operating point. We treat the final pesticide spray recommendation task as a binary classification problem where the classes indicate whether to spray. Consequently, we redefine false positives and false negatives to enable clearer communication with the relevant stakeholders, who, generally, find it hard to interpret precision and recall, as the following:
- Missed alarm rate (MAR): percentage of cases where the recommendation should have been to spray, but the system suggested otherwise.
- False alarm rate (FAR): percentage of cases where no action should be taken, but the system suggested spraying.
We set a goal of achieving < 5% MAR and FAR as both false positives and false negatives are harmful to the farmer. Hyperparameter tuning and the optimal confidence threshold selection for deployment was done towards this joint objective.
We specifically found the DataFrame plot within Weights & Biases to be extremely helpful in understanding the predictions at an image level during inference.
It helps us visually compare the predicted and ground truth bounding boxes for any image, while also allowing us to sort, group and filter data as per our requirement. For example, the predictions can be sorted based on the confidence score to qualitatively evaluate low confidence input data.
The Parameter Importance plot in Weights & Biases helps to identify the most impactful hyperparameters with respect to the val loss.
Even though the primary motivation for model compression is deployment on-phone, it proves beneficial for deployment on-cloud as well since it runs significantly faster. We use iterative pruning (as outlined in this paper) and quantization to reduce our model size from 265MB to 5MB. The paper introduces the following technique for classification networks.
The figure above shows how one filter is pruned. Let’s look at two consecutive layers, L and L+1. The number of input and output channels for layer L is K and N respectively. Since the output of layer L is the input to layer L+1, the number of input channels for layer L+1 is also N, while the number of output channels is M. Pruning the filter at index i from layer L would reduce the number of output channels of layer L, and hence, the number of input channels in layer L+1, by one. Thus, for pruning the chosen filter, we also need to update all the layers with layer L as their input. We modify this technique to work with our multi-task architecture, reducing 1024 filters at each pruning iteration, for 15 iterations. Refer to this blog post for more context and details.
To deploy models in production, we first convert the PyTorch model to a TorchScript module. Torchscript serializes the model and makes it independent of any Python dependency. We mix scripting and tracing for our model since certain parts of the model have a control flow which is data dependent.
We use TorchServe for serving our model. It has a built-in web server which takes a model and serves it as a REST API request. It also allows us to configure the model for batched predictions easily. The torchscript module is packaged in a model archive (.mar) file, required for TorchServe, which contains the entry point and workflow for executing an inference request. Finally, the TorchScript module and the mar file, both, are uploaded to an s3 bucket. We built an internal framework for maintaining model registry and serving. Each model, ready for production, is added to the model registry with all metadata associated with it. The framework uses these artifacts to then serve the model in an AWS EC2 instance with TorchServe.
To deploy our model on smartphones, we use PyTorch Mobile. We build an SDK around it which fetches the artifacts from the model registry and enables us to efficiently deploy multiple models on-device. The application is able to do inference without requiring any internet connection. It queues all the data required for logging and syncs it back to the server in background when the mobile is connected to the network.
To ensure such experiments are replicable, please take into consideration the following:
- Containerization: Docker has been widely used to ensure that all the dependencies required for replicating an experiment are packaged within a single contained environment.
- Data versioning: As datasets continue to evolve over time, train/val/test splits cannot be static. Explicitly version data to take snapshots of samples used to train any given experiment. For now, we store the splits explicitly in a data version file. In the future, we are looking forward to migrating to Artifacts by Weights & Biases.
- Experiment versioning: We store all our hyperparameters for each experiment in a config file along with the Git commit ID and the specific seed.
Our solution is going to be used by more than 10,000 farmers across India in the coming season and we hope that this encourages more people to creatively apply AI to solve societal problems.
Aman is a Research Fellow at Wadhwani AI. At Wadhwani AI, he has been working on AI for early detection of pest infestation, identifying malnutrition in neonates using a smartphone video and using cough sounds to screen for COVID-19. Aman graduated with a bachelor’s degree in Electronics and Communication engineering from the Indian Institute of Technology, Guwahati.
Vishal is an ML Engineer at Wadhwani AI. He has been working on AI for early detection of pest infestation and screening tools for Tuberculosis. Currently he is focusing on MLOps with developing and adapting best practices that enable us to build scalable machine learning systems. Vishal has an interest in Systems and ML and has an academic background in Electronics and Electrical Engineering.
About Wadhwani AI
The Wadhwani Institute for Artificial Intelligence is an independent, nonprofit research institute and global hub, developing AI solutions for social good. We harness the power of modern artificial intelligence to meet global challenges. Our mission is to develop and apply AI-based innovations and solutions to a broad range of societal domains including healthcare, agriculture, education, infrastructure, and financial inclusion.