Federated learning an introduction

Privacy in an increasing concern in the smart phones era. Every smart device owner is generating lots of data and this data has a lot of potential to improve the way we all communicate and also perform our day to day tasks. For instance let’s say based on our messages and emails we type using the keyboard , if we could get better suggestions for the messages we type or even story writing (like GPT2 ) won’t our lives be a lot easier ? as it would save us the time of typing same responses for FAQ’s. But there is a catch here right? Training a LSTM or any attention based model on the data generated by user would give more personalized recommendations rather than training it on a large generic corpus. The problem here is the data would leave the device and has to be fed to a model running on a central server. We would not want the things we type on our smart phones to be accessible to a central server.

One solution is to randomize every text before sending it to server(differential privacy). But another effective method is federated learning.

Source: http://vision.cloudera.com/wp-content/uploads/2018/11/2018-10-31-181344-federated_learning_animated_labeled.gif

Before we start on federated learning a common question that would have crossed the mind is “wait, if the data leaving the device is the problem why not have the entire model setup in the device and train it locally?”. Well that won’t work out well for 2 reasons:

  1. The data on a single client is very limited.
  2. Other clients are not contributing to the model training and we risk chances on overfitting to device data and the model is not very intelligent.In short the suggestion might become more of an annoyance rather than a convenience.

So, what if we can still train a centralized model on decentralized data without invading privacy. That’s where FL comes in. A more intuitive explanation of FL would be “bringing the model to data rather than the data to the model as in traditional approach”.

How FL Works ?

The server first initializes the model. This could be done with any of the neural network initialization methods. It could also be a pretrained model. Then the server sends the model to the client devices(selected at random). Each device trains the model locally and computes an update that must be made to the global model. The server receives the updates and averages them by including a weighting factor for each update which is computed based on the training set each client used. Then the global model is updated using some form of gradient descent(like SGD). This entire process flow amounts to one round. Several such rounds are performed until the model converges.

To avoid multiple round trips the parameter of number of users should be chosen carefully.

Well, the above methodology sounds pretty cool. But there are some problems in FL(federated learning):

  1. The distribution the data is sampled from varies from client to client. In traditional distributed learning there are ways to ensure that the data is well represented. But in FL since users are sampled at random it is very much possible that two datasets of 2 random clients are completely from different distributions.Also each user might not have same amount of samples. One could generate only 10 samples while another could generate larger number based on usage.
  2. Communication problems arise as devices may belong to different networks (lesser bandwidth etc.)

And one more problem we didn’t discuss while discussing the how works is privacy.

While it is true that the data never left the device in the process discussed above there is a possibility of reconstruction attack. A coordinating device maybe an adversary and if that adversary gets access to the gradients it is possible to reconstruct the original data from the gradients(eg: reconstruction of faces in the facial recog. domain).

There are several ways to tackle this.

  1. Gradient clipping : When gradient clipping is performed it limits the amount of influence a single update(an individual device) can influence the final update. This prevents the adversary from learning an individual based on the contribution.
  2. Another way is to not sample from a fixed number of users but involving different users randomly.
  3. A popular way is to add some noise to the individual updates (local differential privacy) or to the final update(global diff. privacy). The noise added to individual updates are added in such a way that when aggregating them they cancel out and it is possible to obtain a proper final update.

CONCLUSION:

I hope this gives an insight into how FL works. One fun fact is google board uses FL to give personalized suggestions. Don’t worry about the model training draining the battery as it happens only when the device is plugged in .

--

--