DDP-based Scaling of Gaudi on PyTorch

mpirun Configuration

mpirun map-by PE attribute value may vary on your setup and should be calculated as: socket:PE = floor((number of physical cores) / (number of gaudi devices per each node)).

This sample code can also be used to calculate the number of physical CPU cores and HPU count to generate the appropriate PE value, shown as MPI_PE below. This can be incorporated into any model:

export PHY_CPU_COUNT=$(lscpu --all --parse=CORE,SOCKET | grep -Ev "^#" | sort -u | wc -l)
export PHY_HPU_COUNT=$(ls /dev/hl? | wc -l)
export MPI_PE=$(($PHY_CPU_COUNT/$PHY_HPU_COUNT))

The PE value in the Model References examples may be set to a common number to ensure functionality, but depending on the Host CPU, the directions above should be used for optimal system performance.

Scale-up Using Gaudi NICs Within a Server

The below is a simple example, hccl_example.py, showing distributed training support. HCCL communication backend is loaded and process group communication backend is initialized as hccl using the following script changes. The code below runs in multiple processes, one for each Gaudi:

 1import os
 2import torch
 3import habana_frameworks.torch.core as htcore
 4import platform
 5
 6torch.manual_seed(0)
 7#load hpu backend for PyTorch
 8device = torch.device('hpu')
 9
10def setup(rank, world_size):
11    os.environ['MASTER_ADDR'] = 'localhost'
12    os.environ['MASTER_PORT'] = '12340'
13    #Import the distributed package for HCCL, set the backend to HCCL
14    import habana_frameworks.torch.distributed.hccl
15    torch.distributed.init_process_group(backend='hccl', rank=rank, world_size=world_size)
16
17def cleanup():
18    torch.distributed.destroy_process_group()
19
20def allReduce(rank):
21    _tensor = torch.ones(8).to(device)
22    torch.distributed.all_reduce(_tensor)
23    _tensor_cpu = _tensor.cpu()
24
25def run_allreduce(rank, world_size):
26    setup(rank, world_size)
27
28    for i in range(100):
29        allReduce(rank)
30
31    cleanup()
32
33def main():
34    #Run Habana's Initialize HPU function to collect the world size and rank
35    from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
36    world_size, rank, local_rank = initialize_distributed_hpu()
37    run_allreduce(rank, world_size)
38
39if __name__ == '__main__':
40    main()

To launch distributed training for eight Gaudi devices within one host, run the following command:

$ mpirun --allow-run-as-root -np 8 python3 hccl_example.py

Note

Open MPI is required for host communication and launching processes.

Scale-out Across Servers

Scale-out Using Host NICs

To use scale-out over Host NICs, install the OFI Wrapper and libfabric as detailed in Scale-out via Host NIC over OFI. The OFI Wrapper and libFabric are installed by default when using Intel Gaudi containers.

Scale-out over Host NICs is enabled with Gaudi Direct functionality. For further information on how the network setting is detected, refer to Scale-out via Host NIC. Follow the below steps to use Gaudi Direct:

  1. Ensure LD_LIBRARY_PATH is set with usr/lib/habanalabs included.

  2. Set the RDMAV_FORK_SAFE=1 and FI_EFA_USE_DEVICE_RDMA=1 variables. These variable are set by default in the Intel Gaudi Docker image.

  3. Run mpirun command with RDMAV_FORK_SAFE=1 and FI_EFA_USE_DEVICE_RDMA=1:

    mpirun --allow-run-as-root  \
       -x RDMAV_FORK_SAFE=1   \
       -x FI_EFA_USE_DEVICE_RDMA=1 \
       ...
    
  1. Ensure LD_LIBRARY_PATH is set with usr/lib/habanalabs included.

  2. Enable fork safety by setting RDMAV_FORK_SAFE=1 environment variable. If your system configuration uses huge pages, also set RDMAV_HUGEPAGES_SAFE=1.

  3. Set MLX5_SCATTER_TO_CQE=0 environment variable.

  4. Ensure PCIe Access Control (ACS) is disabled.

  5. Use Intel Gaudi proprietary libfabric.

  6. When building libfabric, use -with-synapseai configuration option.

  7. Run mpirun command:

    mpirun --allow-run-as-root  \
       -x RDMAV_FORK_SAFE=1   \
       -x RDMAV_HUGEPAGES_SAFE=1 \
       -x MLX5_SCATTER_TO_CQE=0 \
    

Note

Gaudi Direct with verbs provider is supported on Gaudi 3 and Gaudi 2.

