Federated Learning on AWS

A privacy-preserving ML approach

Valerio Paduano
Storm Reply
12 min readOct 31, 2023

--

Introduction
With computing such a big part of all our daily lives, we are living in the golden age of Big Data. Today, individual users and organizations generate huge amounts of data every second and as the saying goes, “Data is the new oil”. Because of this, numerous organizations are constantly striving to unlock value from data exploiting it to extract data-driven insights about users’ behaviors, preferences, and market and societal trends.
This abundance of data paved the way to more and more sophisticated IT applications leading to the definitive explosion of Artificial Intelligence. In fact, in 2022 we saw the release of ground-breaking AI, such as generative image and text models, that are already revolutionizing the way we live and work and will continue to do the same in the coming years. Industry players, researchers, and academics are already asking: where to next?

Challenges
During this path, of course, there have been some bumps in the road. For example, scandals about private data leaks have raised the question of how we embrace these new technologies while, at the same time, protecting people’s private information. Data protection has become one of the biggest concerns in the AI community. In particular, AI systems often rely on large amounts of personal data to learn and make predictions, which raises apprehension about the regulations in the matters of collection, processing, and storage of such data. This might lead users and companies to feel uncomfortable sharing potentially private or sensitive data.
Another challenge is that similar (but not overlapping) users’ and companies’ datasets exist in isolated islands, making it difficult to enrich and merge datasets for training better-performing AI/ML models.
One approach whose aim is to handle and mitigate these problems is Federated Learning.

