DDP-based Scaling of Gaudi on PyTorch

mpirun Configuration

mpirun map-by PE attribute value may vary on your setup and should be calculated as: socket:PE = floor((number of physical cores) / (number of gaudi devices per each node)).

This sample code can also be used to calculate the number of physical CPU cores and HPU count to generate the appropriate PE value, shown as MPI_PE below. This can be incorporated into any model:

export PHY_CPU_COUNT=$(lscpu --all --parse=CORE,SOCKET | grep -Ev "^#" | sort -u | wc -l)
export PHY_HPU_COUNT=$(ls /dev/hl? | wc -l)
export MPI_PE=$(($PHY_CPU_COUNT/$PHY_HPU_COUNT))

The PE value in the Model-References examples may be set to a common number to ensure functionality, but depending on the Host CPU, the directions above should be used for optimal system performance.

Scale-up Using Gaudi NICs Within a Server

The below is a simple example, example.py showing distributed training support. HCCL communication backend is loaded and process group communication backend is initialized as “hccl” using the following script changes.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import os
import torch
import torch.distributed as dist
import habana_frameworks.torch.core as htcore

torch.manual_seed(0)
#load hpu backend for PyTorch
device = torch.device('hpu')

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12340'
    #distributed package for HCCL
    import habana_frameworks.torch.distributed.hccl
    dist.init_process_group(backend='hccl', rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def allReduce():
    _tensor = torch.ones(10).to(device)
    torch.distributed.all_reduce(_tensor)
    _tensor_cpu = _tensor.cpu()

def main(rank, world_size):
    setup(rank, world_size)

    for i in range(100):
        allReduce()

    cleanup()

The code above runs in multiple processes, one for each Gaudi.

In order to launch the distributed training for eight Gaudi devices within one host run the following command:

$ mpirun -np 8 python3 example.py

Note

Open MPI is required for host communication and launching processes.

Scale-out Across Servers

Scale-out Using AWS DL1/Host NICs

Environment variable HCCL_OVER_OFI=1 must be set to enable multi-server in HCCL using Libfabric.

Set the below to the network interface name or subnet that will be used by HCCL to communicate:

export HCCL_SOCKET_IFNAME=interface_name
export HCCL_OVER_OFI=1

Scale-out Using Gaudi NICs

Run the script using the below command. You must append the IP addresses of two servers at the end of the script and the addresses should be separated by commas, similar to the example below:

mpirun --allow-run-as-root  \
    -x LD_LIBRARY_PATH=".."  \
    -x HABANA_LOGS=".."  \
    -x GC_KERNEL_PATH=".."  \
    --prefix /usr/local/openmpi  \
    --mca btl_tcp_if_include 10.211.160.0/16  \
    -x MASTER_ADDR=10.211.160.154  \
    -x MASTER_PORT=12345  \
    --mca plm_rsh_args "-p 3022"  \
    -H 10.211.160.154,10.211.160.52 \
    -n 16  \
    ...
    python3 example.py

By default, the shell scripts connect to port 3022, however, the port listened by the SSH server may differ between different environments. If your environment requires specifying a different port of the remote SSH server, you can use the SSHD_PORT environment variable.

To change the port, use the below command. Make sure to set the port to 22 as in the below example.

$ /etc/init.d/ssh restart '-p 22'

For further details, refer to the README located in PyTorch Model References GitHub page.

HCCL with Distributed Data Parallel (DDP) Hook

The example below, from PyTorch Documentation, includes the Habana-specific modifications to showcase the initialization and usage of HCCL package with PyTorch’s DDP hook, ddp_model = DDP(model).

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#reference https://pytorch.org/tutorials/intermediate/ddp_tutorial.html

import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel as DDP

import habana_frameworks.torch.core as htcore

device = torch.device('hpu')

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    os.environ["ID"] = str(rank)
    #distributed package for HCCL
    import habana_frameworks.torch.distributed.hccl
    dist.init_process_group(backend='hccl', rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    model = ToyModel().to(device)
    ddp_model = DDP(model)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10).to(device))
    labels = torch.randn(20, 5).to(device)
    loss_fn(outputs, labels).backward()
    optimizer.step()

    cleanup()


def run_demo(demo_fn, world_size):
    mp.spawn(demo_fn,
             args=(world_size,),
             nprocs=world_size,
             join=True)

if __name__ == "__main__":
    world_size = 8
    run_demo(demo_basic, world_size)

