A Gentle Introduction to Split Learning

Đồng Nguyễn Minh ANH
11 min readApr 11, 2023

Distributed deep learning without sharing raw data developed by MIT Media Lab’s Camera Culture Group and without being computationally demanding.

INTRODUCTION & BACKGROUND

In traditionally private industries like finance and health care, the shortage of labeled data and computational resources is a vital issue when developing ML algorithms. Historically, ML architectures are built upon the belief that algorithms are to be centralized — meaning that both the training data and the model are in the same location and known to the researcher. Data can either be anonymize, obfuscate or encrypted (using methodologies like homomorphic encryption). Despite this, there is always the risk of raw data being decrypted from the trained model with adversarial attacks. The genie which has left the bottle here is decentralization, which opens the door to innovation through an information resource which has previously been inaccessible; private data.

WHY IS DATA LIMITED IN A HEALTHCARE SETTING?

Collaboration in health is heavily impeded by lack of trust, data sharing regulations such as HIPAA (Health Insurance Portability and Accountability Act — regulations to protect privacy and security of individual’s health information) and limited consent of patients. In environment where various institutions have different types of patient data in electronic health records, picture archiving and communication systems for imaging data, pathology test results, genetic markers for disease — is where split learning comes into hand.

For example, two entities holding pathology test results and radiology data can collaborate on training a deep learning model for patient diagnosis without sharing their raw data. Smaller hospitals can contribute data to an aggregate model (ML model trained by aggregating data and knowledge from various sources and combines the strengths of different models to improve its overall effectiveness) without sharing any raw data, enabling them to effectively serve those in need. Split learning can improve accuracy while using significantly lower computational power and communication bandwidth compared to previous methodologies like large batch synchronous stochastic gradient descent.

How distributed learning can be used to detect threats such as retinopathy images or undetected fast-moving threats, where different health institutions with slow internet connections can pool their images without exchanging raw patient data to train a model for threat detection.

To use sensitive data and build deep learning applications and gain insights from the data like predicting rare diseases or financial crimes, but without getting hold of the raw data, we can use a collaborative training of distributed machine learning models without any data sharing. In this article, we will be exploring split learning — a new technique created at MIT Media Lab that allows the training of ML models without sharing raw data, overcoming challenges like data silos and data sharing.

SplitNN does not share raw data or model details with collaborating institutions. Catering to practical settings of the health sector :

i) entities holding different types of information about patients like their medical history, imaging data and pathology test results.

ii) centralized and local health entities collaborating on multiple tasks : meaning that various health organizations are working together on different healthcare-related projects to work towards a common goal.

iii) learning without sharing labels: training ML models where the data like diagnoses or outcomes are not shared among various entities due to privacy concerns and regulations, but the models can still learn from available data.

WHAT IS SPLIT LEARNING?

In short, split learning is a type of distributed deep learning that enables entities to collaboratively train a ML model without sharing raw data. In split learning, data is kept on local devices and only model updates are shared — addressing data privacy and security and targeting the reduction of computational and communication costs.

This method works by dividing a deep learning model into two parts: a front-end and a back-end.

The front-end is kept on a local device, and it processes the input data up to a certain point, typically right before the last few layers of the model that produce the final output. The output of the front-end is sent to a central server, which has the back-end of the model. The server then continues to process the output of the front-end and generates the final output.

During the training process, the front-end and back-end are updated separately. The local devices update the front-end of the model based on their own data, and then send the updated model to the server. The server updates the back-end of the model based on the updates it receives from the front-ends, and then sends the updated back-end model to each of the local devices. This process of updating the front-end & back-end models is repeated until the model converges to the best outcome.

There are various benefits to this method:

  1. Allows the sharing of raw data, addressing privacy and security concerns.
  2. Decrease computational and communication costs by only requiring the exchange of model updates instead of raw data
  3. Applied to a wide range of deep learning models and can be integrated into existing ML workflows.

If all of this sound a bit confusing, think of a group of friends in playing an escape game. One of such task is to try and solve a puzzle without sharing the complete picture with each other. Each person in the group only has a small section of the complete puzzle and cannot see what the other members are working on. However, they all know what the final puzzle should look like.

In split learning, each member corresponds to a different entity holding a subset of data, and the puzzle pieces correspond to the models trained on the data. Each entity trains its own model on its data and sends only the model’s output to the next entity, without sharing any raw data. The final model is created by combining the output of all the individual models, which results in a complete solution without revealing any sensitive data.

HOW DOES SPLIT LEARNING WORK IN DETAILS?

In split learning, a deep neural network is split into multiple sections, each of which is trained on a different client. The data being trained on might reside on one supercomputing resource or might reside in the multiple clients taking part in the collaborative training.

But none of the friends involved in training the deep neural network can “see” each other’s data — or puzzle in the analogy.

Techniques are applied on the data which encode data into a different space before transmitting it to train a deep neural network.

And as the neural network is split into several parts, and each of these part are trained on a different client. This process continues by transferring the weights of the last layer of each part to the next, without sharing any raw data — only the weights of the last layer (a.k.a cut layer) of each section is sent to the next client.

Let’s look at the following representation which elaborates SplitNN Training:

SplitNN training

SplitNN Training where the layer marked by the Green line represents the Cut Layer. Here the top part of the network is trained on the server and the bottom part of the neural network is trained on multiple clients.

Definition of forward propagation and backward propagation