Scale-out Using Gaudi NICs

Run the script using the below command. You must append the IP addresses of two servers at the end of the script and the addresses should be separated by commas, similar to the example below:

mpirun --allow-run-as-root  \
    -x LD_LIBRARY_PATH=".."  \
    -x HABANA_LOGS=".."  \
    -x GC_KERNEL_PATH=".."  \
    --prefix /usr/local/openmpi  \
    --mca btl_tcp_if_include 10.211.160.0/16  \
    -x MASTER_ADDR=10.211.160.154  \
    -x MASTER_PORT=12345  \
    --mca plm_rsh_args "-p 3022"  \
    -H 10.211.160.154,10.211.160.52 \
    -n 16  \
    ...
    python3 hccl_example.py

By default, the shell scripts connect to port 3022, however, the port listened by the SSH server may differ between different environments. If your environment requires specifying a different port for the remote SSH server, you can use the SSHD_PORT environment variable.

To change the port, use the below command. Make sure to set the port to 22 as in the below example.

$ /etc/init.d/ssh restart '-p 22'

Note

You may need to add the following parameter to the mpirun command in your setup. The --mca pml specifies the PML component to use. For example, you can use --mca pml ob1. Some systems may generate the following error if this flag is not included:

selected pml ob1, but peer selected pml cm

For further details, refer to the README located in PyTorch Model References GitHub page.

HCCL with Distributed Data Parallel (DDP) Hook

The example below, from PyTorch Documentation , includes the Gaudi-specific modifications to showcase the initialization and usage of HCCL package with PyTorch’s DDP hook, ddp_model = DDP(model).

 1#reference https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
 2
 3import os
 4import sys
 5import tempfile
 6import torch
 7import torch.distributed as dist
 8import torch.nn as nn
 9import torch.optim as optim
10import torch.multiprocessing as mp
11
12from torch.nn.parallel import DistributedDataParallel as DDP
13
14import habana_frameworks.torch.core as htcore
15
16device = torch.device('hpu')
17
18def setup(rank, world_size):
19    os.environ['MASTER_ADDR'] = 'localhost'
20    os.environ['MASTER_PORT'] = '12355'
21    os.environ["ID"] = str(rank)
22    #distributed package for HCCL
23    import habana_frameworks.torch.distributed.hccl
24    dist.init_process_group(backend='hccl', rank=rank, world_size=world_size)
25
26def cleanup():
27    dist.destroy_process_group()
28
29class ToyModel(nn.Module):
30    def __init__(self):
31        super(ToyModel, self).__init__()
32        self.net1 = nn.Linear(10, 10)
33        self.relu = nn.ReLU()
34        self.net2 = nn.Linear(10, 5)
35
36    def forward(self, x):
37        return self.net2(self.relu(self.net1(x)))
38
39
40def demo_basic(rank, world_size):
41    print(f"Running basic DDP example on rank {rank}.")
42    setup(rank, world_size)
43
44    model = ToyModel().to(device)
45    ddp_model = DDP(model)
46
47    loss_fn = nn.MSELoss()
48    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
49
50    optimizer.zero_grad()
51    outputs = ddp_model(torch.randn(20, 10).to(device))
52    labels = torch.randn(20, 5).to(device)
53    loss_fn(outputs, labels).backward()
54    optimizer.step()
55
56    cleanup()
57
58
59def run_demo(demo_fn, world_size):
60    mp.spawn(demo_fn,
61             args=(world_size,),
62             nprocs=world_size,
63             join=True)
64
65if __name__ == "__main__":
66    world_size = 8
67    run_demo(demo_basic, world_size)

Advanced Usage: ZeroRedundancyOptimizer with DDP

The example below, from PyTorch Documentation, includes the Gaudi-specific modifications to showcase the initialization and usage of HCCL package with PyTorch’s DDP hook and custom fused optimizer.

