Exploring Model Decomposition Strategies in FSDP

Dwaha Daud
Princeton Systems Course
11 min readMay 8, 2024

Jipeng Sun & Dwaha Daud

1. Introduction

Neural networks are growing at a fast rate, with state-of-the-art models such as GPT-3 consisting of hundreds of billions of parameters. Training such models is extremely costly. To mitigate these costs, the latest techniques seek to train large models efficiently by using various forms of parallelism.

Fully Sharded Data Parallel (FSDP) is one such technique which combines model sharding with data parallelism to achieve high efficiency in training. Data parallelism means that the training data is split into subsets, and then each GPU trains the model on a different subset in parallel. FSDP is implemented as a library within Pytorch, an open-source machine-learning library.

FSDP builds on the contributions of Distributed Data Parallel (DDP), another library within Pytorch, to further extend the sizes of models which can be trained. In DDP, the entire model is replicated on each GPU in the system. Each GPU trains its copy of the model with a subset of the training data. There is periodic communication between GPUs to synchronize parameter values. In the end, each GPU ends up with a copy of the same, fully-trained model. One major limitation of DDP is that the size of the model being trained must be able to fit in the memory of a single GPU. On a GPU with 40GB of memory, this threshold can be exceeded for models with 1 billion parameters, leading to out-of-memory errors.

FSDP works by splitting a model into different units, and sharding each unit across each GPU in the system. A “unit” is a collection of one or more layers of the model. In this project, we explored different strategies that users can use to break up a model into units. We explore the implications of different decisions on training performance.

The rest of this post proceeds as follows. First we describe how the FSDP algorithm works, along with a toy example to illustrate the sharding of a model across different GPUs. Next, we present our implementation of different unit breakdowns of a model, followed by an evaluation of each strategy. Finally, we conclude with some remarks on the results we observed.

2. The FSDP Algorithm

2.1 PyTorch Background

To understand how FSDP works, it is important to understand how Pytorch represents neural networks. Two important object types in a neural network are Tensors and Modules. Pytorch uses Tensors to store values. Tensors are n-dimensional arrays that are able to take advantage of parallel computation on GPUs. Each Tensor is stored on a specific device e.g a GPU.

Layers in a neural network are stored as Module objects. A Module describes a transformation from input values to output values. Modules are made up of parameters, with each parameter being represented as a Tensor. A Linear layer, for example, has parameters that represent a linear function e.g (a * x1) + (b * x_2) + c

Here, a and b represent “weight” parameters stored in the neural network, the x values refer to input data, and c represents a “bias” parameter. Note that all of these values can be multidimensional. A Linear layer can have several nodes, with each node having its own linear function and its own parameters. During forward computation, i.e when the neural network is transforming an input value to an output, the Linear layer applies these parameters to the input data it receives.

The neural network as a whole is itself a Module object, which describes how input data is transformed to a predicted output value.

2.2 Overview of FSDP

FSDP works by decomposing a model into smaller units, where each unit is a collection of one or more layers in the model. Each unit is then sharded evenly across all the GPUs in the system. The idle state of the model is that each unit remains fully sharded across each GPU.

A unit that is split across GPUs can’t be used to compute a value.When it is time for a particular unit to do a computation, either in the forward or backward pass, the full unit is first materialized on each GPU. This means that each GPU sends its shard of the unit to all other GPUs. When each GPU has the full unit in its memory, all the GPUs perform the required computation in parallel on their respective inputs. This input could be the input data to the model, or input from the previous unit.

When the computation is complete, each GPU reshards the unit. This involves freeing the memory for the unit shards that came from other GPUs only keeping the shard that it was originally responsible for.

Below is a visual example of how a unit is sharded and resharded across GPUs. Say we have a unit that wraps four layers as shown in Figure 1. In a system with 3 GPUs, this unit will be decomposed into 3 shards, with each GPU storing one shard. Each highlighted section of the unit represents a different shard.

Figure 1: An FSDP unit split into 3 shards

When the unit is in its sharded state, a simple diagram of GPU memory, showing only the unit shards, would look as shown in Figure 2.

