Using DistributedTensor with Intel Gaudi

DistributedTensor (DTensor) is a PyTorch extension that provides a distributed tensor abstraction. It allows to shard tensors across multiple devices and GPUs, and to perform operations on those tensors in a distributed manner. This can be useful for training large models on multiple machines, or for running inference on large datasets. DTensor is supported with Intel® Gaudi® 2 AI accelerator and standard Pytorch interfaces. For more details, refer to Distributed Tensor Theory of Operations. Intel Gaudi with DTensor is only supported in Eager mode or with torch.compile (PT_HPU_LAZY_MODE=0).

Note

  • This feature is currently experimental.

  • CustomOp is currently not supported.

  • Lazy mode is not supported.

  • Some ops do not support DTensor. Additional ops support will be available in future releases.

  • Reduction ops, -mean, are not supported by Gaudi.

Overview

DTensor provides a number of features that make it easy to work with distributed tensors. For example, it provides support for automatic data sharding and replication, as well as for efficient communication between devices. DTensor also provides a number of tools for debugging and monitoring distributed training.

In case of parallelism, there are different ways to spread data and compute across devices. DTensor achieves this using different types of distribution mechanisms:

  • Shard - Splits the tensor on the specified dimension across devices.

  • Replicate - Replicates the tensor across devices.

  • Partial - Stores partial values only in a tensor but maintains the global shape. This is used to aid in intermediate computations such as allreduce and compute.

The core building block for a DTensor is the logical device mesh and the placement strategy of tensor distributions:

  • Device Mesh - Describes the layout of devices which allows for the devices in a group to communicate during an operation. It is an n-dimensional array where the correct tensor would be placed.

  • PlacementSpec - Captures the different tensor distribution types such as shard, replicate, partial. It is used to describe how the tensor data is distributed across a specific dimension on devices as mentioned in device mesh.

Running a Simple Model Using DTensor on Gaudi

The toy_example.py below is a simple example that demonstrates the usage of distributed tensor on Gaudi. This example uses four nodes:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import torch
import torch.distributed as dist
from torch.distributed._tensor import DTensor, DeviceMesh, Shard, Replicate, distribute_module, distribute_tensor 
import torch.nn as nn
from torch.distributed.tensor.parallel import (
    ColwiseParallel,
    parallelize_module,
    RowwiseParallel,
)
import torch.multiprocessing as mp

torch.manual_seed(0)

def setup(rank, world_size):
    print("rank", rank)
    print("world_size", world_size)
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    import habana_frameworks.torch.distributed.hccl
    dist.init_process_group(backend='hccl', rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

ITER_TIME = 10

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 32)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(32, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def print0(msg, rank):
    if rank == 0:
        print(msg)

def run_example(rank, world_size):
    """
    Main body of the demo of a basic version of tensor parallel by using
    PyTorch native APIs.
    """
    # set up world pg
    setup(rank, world_size)
    print0("Create a sharding plan based on the given world_size", rank)
    # create a sharding plan based on the given world_size.
    device_mesh = DeviceMesh(
        "hpu",
        torch.arange(world_size),
    )

    # create model 
    model = ToyModel()
    # Create a optimizer for the parallelized module.
    LR = 0.25
    optimizer = torch.optim.SGD(model.parameters(), lr=LR)
    print0("Parallelize the module based on the given Parallel Style", rank)
    parallelize_plan = {
        "net1": ColwiseParallel(),
        "net2": RowwiseParallel(),
    }
    # Parallelize the module based on the given Parallel Style.
    model = parallelize_module(model, device_mesh, parallelize_plan)

    # Perform a num of iterations of forward/backward
    # and optimizations for the sharded module.
    for i in range(ITER_TIME):
        inp = torch.rand(20, 10)
        output = model(inp)
        output.sum().backward()
        optimizer.step()
        print0(f" Iteration {i}", rank)

    # shutting down world pg
    cleanup()

if __name__ == "__main__":
    import habana_frameworks.torch as htorch
    WORLD_SIZE = htorch.hpu.device_count()
    print("world size", WORLD_SIZE)
    assert WORLD_SIZE == 4  # our example uses 4 worker ranks
    mp.spawn(run_example, args=(WORLD_SIZE,), nprocs=WORLD_SIZE, join=True)
  • Line 20 - Import habana_frameworks.torch.distributed.hccl:

    import habana_frameworks.torch.distributed.hccl
    
  • Line 52 - Create a device mesh of Gaudi devices based on the sharding plan:

    device_mesh = DeviceMesh(
            "hpu",
            torch.arange(world_size),
    )
    
  • Line 68 - Parallelize the module based on the given parallel style on Gaudi device:

    model = parallelize_module(model, device_mesh, parallelize_plan)
    

Executing the Example

Execute the toy_example.py by running the below command. Since lazy mode is the default, running the model with PT_HPU_LAZY_MODE=0 disables lazy mode:

PT_HPU_LAZY_MODE=0 python toy_example.py

Supported Features

The following table lists the DTensor features supported on Gaudi. For more details on DTensor, refer to DTensor GitHub page.

Feature

Description

Sharding

Shard on tensor dimension across devices.

Replication

Replicate across devices.

Partial

Tensor with same shape but only partial values.

Sharding+Replication

Tensor types.

DeviceMesh

Abstraction that describes global view of devices within a cluster.

distribute_tensor

Distributed tensor according to device_mesh placements.

distribute_module

Converts all module parameters to distributed tensor parameters.

parallelize_module

To parallelize tensor for tensor parallelism.

redistribute

Convert from one transformation to another. For example, rowwise sharding to columnwise.

dtensor_from_local

Convert torch.tensor to dtensor.

dtensor_to_local

Convert DTensor to torch.tensor.

checkpoint

Save/load large sharded models.