Using Fully Sharded Data Parallel (FSDP) with Intel Gaudi
On this Page
Using Fully Sharded Data Parallel (FSDP) with Intel Gaudi¶
In DistributedDataParallel (DDP) training, each process/worker owns a replica of a model and processes a batch of data. When using DDP, model weights and optimizer states are replicated across all workers.
Fully Sharded Data Parallel (FSDP) is a type of data parallel training supported by Intel® Gaudi® 2 AI accelerator for running distributed training on large-scale models. It shards model parameters, optimizer states and gradients across all ranks. For more details, refer to Fully Sharded Data Parallel (FSDP) Theory of Operations. FSDP reduces Gaudi’s memory footprint because each worker does not hold a full copy of a model compared to DDP.
Intel Gaudi FSDP integration can be executed using Eager mode with torch.compile
.
Running a Simple FSDP Example on Gaudi¶
The toy_example.py
below is a simple example that demonstrates the FSDP distributed training on Gaudi.
This example uses two Gaudis on one node. If you want to run it on eight Gaudis, use --ws 8
flag.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | import os
import torch
import argparse
import torch.nn as nn
from torch.nn import Linear
from torch.optim import SGD
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import habana_frameworks.torch.distributed.hccl
os.environ["PT_HPU_LAZY_MODE"] = "0"
device_hpu = torch.device('hpu')
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
import habana_frameworks.torch.distributed.hccl
dist.init_process_group(backend='hccl', rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.lin1 = Linear(3, 3, bias=False)
def forward(self, x):
return (self.lin1(x))
def simple_demo(rank, world_size, args):
setup(rank, world_size)
model = ToyModel()
input = torch.rand(8, 3)
model = FSDP(model, device_id = torch.device("hpu", torch.hpu.current_device()))
optim = SGD(model.parameters(), lr=0.1)
in_data = torch.Tensor(input[rank]).to(device_hpu)
for i in range(args.iterations):
out = model(in_data)
out.float().sum().backward()
optim.step()
optim.zero_grad()
cleanup()
print("All done")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Simple_demo Example')
parser.add_argument('--iterations', type=int, default=5, metavar='I', help='iterations (default: 5)')
parser.add_argument('--ws', type=int, default=2, help='world size (default: 2)')
parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbosity")
args = parser.parse_args()
if args.verbose:
os.environ["TORCH_CPP_LOG_LEVEL"]="INFO"
os.environ["TORCH_DISTRIBUTED_DEBUG"]="DETAIL"
os.environ["TORCH_SHOW_CPP_STACKTRACES"]="1"
torch._dynamo.config.verbose=True
WORLD_SIZE = args.ws
mp.spawn(simple_demo,
args=(WORLD_SIZE, args),
nprocs=WORLD_SIZE,
join=True)
|
The Gaudi-specific lines are explained below:
Line 10 - Import
habana_frameworks.torch.distributed.hccl
:import habana_frameworks.torch.distributed.hccl
Line 12 - Enable Eager mode:
os.environ["PT_HPU_LAZY_MODE"] = "0"
Line 13 - Target the Gaudi device:
device_hpu = torch.device('hpu')
Line 36 - Target FSDP execution with Gaudi:
model = FSDP(model, device_id = torch.device("hpu", torch.hpu.current_device()))
For an example model using FSDP on Gaudi, see LLaMA 2 70B.
Executing the Example¶
Execute the toy_example.py
by running:
python3 toy_example.py
Note
FSDP cannot be executed on Gaudi using Lazy mode.
Using FSDP may increase communication volume.
If you encounter stability issues when running your model with FSDP, set
PT_HPU_EAGER_PIPELINE_ENABLE=false
flag. Note that this may affect performance.
Supported Features¶
The following outlines PyTorch FSDP features supported by Intel Gaudi. Some of them are optional and can be used according to your requirements. For more details on the FSDP parameters and their usage, refer to PyTorch documentation.
Feature |
Options |
Description |
---|---|---|
module |
N/A |
Module that FSDP wraps. |
process_group (optional) |
N/A |
Process group for model sharding. |
sharding_strategy (optional) |
|
Configures the sharding strategy. |
cpu_offload (optional) |
|
Configures CPU offloading. |
auto_wrap_policy (optional) |
N/A |
Enables a policy that applies FSDP to submodules of module. |
backward_prefetch (optional) |
|
Configures explicit backward prefetching of all-gathers. |
mixed_precision (optional) |
|
Configures native mixed precision for FSDP. |
ignored_modules (optional) |
N/A |
Avoids sharding specific parameters at module granularity when using auto_wrap_policy. |
param_init_fn (optional) |
N/A |
Specifies how meta device modules should be initialized onto an actual device. |
device_id (optional) |
N/A |
An |
sync_module_states |
|
If True, each FSDP module broadcasts module parameters and buffers from rank 0 to ensure that they are replicated across ranks. |
forward_prefetch |
|
If True, FSDP prefetches the next forward-pass all_gather before the current forward computation. |
limit_all_gathers |
|
If True, FSDP synchronizes the CPU thread to ensure Gaudi memory usage from only two consecutive FSDP instances. |
use_orig_params |
|
If True, FSDP uses module’s original parameters. |
ignored_states (optional) |
N/A |
Ignores parameters or modules that are not managed by the FSDP instance. |
state_dict |
|
Dictionary object that stores model’s parameters. |
Activation checkpoint |
N/A |
Reduces memory consumption by trading off computation for memory. |