Figure 2: Memory diagram for an FSDP unit sharded across 3 GPUs

When the unit needs to perform a computation, say, in the forward pass, each GPU collects the other shards of the unit and performs the computation for the unit on their respective inputs. When the computation is complete, the unit is resharded so GPU memory reverts to the state in Figure 2.

When the model as a whole is decomposed into several units, each unit is sharded as described above. The units then perform computations in sequence, either in forwards or backwards order. Each unit is only fully materialized on each GPU when it is time to perform the computation for that unit. Units are kept sharded otherwise.

With data parallelism, each GPU is operating on different input data, so each GPU predicts a different output value during the forward pass. Without any synchronization, parameters would be updated differently across GPUs because the update depends on what the predicted output was. To prevent this, there is an extra synchronization step in the backwards pass.

2.3 How units are sharded

This section describes how a unit is sharded in more detail. Say we have an FSDP unit that contains multiple layers. Each of these layers contains parameter values. For example, a Linear layer has weight and bias parameters across several nodes as described in section 2.1.

To shard this unit, all of the parameters across all of the layers are flattened. Flattening turns an n-dimensional tensor into a 1-dimensional tensor e.g a parameter that’s a 2-D matrix will become a 1-D array. All of the flattened parameters are concatenated together into one big “flat parameter”. It is this flat parameter that gets sharded across several GPUs, with each GPU taking an equal chunk of the flat parameter. If the flat parameter can’t be divided evenly across all the GPUs, extra padding is added to the flat parameter until it can be evenly divided.

2.4 Communication optimizations

A model being sharded across GPUs creates a communication overhead. The communication overhead negatively impacts training efficiency. To mitigate it, one central optimization that FSDP makes is to overlap computation and communication.

As we saw above, the communication mainly takes place when GPUs need to unshard a unit, and to synchronize a unit’s parameters in the backwards pass. Since communication is contained within units, there is no dependence between the computation of one unit and the communication required for another unit. This enables GPUs to prefetch the parameters needed for the next unit in the sequence, while still computing the output of the current unit. This prefetching optimization can be applied in both the forward and backward pass.

3. Experiment Design

Sharding a model comes with a communication overhead. This communication comes from units gathering parameters onto each GPU. As a first approximation, one might imagine that in order to reduce the communication overhead, we would want to have as few units as possible i.e make units wrap as many layers as possible. This comes with a clear memory penalty since units will occupy more space on each GPU when they are fully materialized. However, we want to investigate the impact that this approach would have on training performance itself. To understand this impact, we need to dig deeper into the FSDP mechanism from a distributed system perspective.

Communication between GPUs in FSDP happens through two main collective communication operations: AllGather and ReduceScatter. AllGather is used to gather parameter shards onto all the GPUs, thereby materializing a full unit on each GPU. ReduceScatter is used to synchronize gradients across GPUs, so that each GPU can update parameters in a way that’s consistent across GPUs.

Figure 3 shows how these two operations are used in the FSDP workflow. In the forwards pass, an AllGather operation is issued to collect the full parameters of the unit onto each GPU. Once the AllGather completes, the unit can perform the forward computation using some input data. Afterwards, the unit is resharded, with each GPU only remaining with its own shard of the unit. This unit needs to be AllGathered again for the backward pass and then, after gradients are calculated for each parameter on each GPU, ReduceScatter is called to sync the gradients globally.

Figure 3: The FSDP Workflow [2]

To illustrate the tradeoff that comes with wrapping more layers into a single unit, we will look deeper at the computation/communication overlapping logic of FSDP. Figure 4 illustrates the three streams of an FSDP training process: CPU Computation, GPU Computation, and GPU Communication. We will focus on the GPU computation and communication streams.

Figure 4: CPU stream and GPU communication and computation streams in FSDP training [1]

The forward computation of the current unit does not affect a later unit’s communication i.e the AllGather of the next unit’s parameters. This allows for the overlapping computation and communication: FSDP can schedule the AllGather operation for the next unit while the previous unit is still computing. For example, in Figure 4, we see that FWD0, the forward computation of unit 0, overlaps with AG1, the AllGather operation for unit 1. Note that there should be a limit to how many advance AllGathers can be done at a time because each materialized unit uses up GPU memory. FSDP implements a rate-limit on AllGathers to prevent all the free memory being used up. Also note that the computation operations of the current unit can only start after the AllGather for that unit completes.

