Introduction to Federated Learning

Manish Nayak
Nov 12 · 3 min read

Introduction

Any deep learning model learns from the data and that data must be collected or uploading on the server (one machine or in a data center). A most realistic and meaningful deep learning model can learn from personal data. Personal data is extremely private and sensitive and no one would like to send or upload it on the server. Federated learning is a collaborative machine learning approach in which we trained a model without centralizing data on the server and this is the main kind of a revolution.

What if we bring the model to the data where it generated instead of bringing data to one location and training a model.

The main use case is when we want to improve a pre-trained model repeatedly using the data from multiple mobile devices or it can be any kind of embedded device like all sorts of Internet of Things that are connected to the Internet, and without uploading the data to end server or cloud.

This is really interesting because the actual solution to this problem is really simple. First the clients, the mobile devices get a pre-trained model and then improve the model using local data. So the actual model is trained on a locally on devices and sends the model back to the server.

The server combines all models that it gets from clients. And this combined model becomes the next initial model that will be sent to clients and we just repeat the process. All those devices get the benefit of each device's data.

Federated Learning
  • Performance: If the client has only a few training examples it can still learn a bit about the data. If we 50,000 clients each with small data they spend most of the time sending the model back and forth and not much time training if the model is really big
  • Privacy: By looking at weights changed, someone can figure out personal data. So we can not use federated learning if someone gets to know, the training data from the weight update.

To deal with both these issue Google developed a Secure Aggregation protocol in which the main idea is server generates a public and private key pair and share public keys to each client.

Then clients talk directly to each other and share their encrypted updated weight using the public key of the server. All clients having only public key shared by the server so there is no way any client can see other weights update.

All clients accumulate all their model’s weight into a single and final update sent back to the server. Then the server decrypted it using a private key and update server’s model weight. In this process, the server gets accumulate weights so the server also can not see any particular client's weight update. In this secure aggregation protocol, no individual phone’s update can be inspected before averaging by the server. The server can request to share the update to a client and the client will only respond if it has been syncing up with other clients and accumulated their weight with some threshold |Number Of Clients| > threshold. After getting a response from the client, the server reconstructs the accumulated weight with a private key and computes the aggregate value.

One question you may ask, how clients accumulate weight which encrypted? The answer is Homomorphic Encryption. Homomorphic Encryption lets you perform computation on encrypted values without decrypting them. you can try yourself by installing the python-paillier library(pip install phe).

Homomorphic Encryption in Python

Conclusion

Using federated learning, now we can develop a very useful and precise model that learns from personal data such as healthcare, personal management, where data sets are often tightly locked which making research difficult.

I hope this article helped you to get an understanding of federated learning using the user’s personal and private data. I also try to explain how user’s data is secure and no client’s deep learning model can sniff around the user’s sensitive data not even the server’s deep learning model can see user’s sensitive data.

References

Towards AI

Towards AI, is the world’s fastest-growing AI community for learning, programming, building and implementing AI.

Manish Nayak

Written by

Machine Learning, AI & Deep Learning Enthusiasts

Towards AI

Towards AI, is the world’s fastest-growing AI community for learning, programming, building and implementing AI.

Welcome to a place where words matter. On Medium, smart voices and original ideas take center stage - with no ads in sight. Watch
Follow all the topics you care about, and we’ll deliver the best stories for you to your homepage and inbox. Explore
Get unlimited access to the best stories on Medium — and support writers while you’re at it. Just $5/month. Upgrade