This process is continued until the distributed split learning network is trained without looking at each others raw data.

For example, a split learning configuration allows for resource-constrained local hospitals with smaller individual datasets to collaborate and build a ML model that offers superior healthcare diagnostics, without sharing any raw data across each other as necessitated by trust, regulation and privacy.

SplitNN CONFIGURATIONS

Simple vanilla configuration for split learning: The basic configuration of splitNN involves multiple clients, such as radiology centers, who each train a portion of a deep neural network up to a specific layer called “cut layer.” The outputs at this layer are sent to a server which completes the rest of the training without accessing any raw data from the clients. This completes a round of forward propagation without sharing raw data. The gradients are then back propagated from the server’s last layer until the cut layer, and only these gradients are sent back to the radiology centers. The back propagation process is then completed at each radiology center. This process is repeated until the split learning network is trained without sharing any raw data between the clients. Illustrated by figure a above.

U-shaped configurations for split learning without label sharing: The two other configurations involve sharing labels, but they don’t share any raw input data with each other. To solve this issue, we can use a U-shaped configuration where clients don’t need to share labels. To do this, we can wrap the network around the end layers of the server’s network, and send the outputs back to client entities as shown in Figure b. The server retains most of its layers, while the clients generate gradients from the end layers and use them for back-propagation without sharing the corresponding labels. This is perfect for distributed deep learning in cases where labels contain highly sensitive information, like a patient’s disease status.

Vertically partitioned data for split learning: Figure c’s configuration allows multiple institutions with different patient data types to collaborate on learning distributed models without sharing raw data. In the image, radiology centers collaborate with pathology test centers and a server for disease diagnosis. Radiology centers & pathology test centers train their partial models up to the cut layer based on their respective patient data modalities. The outputs at the cut layer from both centers are merged and sent to the disease diagnosis server, which trains the rest of the model. This is repeated to complete the forward and backward propagations, resulting in a trained distributed deep learning model without any raw data sharing. While the example configurations above demonstrate the versatility of splitNN, there are other possible configurations.

SplitNN Configurations with and without label sharing

In essence, the training of a neural network (NN) is ‘split’ across two or more hosts. Each model segment is a self contained NN that feeds into the segment in front. In this example, Alice has unlabeled training data and the bottom of the network whereas Bob has the corresponding labels and the top of the network. The image below shows this training process where Bob has all the labels and there are multiple Alices with X data

. Once the first Alice has trained, she sends a copy of her bottom model to the next Alice, training is complete once all Alices have trained.

Overall, split learning only communicates activations and gradients just from the split layer unlike other popular methods that share weights/gradients from all the layers. Split learning requires no raw data sharing; either of labels or features.

Another type of split learning called NoPeek SplitNN also drastically reduces leakage due to any communicated activations by reducing their distance correlation with raw data while maintaining model performance through categorical cross-entropy (mathematical function used in ML to measure the difference between two probability distributions. It is commonly used in classification problems where the goal is to predict the probability distribution of a sample belonging to different classes).

Key technical idea: In the simplest of configurations of split learning, each client (for example, radiology center) trains a partial deep network up to a specific layer known as the cut layer. The outputs at the cut layer are sent to another entity (server/another client) which completes the rest of the training without looking at raw data from any client that holds the raw data. This completes a round of forward propagation without sharing raw data. The gradients are now back propagated again from its last layer until the cut layer in a similar fashion. The gradients at the cut layer (and only these gradients) are sent back to radiology client centers. The rest of back propagation is now completed at the radiology client centers. This process is continued until the distributed split learning network is trained without looking at each others raw data.

TRAINING A SplitNN

Process of Training a SplitNN

The training of a neural network is ‘split’ across one or more hosts. Each model segment is a self contained neural network that feeds into the segment in front. The following is a general outline:

  1. Divide the model into two parts: the client part and the server part. The client part can be a smaller part of the model that can be run on a user’s device, such as their laptop. The server part is the larger part that is run on a more powerful server.
  2. Distribute the data between the client and server. The client will typically have access to the user’s data, while the server will have access to a large dataset to train the model.
  3. Train the client part of the model on the user’s data, and send the output of the model to the server.
  4. Train the server part of the model on the output from the client. The server will use this output, along with its larger dataset, to improve the model’s accuracy.
  5. Send the updated model back to the client for the next round of training.
  6. Repeat steps 3–5 for a fixed number of iterations, until the model’s accuracy reaches an ideal level.

Applications?

There are also multiple applications of such technology, an example is reconstruction attacks.

Reconstruction attack is a type of privacy attack where an adversary tries to reconstruct the sensitive information that was used to train a machine learning model. This can be done by analyzing the model’s output, gradients or weights.

In split learning, instead of training a model on a centralized server, the data is split into different parts and processed locally on the devices or nodes. The model is divided into two parts: a client part and a server part.

In split learning, the client device processes the input data and sends the output to the server, where the final model is trained. The server then sends the updated model back to the client for the next round of training. The data never leaves the client device, which helps to protect the privacy of the user’s data.

Since the sensitive data never leaves the client device, split learning makes it challenging for an adversary to launch a reconstruction attack. The client device only sends the output of the model, not the weights or gradients, making it difficult for an adversary to reconstruct the original data used for training.

Reconstruction attacks

There you have it, a new tool in privacy-preserving machine learning that is not too computational complex and does not require a lot of network resources, while preserving privacy, making it more efficient than other methods like homomorphic encryption.

--

--