Figure 4 hints at the reason why wrapping as many layers as possible in a unit might need to be reconsidered. Usually, the communication collective operations take more time than GPU computation operations. Thus, wrapping more weights in one unit won’t influence the training speed of FSDP as the communication time is the bottleneck. However, as more and more weights are wrapped into one unit, the computation time for each unit grows longer and longer. There might be a point for a model that the computation time of that unit will finally surpass the all-gather communication time for that unit. Then, wrapping as many layers as possible into one unit will stop improving performance while also consuming more GPU memory. That GPU memory could be better used for other optimization purposes (eg. prefetching, off-loading, model replication, etc). This problem would be more evident in a GPU cluster which has faster network connections but relatively low computation speed.

To evaluate this effect, we design an experiment on measuring the complete time of one epoch of training for different FSDP unit breakdowns when wrapping models. We will wrap more and more layers into each FSDP unit and see how it influences the task finishing time.

4. Implementation

In our implementation, we wanted to explore the effects of different unit decompositions on the training efficiency of a model. Currently, the user is left to decide their own policy on how a model should be broken up into units. This decision could potentially have a large impact on performance due to varying levels of memory usage and communication overhead. By measuring the effect of different strategies on performance, we investigate whether an optimal unit decomposition is possible and can therefore be automatically selected by FSDP.

We conduct our experiments on a 2-Nvidia 80GB A100 GPUs node with GPUs inter-connected by the Nvidia NVLink. The GPU specification is below:

Figure 5: GPU cluster information

The model we are using for FSDP is the BERT language encoder model with 12 Transformer layers. We measure the average training time for each epoch. Different FSDP unit wrapping policies represent the different number of layers being wrapped per-unit. The structure of the model is shown in Figure 6 below. Figure 6 also shows the entire Bert model wrapped in a single FSDP unit.

Figure 6: Structure of the Bert model. Entire model wrapped in one FSDP unit

5. Evaluation

Below, we present our results. The x-axis is the number of the FSDP units in the model. The smaller the number of units, the more layers each unit is wrapping. The y-axis is the training time for one epoch of training the model. As the result shows, there is a saturation for the curve when using fewer and fewer units i.e as the number of units decreases, training time is falling at a decreasing rate.

This suggests that the computation time for a unit is increasing to match the communication time. With even larger models, it is possible that the unit computation time becomes the bottleneck, which would halt the scaling effect of wrapping more and more layers in one unit.

Figure 7: Bert model training time for one epoch vs. number of FSDP units in the model

6. Conclusion

In summary, we explored the performance improvement gained from wrapping as many layers as possible into one unit. We analyzed this by looking at the tradeoff in computation vs. communication time as unit size grows. By conducting this experiment, we found evidence that the computation time for a unit has the potential to become a bottleneck in training. We believe that under different hardware hardware settings, and with much larger, state-of-the-art models, it is possible to halt the performance benefits gained from larger units.

Our code is available at this repo: https://github.com/JipengSun/fsdp

References

  1. Zhao, Y., Gu, A., Varma, R., Luo, L., Huang, C.-C., Xu, M., Wright, L., Shojanazeri, H., Ott, M., Shleifer, S., Desmaison, A., Balioglu, C., Damania, P., Nguyen, B., Chauhan, G., Hao, Y., Mathews, A., & Li, S. (2023). Pytorch FSDP: Experiences on scaling fully sharded data parallel. Proceedings of the VLDB Endowment, 16(12), 3848–3860. https://doi.org/10.14778/3611540.3611569
  2. Ott, M., Ott, M., Shleifer, S., Xu, M., Goyal, P., Duval, Q., & Caggiano, V. (2021, July 22). Fully sharded data parallel: Faster AI training with fewer gpus. Engineering at Meta. https://engineering.fb.com/2021/07/15/open-source/fsdp/

--

--