Using Fully Sharded Data Parallel (FSDP) with Intel Gaudi

In DistributedDataParallel (DDP) training, each process/worker owns a replica of a model and processes a batch of data. When using DDP, model weights and optimizer states are replicated across all workers.

Fully Sharded Data Parallel (FSDP) is a type of data parallel training supported by Intel® Gaudi® 2 AI accelerator for running distributed training on large-scale models. It shards model parameters, optimizer states and gradients across all ranks. For more details, refer to Fully Sharded Data Parallel (FSDP) Theory of Operations. FSDP reduces Intel Gaudi’s memory footprint because each worker does not hold a full copy of a model compared to DDP.

Intel Gaudi FSDP integration can be executed using Eager mode with torch.compile.

Running a Simple FSDP Example on Intel Gaudi

The toy_example.py below is a simple example that demonstrates the FSDP distributed training on Gaudi. This example uses two Gaudis on one node. If you want to run it on eight Gaudis, use --ws 8 flag.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import torch
import argparse
import torch.nn as nn
from torch.nn import Linear
from torch.optim import SGD
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import habana_frameworks.torch.distributed.hccl

os.environ["PT_HPU_LAZY_MODE"] = "0"
device_hpu = torch.device('hpu')

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    import habana_frameworks.torch.distributed.hccl
    dist.init_process_group(backend='hccl', rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.lin1 = Linear(3, 3, bias=False)
    def forward(self, x):
        return (self.lin1(x))

def simple_demo(rank, world_size, args):
    setup(rank, world_size)
    model = ToyModel().to(device_hpu)
    input = torch.rand(8, 3)

    model = FSDP(model, device_id = device_hpu)
    optim = SGD(model.parameters(), lr=0.1)
    in_data = torch.Tensor(input[rank]).to(device_hpu)
    for i in range(args.iterations):
        out = model(in_data)
        out.float().sum().backward()
        optim.step()
        optim.zero_grad()
    cleanup()
    print("All done")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Simple_demo Example')
    parser.add_argument('--iterations', type=int, default=5, metavar='I',  help='iterations (default: 5)')
    parser.add_argument('--ws', type=int, default=2, help='world size (default: 2)')
    parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbosity")

    args = parser.parse_args()
    if args.verbose:
        os.environ["TORCH_CPP_LOG_LEVEL"]="INFO"
        os.environ["TORCH_DISTRIBUTED_DEBUG"]="DETAIL"
        os.environ["TORCH_SHOW_CPP_STACKTRACES"]="1"
        torch._dynamo.config.verbose=True
    WORLD_SIZE = args.ws
    mp.spawn(simple_demo,
        args=(WORLD_SIZE, args),
        nprocs=WORLD_SIZE,
        join=True)

The Intel Gaudi-specific lines are explained below.

  • Line 10 - Import habana_frameworks.torch.distributed.hccl:

import habana_frameworks.torch.distributed.hccl
  • Line 12 - Enable Eager mode:

os.environ["PT_HPU_LAZY_MODE"] = "0"
  • Line 13 - Target the Gaudi device:

device_hpu = torch.device('hpu')
  • Line 33 - Target the Gaudi device for the model execution:

model = ToyModel().to(device_hpu)
  • Line 36 - Target FSDP execution with Gaudi:

model = FSDP(model, device_id = device_hpu)

Executing the Example

Execute the toy_example.py by running:

python3 toy_example.py

Note

  • FSDP cannot be executed on Intel Gaudi using Lazy mode.

  • Using FSDP may increase communication volume.

  • If you encounter stability issues when running your model with FSDP, set PT_HPU_EAGER_PIPELINE_ENABLE=false flag. Note that this may affect performance.

Supported Features

The following outlines PyTorch FSDP features supported by Intel Gaudi. Some of them are optional and can be used according to your requirements.

Feature

Options

Description

module

N/A

Module that FSDP wraps.

process_group (optional)

N/A

Process group for model sharding.

sharding_strategy (optional)

  • FULL_SHARD

  • SHARD_GRAD_OP

  • NO_SHARD

  • HYBRID_SHARP

Configures the sharding strategy.

cpu_offload (optional)

  • TRUE

  • FALSE

Configures CPU offloading.

auto_wrap_policy (optional)

N/A

Enables a policy that applies FSDP to submodules of module.

backward_prefetch (optional)

  • BACKWARD_PRE

  • BACKWARD_POST

Configures explicit backward prefetching of all-gathers.

mixed_precision (optional)

  • param_dtype (torch.dtype)

  • recude_dtype (torch.dtype)

  • buffer_dtype (torch.dtype)

  • keep_low_precision_grads (bool)

  • cast_forward_inputs (bool)

  • cast_root_forward_inputs (bool)

Configures native mixed precision for FSDP.

ignored_modules (optional)

N/A

Avoids sharding specific parameters at module granularity when using auto_wrap_policy.

param_init_fn (optional)

N/A

Specifies how meta device modules should be initialized onto an actual device.

device_id (optional)

N/A

An int or torch.device setting the device to run FSDP on.

sync_module_states

  • TRUE

  • FALSE

If True, each FSDP module broadcasts module parameters and buffers from rank 0 to ensure that they are replicated across ranks.

forward_prefetch

  • TRUE

  • FALSE

If True, FSDP prefetches the next forward-pass all_gather before the current forward computation.

limit_all_gathers

  • TRUE

  • FALSE

If True, FSDP synchronizes the CPU thread to ensure Gaudi memory usage from only two consecutive FSDP instances.

use_orig_params

  • TRUE

  • FALSE

If True, FSDP uses module’s original parameters.

ignored_states (optional)

N/A

Ignores parameters or modules that are not managed by the FSDP instance.

state_dict

  • FULL_STATE_DICT

  • LOCAL_STATE_DICT

  • SHARDED_STATE_DICT

Dictionary object that stores model’s parameters.

Activation checkpoint

N/A

Reduces memory consumption by trading off computation for memory.

For more details on the FSDP parameters and their usage, refer to PyTorch documentation.

Supported Models

For an example model using FSDP on Intel Gaudi, see LLaMA 2 70B.