VisionAir: Using Federated Learning to Improve Model Performance on Android
The rapidly deteriorating quality of air has affected billions globally. This calls for the development of an inexpensive and easily accessible application that can help a common man understand the quality of air they are breathing.
However, Like most existing applications in this data hungry era, currently available Air Quality Index (AQI) estimating applications share private user data with remote servers which discourage users from engaging with such applications.
VisionAir is a privacy preserving android application that allows the user to estimate the Air Quality Index (AQI) of a region using an image that the user takes.
It uses On Device training to ensure that all the private data of the user, including the image the user takes are processed on device to avoid the breach of any private data.
It also uses federated learning to consistently diversify and improve the deep learning model’s performance taking the crowd sourced application’s inputs into consideration.
What is Federated Learning ?
In Federated Learning, a global model is pushed onto client devices and then the client device trains this model for a few iterations.The clients then send their modified weights for averaging .The averaged weights are used to modify the weights of the global model which is again pushed onto all the client devices.
The Global Model
VisionAir’s global model takes the following inputs:
1. An image with at least 60% skyline at a location of the users choice.
2. Ground truth AQI of that location : Obtained from making an API call to the nearest Central Pollution Control Board (CPCB) centre.
3. Meteorological data of that location : Obtained from making API calls to the state weather monitoring station.
We use Tensorflow’s API for Java to achieve completely isolated on device processing and OpenCV’s Java library to carry out the image processing required for the feature extraction.
The global model that had to be pushed on the client devices was trained on a diverse dataset of 4000+ images taken across 80+ locations in Delhi-NCR. The dataset was mapped with relevant ground truth values, corresponding weather data and is available here.
To curate this dataset we used two strategies:
- If there was a CPCB centre within a 1 KM radius of the location where the photo was clicked — We used CPCB’s data as the label for that particular image.
- In any other case, we used an accurate, high quality portable sensor : AirVeda which gave us the AQI reading corresponding to that particular image.
Note that we had previously calibrated the AirVeda Sensor and validated if both CPCB sensor readings and AirVeda readings can concurrent for the same location.
The On-Device Model
For the models that were being trained On-Device, the client device needed the following inputs:
- The image that the user took. Note that this image is not stored anywhere and all processing needed for extracting the features, is carried out on device.
- The ground truth from the nearest CPCB centre. This value was used as a label in the On-Device training process.
- Weather parameters: Obtained by making API calls to the nearest weather monitoring station.
How VisionAir uses Federated Learning
Using Federated Learning allows VisionAir to maintain the user’s privacy while allowing the model to improve gradually. Hence, the model is able to incorporate seasonal changes in the PM 2.5 estimates and the model generalizes better.
Federated Learning consists of two processes:
Client Side Processing: The client device uses On-Device Training to improve the PM 2.5 estimates for a user. This allows the estimates to be adapted to each user.
Server Side Processing: The improvements in the individual On-Device models are aggregated using the concept of Federated Averaging to update the global model which increases performance and helps the model to generalize better.
The two processes are explained below:
Client Side Processing (On-Device Training):
VisionAir uses On-Device Training to improve model performance. It uses the TensorFlow API for Java to run training epochs on the device.
This requires the deep learning model’s metagraph and a checkpoint file.
The code below will generate the metagraph for the model being used in the current TensorFlow session. Note that the metagraph only contains the model’s architecture and not the weights.
To restore the weights for the graph, a checkpoint file is created:
This will generate the checkpoint files for the model which will be used to restore the weights of the graph.
Training the model On-Device:
Import TensorFlow’s API for Java by adding its dependency in your app’s build.gradle:
In the Activity, create an object of org.tensorflow.Graph:
Place the .pb file generated before in the assets folder and import it as a byte array. Let the array’s name be graphdef:
Now, import the graphdef into the graph variable:
The architecture of the model has been loaded. Now, the checkpoint files will be used to restore the model’s weights.
To load the checkpoint, place the checkpoint files in the device and create a Tensor to the path of the checkpoint prefix:
Now, load the checkpoint by running the restore checkpoint op in the graph:
Alternatively, initialize the graph by calling the init op:
This will initialize the graph using the initializer specfied in the graph.
Running On-Device Training:
Now that the model has been restored, VisionAir can perform On-Device training.
First, VisionAir estimates the PM 2.5 concentration in an area using the model and displays an estimate. The estimation is done using the predict() method below:
The current concentration of the area is then fetched through an API call which serves as the label for On-Device Training:
The following function then performs the On-Device training:
Now that the model has been trained, the updated weights need to be sent for averaging. The following function extracts the weights from the model:
Now, these weights are saved on the device in order to send them for averaging. The function finalSave() saves the weights on the device:
Note that this function includes a call to another function save() which does the required pre-processing.
Now that the weights have been saved, VisionAir sends these weights to a secure server:
The weights are uploaded using an AsyncTask. The AsyncTask has three functions:
- isUpdated(): Checks if the global model has been updated.
2. uploadWeights(): Uploads the weights to the server.
3. downloadFiles(): Downloads the global model if it is updated.
Server Side Processing (Federated Averaging):
The weights sent by the Client device (Android) are received on the server. The following function is responsible for receiving the weights and temporarily storing them for averaging:
Performing Averaging on the Server:
Now, the weights are averaged on the server.
The Android device now downloads this updated model using the downloadFiles() function shown above.
Conclusion: Federated Learning and Beyond
VisionAir successfully demonstrates the ability to continuously improve a custom built deep learning model while also demonstrating how to maintain the privacy of the user. Being one of the first of its kind, the performance of VisionAir is being continuously monitored and aggregated results shall be released soon.