Key concepts
Federated learning was first introduced in literature between 2015 and 2016 (https://arxiv.org/abs/1511.03575 and https://arxiv.org/abs/1602.05629). One of the first practical applications was developed by Google in 2017 (https://arxiv.org/pdf/1602.05629.pdf) and aimed at improving next-word suggestions while typing with a smartphone keyboard — GBoard in this case. The text prediction was inferred using Machine Learning models trained by data across multiple devices. This new approach caught the attention of the research field because of its new paradigm which was a breakthrough in traditional ML to tackle data privacy issues, allowing multiple entities to work together on training a shared model without sharing their private data with a central organization.

In a nutshell, Federated Learning is all about decentralized ML and is composed of one trusted central entity and two or more participants. It is a collaborative ML technique that enables participants to train a shared ML model while keeping their data private. After an ‘initialization’ stage in which the model is chosen, FL relies on an iterative process that can be broken into an atomic set of client-server interactions known as FL Rounds. Each Round consists of:

  • Transferring the current global model state to each participant (weights & biases).
  • Training a local model to produce a set of potential model updates for each participant (each using their data).
  • Transferring back each trained model to the central server.
  • Aggregating the received model updates in a pre-defined fashion (several model aggregation algorithms exist), to generate a new global model which is sent to the participants.
  • If the chosen termination criterion is met (i.e., convergence is reached), the process stops. Otherwise, a new round starts.

With this approach, the participant’s sensitive training data is not exposed since only the model updates are shared with the central entity. This ensures sensitive user data remains private and secure, while still allowing for collaborative model training.

A key benefit of Federated Learning is its ability to handle large-scale distributed datasets. This is particularly useful where data is siloed across multiple devices or locations, which prefer not — or aren’t allowed — to share data.
As it could be imagined, grounding a FL solution brings multiple challenges. For example, lack of appropriate computing power (both in the participant/training phase and the central entity/aggregation phase), high modular scalability, etc.
Moreover, FL may be required in scenarios in which we do not have physical devices such as smartphones to execute the training. For example, multiple healthcare institutions may want to train ML models on confidential patient data. This would require proper computing infrastructure for each institution, resulting in high costs.

In this scenario, Cloud Computing capabilities can be properly leveraged to overcome these challenges. Here is where the expertise of Storm Reply IT in designing efficient, highly scalable, and robust AWS cloud-native solutions fits in.

High-level architecture

High-Level FL Architectural Process

In the above diagram, we propose a high-level overview of a possible FL implementation on AWS, leveraging a serverless and event-driven design.
There will be one Federation account, hosting the orchestration and aggregation logic and 2+ client accounts hosting the training logic.
When a FL process is started, a starting model is distributed to all the participants. This action triggers the compute unit which will train the starting model on the participant’s private dataset. Once completed, the model trained by each participant will be copied to the central organization’s storage.
Once all the participants (or a chosen subset) have finished their training, the aggregator will be triggered and will merge the models using the chosen aggregation algorithm. The aggregated model will then be saved and sent to the participants for testing.
Each Participant will test the global model on their private dataset and publish the resulting metrics into the central organization’s storage. The central organization will then evaluate the overall metrics and will, if satisfactory, deploy/publish the model, otherwise it will use the aggregated model as a starting point for a new round of FL.

Central entity architecture
Now, let’s have a deeper look into the architecture of the Central Entity. To maximize security, network traffic across the public internet is intentionally avoided by using VPC endpoints for accessing S3 and DynamoDB services.
The process in the Central Entity is coordinated by 3 Lambda Functions that are asynchronously triggered during the FL Rounds:

  • Start FL: generates a unique execution ID to distinguish the FL process and then it invokes (with an empty payload — which is interpreted as a START signal) the first run of the Orchestrator Function (the Start FL Function won’t be involved in the rest of the FL process).
  • Orchestrator: generates the first “base model” by using well-known ML frameworks. This function is also in charge of checking the final aggregated model metrics against certain stopping conditions (e.g. loss convergence, accuracy, F1 target scores, etc.) which dictate whether the FL process needs other Rounds to converge or can be stopped.
  • Aggregator: merges the different models into one by using a given FL strategy (e.g. FedAvg) and then stores it on S3.

The Central Entity’s process will use 3 DynamoDB Tables to regulate and synchronize the various steps of the FL Round:

  • Participants Control Table: contains the list of all the Account IDs of the participating clients.
  • Participants Barrier Table: in this table, each Participant writes an ACK message when finished training. This table triggers an event to the Aggregator when a certain number of clients have written their entry.
  • Metrics Control Table: a table to which each client will write the testing metrics for the global model. This table triggers an event to the Orchestrator when a certain number of clients have written their entries.

Moreover, the process will need 2 S3 Buckets in the Central Entity each one with different KMS CMK Keys:

  • Aggregated Models Bucket: in this Bucket the base and the aggregated models are saved. A cross-account replication is set from this Bucket to the Participants’ ones to automatically distribute the model in the starting phase of the FL Round.
  • Trained Models Bucket: this bucket contains the trained models of each Participant.

Central Entity process

Central Entity Architectural Process Flow

The FL process is started by a Data Scientist who invokes the Lambda Function Start FL. The Start FL Function is the starting point of the first FL Round. The process follows the steps below:

  1. As already mentioned, the Start FL Function “kicks” the first run of the Orchestrator Function.
  2. The Orchestrator Function generates the “base model” and puts it in the “base-models/” folder of the Aggregated Models Bucket. In this S3 Bucket, a cross-account replication mechanism is continuously fetching data to be copied into the Participants’ S3 Buckets where the “base models” are collected. At this stage, the Central Entity starts waiting for the Participants’ training to finish.
  3. When each Participant finishes the training:
    a. The trained model is copied back through cross-account replication in the Trained Models Bucket of the Central Entity.
    b. The Participant notifies the Central Entity at the end of the training by writing a record in the DynamoDB Participants Barrier Table by assuming a cross-account IAM Role.
  4. When all the Federation Participants have written on the barrier, the Aggregator Function is triggered, and the Participants Barrier Table is emptied.
  5. The Aggregator has two tasks:
    a. As already mentioned, it merges the different models generating an aggregated model.
    b. After that, the Aggregator stores the new aggregated model in the “testing-models/” folder of the Aggregated Models Bucket (and consequently in the Participants’ Test Models Buckets thanks to the cross-account replication).
  6. At this stage, again, the Central Entity waits for the Participants’ metrics testing to finish:
    a. The cross-account replication on the Test Models Buckets triggers the metrics evaluation process on the Participants’ side.
    b. Each Participant notifies the Central Entity at the end of the testing by writing a record in the DynamoDB Metrics Control Table by assuming a cross-account IAM Role.
  7. When all the Federation Participants have written on the Metrics Control Table, the Orchestrator checks the stopping condition:
    a. If the stopping condition is met, the process ends.
    b. If the stopping condition is not met, the aggregated model becomes the new “base model” and it’s stored in the “base-models/” folder of the Aggregated Models Bucket to start a new FL Round.

Federation Participant architecture
Let’s now shift the focus to the architecture of each Federation’s Participant.
Just as the Central Entity, to maximize security, network traffic across the public internet is intentionally avoided by using VPC endpoints for accessing S3.
The process in the Participant is coordinated by 4 Lambda Functions that are asynchronously triggered during the FL Rounds:

  • ECR Image Creator: generates a custom ECR Image of the “base model”. This image is then used in the Sagemaker Training Job.
  • Start Training Job: creates and starts the Sagemaker Training Job.
  • Notify Training End: notifies, the Central Entity, of the end of the training by writing a record in the DynamoDB Participants Barrier Table by assuming a cross-account IAM Role.
  • Test Aggregated Model: evaluates the metrics on the aggregated model and notifies the Central Entity at the end of the testing by writing a record in the DynamoDB Metrics Control Table by assuming a cross-account IAM Role.

There will also be 4 S3 Buckets:

  • Dataset Bucket: in this Bucket the private Participant dataset resides. It has a dedicated KMS CMK Key.
  • Base Models Bucket: in this Bucket, the base and the aggregated models are copied by the Central Entity through a cross-account replication in the starting phase of the FL Round.
  • Trained Models Bucket: this bucket contains the output model of the Sagemaker Training Job.
  • Test Models Bucket: this bucket contains the aggregated model for metrics evaluation purposes.

Federation Participant process

Federation Participant Architectural Process Flow

On the Participant side each FL Round follows the steps below:

  1. The Base Models Bucket receives the new base model through cross-account replication.
  2. The ObjectCreation event triggers the ECR Image Creator Function which is in charge of:
    a. Creating the container image of the custom base model,
    b. And then, invoking the Start Training Job Function.
  3. The Start Training Job Function creates and starts the Sagemaker Training Job by configuring it with the input payload which is a JSON file containing all the attributes needed by the Training Job. Moreover, the Training Job gets also as input the training dataset from the Dataset Bucket.
  4. Once the training is finished:
    a. The output is written on the Trained Models Bucket on which cross-account replication to the Central Entity bucket is configured.
    b. Moreover, the ObjectCreation event triggers the Notify Training End Function whose task is to notify, as already said, the end of the training to the Central Entity.
    At this stage, the Participant waits for the Central Entity to do the aggregation task.
  5. The Test Models Bucket receives the new aggregated model through cross-account replication.
  6. The ObjectCreation event triggers the Test Aggregated Model Function which is in charge of:
    a. Computing the model metrics,
    b. And then, writing them in the DynamoDB Metrics Control Table of the Central Entity.

Overall architecture

Low-Level Overall Architectural Process

The cross-account interactions together with the overall design of the architecture are depicted in the figure above.

It is of significant importance the fact that the proposed architecture is built following a serverless event-driven paradigm. This ensures an asynchronous and cost-effective solution since each component runs only when needed and for the time that is required. Moreover, a serverless architecture ensures high scalability during peak computations, maintainability, robustness, and, last but not least, contained costs.

Moreover, on the privacy-preserving side, it should be highlighted how, by design, the architecture embeds 3 security levels:

  1. Data Level: Each Participant will have 2 KMS CMK Keys in their account with which they will handle encryption. More specifically, there will be 2 DEK (Data Encryption Keys):
    - One is used exclusively by the client to encrypt the private dataset.
    - One is used to encrypt the trained models before replicating them to the Central Entity. Going into the details of cross-account bucket replication, the trained models are replicated between the S3 buckets within the AWS service itself, without the models having to traverse the public internet. This means that the models are not sent over the Internet during the replication process. Indeed, communication between the S3 buckets is handled within the AWS network, ensuring the security and efficiency of replication.
  2. Service Level: The data is stored on S3 Buckets and each bucket is secured using Bucket Policies which are intended to deny access from unexpected entities.
  3. Account Level: Following a Least Privileged Best-Practice, each resource in each Federation’s account has an attached IAM Role which provides limited access to the other AWS services.

With these security shrewdnesses, the solution guarantees that:

  • Only the models are transferred cross-account; the data resides and remains only in the owner account.
  • The Participant’s data is encrypted, private, and safe; the Central Organization cannot interact with them in any way.
  • The whole system is isolated from external intrusions.

Conclusions
Federated Learning emerges as a promising solution to address the growing concerns about data privacy and security, as well as the practical challenges of merging disparate datasets. It attempts to mitigate these issues by enabling collaborative training of shared machine learning models while safeguarding the privacy of sensitive data.

To address these challenges we, as Storm Reply IT, proposed an efficient, highly scalable, and robust AWS Cloud Native Federated Learning solution. The high-level architecture of Federated Learning on AWS involves a central entity and multiple client accounts, where models are trained, aggregated, and tested. The central entity’s architecture comprises Lambda functions, DynamoDB tables, and S3 buckets, ensuring secure and efficient orchestration of the Federated Learning process. Meanwhile, each participant’s architecture includes Lambda functions and S3 buckets, with VPC endpoints for secure data access. The whole process is fully event-driven and asynchronous ensuring that each component runs only when it needs to, therefore enabling a cost-effective solution.

Overall, the Federated Learning architecture is designed with privacy in mind, with each participant handling encryption using KMS Customer Managed Keys. Cross-account bucket replication within the AWS network ensures data security and efficiency, avoiding exposure to the public internet. This approach ensures that the Participant’s data remains encrypted, private, and inaccessible.

In conclusion, Federated Learning offers a promising avenue for collaborative machine learning while preserving data privacy and security. When implemented with the right infrastructure and architectural considerations, it holds the potential to unlock valuable insights from distributed datasets without compromising on data protection.

Authors: Valerio Paduano, Riccardo Polini

--

--