Advanced Usage: ZeroRedundancyOptimizer with DDP

The example below, from PyTorch Documentation includes the Habana-specific modifications to showcase the initialization and usage of HCCL package with PyTorch’s DDP hook and custom fused optimizer.

The fused optimizer is wrapped around with the functional optimizer class so that its accessible by zero1. For further details on custom optimizers, refer to Custom Optimizers.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#reference https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html
#usage: NATIVE=0 OVERLAP=1 python -u testZero1.py

import os
import time
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.nn.parallel import DistributedDataParallel as DDP
from habana_frameworks.torch.hpex.optimizers import FusedAdamW
from torch.distributed.algorithms.ddp_comm_hooks.ddp_zero_hook import (
    hook_with_zero_step,
    hook_with_zero_step_interleaved,
)
from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import (
    allreduce_hook,
)

import habana_frameworks.torch.core as htcore

#To register a functional optimizer, import the optimizer and
#invoke register_functional_optim(key,optimizer) from torch.distributed.optim.utils
#to register the optimizer
from habana_frameworks.torch.hpex.optimizers.distributed import FusedAdamW as FunctionalFusedAdamW
from torch.distributed.optim.utils import register_functional_optim
register_functional_optim(FunctionalFusedAdamW,FunctionalFusedAdamW)

NATIVE=0
use_native = int(os.environ['NATIVE'])
OVERLAP=0
use_overlap = int(os.environ['OVERLAP'])
WARMUP_STEPS=2

device = torch.device('hpu')
torch.manual_seed(0)
input=[torch.randn(2000, 2000),torch.randn(2000, 2000)]
label= [torch.randn(2000, 2000),torch.randn(2000, 2000)]

def example(rank, world_size, use_zero):
    torch.manual_seed(0)
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    # create default process group
    import  habana_frameworks.torch.distributed.hccl
    dist.init_process_group("hccl", rank=rank, world_size=world_size)

    # create local model
    model = nn.Sequential(*[nn.Linear(2000, 2000).to(device) for _ in range(20)])
    model = model.to(device)
    # construct DDP model
    import copy
    ddp_model = DDP(copy.deepcopy(model).to(device), bucket_cap_mb=10000*1024*1024, gradient_as_bucket_view=True)

    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    if use_zero:
        optimizer = ZeroRedundancyOptimizer(
            ddp_model.parameters(),
            optimizer_class=torch.optim.Adam if use_native else FunctionalFusedAdamW,
            lr=0.01,
            overlap_with_ddp=True if use_overlap else False,
            weight_decay=1e-2,
            eps = 1e-8
       )
        if use_overlap:
            print("registering comm hook")
            ddp_model.register_comm_hook(
                None,
                hook_with_zero_step_interleaved(allreduce_hook, ddp_model, optimizer, shard_buckets=True)
            )
    else:
        optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01,weight_decay=1e-2,eps = 1e-8) if use_native else FusedAdamW(ddp_model.parameters(), lr=0.01,weight_decay=1e-2,eps = 1e-8)
    i=1
    while i!=10:
        # forward pass
        outputs = ddp_model(input[rank].to(device))
        labels = label[rank].to(device)
        # backward pass
        loss = loss_fn(outputs, labels)
        loss.backward()
        #break the graph
        htcore.mark_step()
        # update parameters
        if (((not use_zero) or (use_zero and not use_overlap)) and (i>WARMUP_STEPS)): #2 warm-up steps
            optimizer.step()
            htcore.mark_step()
        if rank == 0:
            print(" loss for step ",i, " is ", loss.to("cpu"))
        i=i+1

    print(f"params sum is: {sum(model.parameters()).sum()}")

def main():
    world_size = 2
    print("=== Using ZeroRedundancyOptimizer ===")
    start_time = time.time()
    mp.spawn(example,
        args=(world_size, True),
        nprocs=world_size,
        join=True)
    print("Time : ",time.time()-start_time)
    print("=== Not Using ZeroRedundancyOptimizer ===")
    start_time = time.time()
    mp.spawn(example,
        args=(world_size, False),
        nprocs=world_size,
        join=True)
    print("Time : ",time.time()-start_time)

if __name__=="__main__":
    main()

Examples of Real-world Applications

Examples of real-world applications, such as ResNet-50, BERT, and other models, including their performance and results can be found in the PyTorch Models GitHub page with specific README and examples.