VisionAir: Using Federated Learning to Improve Model Performance on Android

Vision Air
VisionAir
Published in
6 min readOct 19, 2019
VisionAir: A privacy preserving Air Quality estimation android application

The rapidly deteriorating quality of air has affected billions globally. This calls for the development of an inexpensive and easily accessible application that can help a common man understand the quality of air they are breathing.

However, Like most existing applications in this data hungry era, currently available Air Quality Index (AQI) estimating applications share private user data with remote servers which discourage users from engaging with such applications.

Existing Applications lack the vision of preserving user’s private data

VisionAir is a privacy preserving android application that allows the user to estimate the Air Quality Index (AQI) of a region using an image that the user takes.

It uses On Device training to ensure that all the private data of the user, including the image the user takes are processed on device to avoid the breach of any private data.

It also uses federated learning to consistently diversify and improve the deep learning model’s performance taking the crowd sourced application’s inputs into consideration.

What is Federated Learning ?

In Federated Learning, a global model is pushed onto client devices and then the client device trains this model for a few iterations.The clients then send their modified weights for averaging .The averaged weights are used to modify the weights of the global model which is again pushed onto all the client devices.

The Global Model

VisionAir’s global model takes the following inputs:
1. An image with at least 60% skyline at a location of the users choice.

A sample image from our dataset

2. Ground truth AQI of that location : Obtained from making an API call to the nearest Central Pollution Control Board (CPCB) centre.

3. Meteorological data of that location : Obtained from making API calls to the state weather monitoring station.

The global model is a lightweight neural network that uses image, weather and temporal features for predicting the PM AQI.

We use Tensorflow’s API for Java to achieve completely isolated on device processing and OpenCV’s Java library to carry out the image processing required for the feature extraction.

The global model that had to be pushed on the client devices was trained on a diverse dataset of 4000+ images taken across 80+ locations in Delhi-NCR. The dataset was mapped with relevant ground truth values, corresponding weather data and is available here.

A glimpse of the dataset used for training VisionAir

To curate this dataset we used two strategies:

  1. If there was a CPCB centre within a 1 KM radius of the location where the photo was clicked — We used CPCB’s data as the label for that particular image.
  2. In any other case, we used an accurate, high quality portable sensor : AirVeda which gave us the AQI reading corresponding to that particular image.

Note that we had previously calibrated the AirVeda Sensor and validated if both CPCB sensor readings and AirVeda readings can concurrent for the same location.

The On-Device Model

For the models that were being trained On-Device, the client device needed the following inputs:

  1. The image that the user took. Note that this image is not stored anywhere and all processing needed for extracting the features, is carried out on device.
  2. The ground truth from the nearest CPCB centre. This value was used as a label in the On-Device training process.
  3. Weather parameters: Obtained by making API calls to the nearest weather monitoring station.

How VisionAir uses Federated Learning

Using Federated Learning allows VisionAir to maintain the user’s privacy while allowing the model to improve gradually. Hence, the model is able to incorporate seasonal changes in the PM 2.5 estimates and the model generalizes better.

The process of Federated Learning

Federated Learning consists of two processes:

Client Side Processing: The client device uses On-Device Training to improve the PM 2.5 estimates for a user. This allows the estimates to be adapted to each user.

Server Side Processing: The improvements in the individual On-Device models are aggregated using the concept of Federated Averaging to update the global model which increases performance and helps the model to generalize better.

The two processes are explained below:

Client Side Processing (On-Device Training):

VisionAir uses On-Device Training to improve model performance. It uses the TensorFlow API for Java to run training epochs on the device.

This requires the deep learning model’s metagraph and a checkpoint file.

The code below will generate the metagraph for the model being used in the current TensorFlow session. Note that the metagraph only contains the model’s architecture and not the weights.

Saving the metagraph of the model
Metagraph generated by the above code

To restore the weights for the graph, a checkpoint file is created:

This will generate the checkpoint files for the model which will be used to restore the weights of the graph.

Checkpoint files containing the weights of the model

Training the model On-Device:

Import TensorFlow’s API for Java by adding its dependency in your app’s build.gradle:

In the Activity, create an object of org.tensorflow.Graph:

Place the .pb file generated before in the assets folder and import it as a byte[] array. Let the array’s name be graphdef:

Now, import the graphdef into the graph variable:

The architecture of the model has been loaded. Now, the checkpoint files will be used to restore the model’s weights.

To load the checkpoint, place the checkpoint files in the device and create a Tensor to the path of the checkpoint prefix:

Now, load the checkpoint by running the restore checkpoint op in the graph:

Alternatively, initialize the graph by calling the init op:

This will initialize the graph using the initializer specfied in the graph.

Running On-Device Training:

Now that the model has been restored, VisionAir can perform On-Device training.

First, VisionAir estimates the PM 2.5 concentration in an area using the model and displays an estimate. The estimation is done using the predict() method below:

The current concentration of the area is then fetched through an API call which serves as the label for On-Device Training:

The following function then performs the On-Device training:

Now that the model has been trained, the updated weights need to be sent for averaging. The following function extracts the weights from the model:

Now, these weights are saved on the device in order to send them for averaging. The function finalSave() saves the weights on the device:

Note that this function includes a call to another function save() which does the required pre-processing.

Now that the weights have been saved, VisionAir sends these weights to a secure server:

The weights are uploaded using an AsyncTask. The AsyncTask has three functions:

  1. isUpdated(): Checks if the global model has been updated.

2. uploadWeights(): Uploads the weights to the server.

3. downloadFiles(): Downloads the global model if it is updated.

Server Side Processing (Federated Averaging):

The weights sent by the Client device (Android) are received on the server. The following function is responsible for receiving the weights and temporarily storing them for averaging:

Receiving the weights on the Server side

Performing Averaging on the Server:

Now, the weights are averaged on the server.

The Android device now downloads this updated model using the downloadFiles() function shown above.

Conclusion: Federated Learning and Beyond

VisionAir successfully demonstrates the ability to continuously improve a custom built deep learning model while also demonstrating how to maintain the privacy of the user. Being one of the first of its kind, the performance of VisionAir is being continuously monitored and aggregated results shall be released soon.

Authored by Divyanshu Sharma, Harshita Diddee, Shivani Jindal and Shivam Grover.

--

--