Simulation-to-Reality Gap in Federated Learning — Part 2

Training, validation, and test in federated learning

Yasmine Djebrouni
6 min readApr 12, 2022

In the first part of this series (here), I discussed the simulation-to-reality gap in federated learning (FL) and its causes. I explained how FL is simulated by researchers and raised some questions about how FL is done in real production systems. In this part, I try to answer the raised question by explaining how training, validation, and testing are implemented in real commercial systems today. I note that I consider a standard federated learning system with one central server.

If you are very new to federated learning, I have something for you: A fairy tail about the emergence of FL :)

What Is Federated Learning (FL)? Plot Twist: A Fairy Tale

Federated Learning in Real Life

Similar to classical machine learning, federated learning consists of several phases, namely a training phase, a validation phase, and a testing phase. However, in FL, the data is distributed on different devices, and the learning process is distributed among the different data owners, called also FL clients.

During the training phase, clients contribute to the global goal by learning local models trained on their local data. The validation process takes place during training to validate the architecture of the trained model. Finally, the testing phase verifies the performance of the model before final deployment to the clients. FL is communication efficient, as learning data is not transmitted over the network during the FL process, but also privacy friendly, as the clients do not send their training data, which may contain private information, to other clients or to the server.

Training Phase

The training process in real FL systems takes place through several rounds under the direction of a central server. The clients may be referred to as clients, devices, or data owners. They can be mobile devices, banks, or even hospitals. Each client holds a set of private labeled data. Data is either labeled according to the activity of the owner, e.g: clicks history in mobile data, or by requesting the client to manually label his data, e.g: a hospital that labels its patients' instances according to their medical records.

Training begins when an FL developer sends a training job to the server. The training job consists of the model to be trained, the initial hyper-parameters and their values, and some selection criteria for the clients. The selection criteria are used by the server to select the clients to participate in the training from those who are available on the network. FL developers can mention several criteria. Examples of criteria when the clients are mobile devices are a high battery level and an un-metered network connection.

One round of federated learning training

At the beginning of each round, the server selects the clients that will participate in that learning round. The server sends the current version of the global model to the selected clients. In the first round, the server sends a randomly initialized model or a pre-trained model. Then, each participating client trains the received model on its local data and sends its local model updates to the server. Finally, the server performs a weighted aggregation of the received clients’ model updates through FedAvg[3] or other aggregation methods, to produce a new version of the global model. The process continues for several rounds until the global model converges. The process is shown in the following figure.

Validation and Tuning

Hyper-parameter tuning is similar to centralized training, but FL additionally involves 1) the number of training rounds and 2) the number of local iterations, which may cost lots of resources (communication bandwidth, computing power, energy, etc.). AutoML is a tool that can reduce the number of attempts of hyper-parameter tuning. A common practice in FL hyper-parameter tuning according to FL developers is to keep observing the training curve for a given hyper-parameters setting, and terminate the training as soon as the accuracy gets worse than previous runs, without the need to wait for the end of the training.

Hyper-parameter tuning can thus be done during model training. Every few rounds, proxy data stored on the server as validation data can be used to validate the model architecture and select values for basic hyper-parameters [4]. The proxy data can be data synthetically generated using private data generation methods or a collection of donated instances of real-world data owners. The validation data, other than tuning hyper-parameters, can serve to monitor the learning process when the training is long. One can validate/test its model performance against some validation data during training rounds, without having to wait for the training to end.

Test Phase

In this phase, the performance of the global model is tested against new data that it has never seen before, to measure how well it can generalize.

One common case is when the FL service provider has a global “public” test dataset that doesn’t contain any private information. It can be a collection of donated data or data synthetically generated by existing private data generation techniques. The evaluation in this scenario can run on the server side and its results analysed by the FL developers.

In other cases, FL developers need to test the performance of the model on real data sets, i.e., real edge devices or clients. This may be the case to validate the trained model before deployment, as in [4, 5], or to circumvent biased/non-representative proxy data used in training and validation, as described in [6]. To do so, the server selects a group of clients and sends them the final global model to run the test phase with their real data. After running the test, the clients send the performance of the final model on their data to the server, where the developers can analyse it. When the test is run on the clients, it is usually called: Live Test, Live Inference or On-Device Inference.

Another way of testing is fully decentralized testing, where the test happens on the clients, without any interference from the server. Actually, an organization normally migrates its ML training pipeline from centralized training on a single data silo (local dataset) to a federated manner that incorporates many other data silos. This allows the organization to improve its own model. In such a case, they already have the test datasets. So they just want to see if FL helps to improve the accuracy or generalization ability. The evaluation is run at the local client/organization side and, in contrast to previous cases, nothing is communicated to the server. Fully decentralized testing is also used in a more complex setting: Personalized FL, where again each client has its own test dataset. In that case, clients intentionally fine-tune the received global model to fit their data. Such personalized models will be tested against the local test data of the clients.

Here is the end of this series. I will be very happy to discuss anything from what I said in the comments section :)!

References

[1] Hard, A., Rao, K., Mathews, R., Ramaswamy, S., Beaufays, F., Augenstein, S., … & Ramage, D. (2018). Federated learning for mobile keyboard prediction. arXiv preprint arXiv:1811.03604.

[2] Lai, F., Zhu, X., Madhyastha, H. V., & Chowdhury, M. (2021). Oort: Efficient federated learning via guided participant selection. In 15th {USENIX} Symposium on Operating Systems Design and Implementation ({OSDI} 21) (pp. 19–35).

[3] McMahan, B., Moore, E., Ramage, D., Hampson, S., & y Arcas, B. A. (2017, April). Communication-efficient learning of deep networks from decentralized data. In Artificial intelligence and statistics (pp. 1273–1282). PMLR.

[4] Yang, T., Andrew, G., Eichner, H., Sun, H., Li, W., Kong, N., … & Beaufays, F. (2018). Applied federated learning: Improving google keyboard query suggestions. arXiv preprint arXiv:1812.02903.

[5] Wang, K., Mathews, R., Kiddon, C., Eichner, H., Beaufays, F., & Ramage, D. (2019). Federated evaluation of on-device personalization. arXiv preprint arXiv:1910.10252.

[6] Augenstein, S., McMahan, H. B., Ramage, D., Ramaswamy, S., Kairouz, P., Chen, M., & Mathews, R. (2019). Generative models for effective ML on private, decentralized datasets. arXiv preprint arXiv:1911.06679.

Useful Links

https://doc.fedml.ai/

--

--