The fused optimizer is wrapped around with the functional optimizer class so that its accessible by ZeRO-1. For further details on custom optimizers, refer to Fused Optimizers.

  1#reference https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html
  2#usage: NATIVE=0 OVERLAP=1 python -u testZero1.py
  3
  4import os
  5import time
  6import torch
  7import torch.distributed as dist
  8import torch.multiprocessing as mp
  9import torch.nn as nn
 10import torch.optim as optim
 11from torch.distributed.optim import ZeroRedundancyOptimizer
 12from torch.nn.parallel import DistributedDataParallel as DDP
 13from habana_frameworks.torch.hpex.optimizers import FusedAdamW
 14from torch.distributed.algorithms.ddp_comm_hooks.ddp_zero_hook import (
 15    hook_with_zero_step,
 16    hook_with_zero_step_interleaved,
 17)
 18from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import (
 19    allreduce_hook,
 20)
 21
 22import habana_frameworks.torch.core as htcore
 23
 24#To register a functional optimizer, import the optimizer and
 25#invoke register_functional_optim(key,optimizer) from torch.distributed.optim.utils
 26#to register the optimizer
 27from habana_frameworks.torch.hpex.optimizers.distributed import FusedAdamW as FunctionalFusedAdamW
 28from torch.distributed.optim.utils import register_functional_optim
 29register_functional_optim(FunctionalFusedAdamW,FunctionalFusedAdamW)
 30
 31NATIVE=0
 32use_native = int(os.environ['NATIVE'])
 33OVERLAP=0
 34use_overlap = int(os.environ['OVERLAP'])
 35WARMUP_STEPS=2
 36
 37device = torch.device('hpu')
 38torch.manual_seed(0)
 39input=[torch.randn(2000, 2000),torch.randn(2000, 2000)]
 40label= [torch.randn(2000, 2000),torch.randn(2000, 2000)]
 41
 42def example(rank, world_size, use_zero):
 43    torch.manual_seed(0)
 44    os.environ['MASTER_ADDR'] = 'localhost'
 45    os.environ['MASTER_PORT'] = '29500'
 46    # create default process group
 47    import  habana_frameworks.torch.distributed.hccl
 48    dist.init_process_group("hccl", rank=rank, world_size=world_size)
 49
 50    # create local model
 51    model = nn.Sequential(*[nn.Linear(2000, 2000).to(device) for _ in range(20)])
 52    model = model.to(device)
 53    # construct DDP model
 54    import copy
 55    ddp_model = DDP(copy.deepcopy(model).to(device), bucket_cap_mb=10000*1024*1024, gradient_as_bucket_view=True)
 56
 57    # define loss function and optimizer
 58    loss_fn = nn.MSELoss()
 59    if use_zero:
 60        optimizer = ZeroRedundancyOptimizer(
 61            ddp_model.parameters(),
 62            optimizer_class=torch.optim.Adam if use_native else FunctionalFusedAdamW,
 63            lr=0.01,
 64            overlap_with_ddp=True if use_overlap else False,
 65            weight_decay=1e-2,
 66            eps = 1e-8
 67       )
 68        if use_overlap:
 69            print("registering comm hook")
 70            ddp_model.register_comm_hook(
 71                None,
 72                hook_with_zero_step_interleaved(allreduce_hook, ddp_model, optimizer, shard_buckets=True)
 73            )
 74    else:
 75        optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01,weight_decay=1e-2,eps = 1e-8) if use_native else FusedAdamW(ddp_model.parameters(), lr=0.01,weight_decay=1e-2,eps = 1e-8)
 76    i=1
 77    while i!=10:
 78        # forward pass
 79        outputs = ddp_model(input[rank].to(device))
 80        labels = label[rank].to(device)
 81        # backward pass
 82        loss = loss_fn(outputs, labels)
 83        loss.backward()
 84        #break the graph
 85        htcore.mark_step()
 86        # update parameters
 87        if (((not use_zero) or (use_zero and not use_overlap)) and (i>WARMUP_STEPS)): #2 warm-up steps
 88            optimizer.step()
 89            htcore.mark_step()
 90        if rank == 0:
 91            print(" loss for step ",i, " is ", loss.to("cpu"))
 92        i=i+1
 93
 94    print(f"params sum is: {sum(model.parameters()).sum()}")
 95
 96def main():
 97    world_size = 2
 98    print("=== Using ZeroRedundancyOptimizer ===")
 99    start_time = time.time()
100    mp.spawn(example,
101        args=(world_size, True),
102        nprocs=world_size,
103        join=True)
104    print("Time : ",time.time()-start_time)
105    print("=== Not Using ZeroRedundancyOptimizer ===")
106    start_time = time.time()
107    mp.spawn(example,
108        args=(world_size, False),
109        nprocs=world_size,
110        join=True)
111    print("Time : ",time.time()-start_time)
112
113if __name__=="__main__":
114    main()

Examples of Real-world Applications

Examples of real-world applications, such as ResNet50, BERT, and other models, including their performance and results can be found in the PyTorch Model References GitHub page with specific READMEs and examples.