On device inference is very common. On device training, not so much. Federated learning paves the way for doing on device training on multiple devices while taking care of privacy.
Smartphones have become ubiquitous. And these days, machine learning is prevalent on every device. On device inference has become quite common, especially on apps made by Google. The way it works is that a ML model is trained on server and embedded in the app. The actual inference takes place on the device.
However, let’s talk about training the model. Most models are trained on the server. But where do they get the training data from? To collect the data for training, the apps that you are using send data to the server (hopefully anonymously). The train and validation set are created from this data collected from multiple users and the model is trained. This poses a big challenge to user privacy. Users might not want to share their personal data with someone else.
To assuage this concern, McMahan et al. came up with the idea of Federated Learning and published a paper in 2017 titled Communication-Efficient Learning of Deep Networks from Decentralized Data. In this post, we will primarily focus on the algorithm used for Federated Learning.
What is Federated Learning?
In federated learning, the ML model is trained on the user’s device. The trained weights are shared with the server (instead of the actual data) which then averages these weights to create a global model. The authors call this algorithm Federated Averaging, which we will discuss in this post. The one minute video below is a nice non-technical summary of how federated learning works.
It is obvious by now that this algorithm has two components. Training on device and averaging on server. Let’s first turn our attention to on device training.
On device training
At time t = 0, the device gets a model trained on the server. Let’s call this w⁰. The server also sends the mini-batch size b, learning rate η, the number of epochs e to the device and any other parameters needed by the model. The data keeps getting collected on the device. When a sufficient amount of data is collected, the model is trained on device. The pseudo code looks something like this.
w¹ = model(x, y, b, e, η)
where w¹ is the new weight matrix calculated by the model, and x, y are the input and target output generated from the locally collected data. These new weights are then shared with the server.
On server training
The server collects the trained weights from a bunch of devices. Let’s represent the weights collected from device k as w¹ᵏ. The server updates it’s global weight matrix (let’s call it 𝔤) as per the following pseudo code.
𝔤 = 𝔤 + (nᵏ * w¹ᵏ / N)
where nᵏ is the number of data points used to obtain w¹ᵏ on device k and N is the sum of the number of data points across all such devices.
Out of all the available clients K, the server considers a small fraction of clients (C) in each round to update it’s global weight.
nc = max(C * K, 1)
where nc is the number of clients.
Federated Averaging vis-a-vis SGD
Because the server considers a small fraction of clients every time, you can think of this algorithm as a mini-batch gradient descent, where the batch size is nc. In each iteration, server randomly chooses nc clients, and updates the global weights. If C = 1, it is similar to a full-batch (non-stochastic) gradient descent.
Case Study: GBoard (Google Keyboard)
In 2018, Hard et al. published a paper titled Federated Learning for Mobile Keyboard Prediction. This paper gives details about how federated learning is used in GBoard. We will discuss the most important aspects of this paper in this section.
If you have used GBoard, you would have noticed that it predicts the next word while you are typing. Taking an example straight from the paper, if you have typed
I love you, GBoard suggests you the three most likely words you will type next:
and. The word with the highest probability is shown in the center, while words with second and third most probabilities are shown on left and right respectively.
For on device training of this model, the authors have used a variant of LSTM called Coupled Input and Forget Gate (CIFG). This architecture uses 25% less parameters per cell. Hence, it’s a perfect fit for a mobile device environment. The number of computations reduces along with the parameter set size without significant impact on model performance.
Authors show in this paper that the performance of federated CIFG is almost as good as CIFG trained on a server.