Fully Sharded Data Parallel (FSDP) Theory of Operations

FSDP enables data parallelism training through the following steps:

  1. Init step - Full parameter sharding, where only a subset of the model parameters, gradients, and optimizers is needed for a local computation.

  2. Forward step:

    1. All weights are locally gathered from the Gaudi devices with all_gather.

    2. Forward pass is calculated.

    3. The all_gather is done again before before the backward pass.

  3. Backward step:

    1. Backward pass is calculated.

    2. The local gradients are averaged and sharded across the Gaudi devices with reduce_scatter. This allows each Gaudi to update its local weight shard.

DDP and FSDP

The below image demonstrates standard data parallel training and fully sharded data parallel training processes.

../../_images/fsdp_flow_1.png

DDP and FSDP Flows

  • In DDP, each device holds a full copy of the model, computing on a shard of data before sharing parameters globally. By using all_reduce, it sums gradients across workers and distributes them equally.

  • FSDP distributes only parts of the model across Gaudis, with weights gathered before both forward and backward passes, enabling localized weight updates.

In FSDP, all_reduce is broken down into reduce_scatter and all_gather. During the backward pass, it reduces and scatters gradients across ranks. Then it updates the respective shard of the parameters. In subsequent forward pass, FSDP performs all_gather to collect and combine updated parameter shards.

../../_images/all_reduce_fsdp.JPG

FSDP Allreduce