Distributed Tensor Theory of Operations

Construct DistributedTensor

The following outlines how to construct a DistributedTensor (DTesnor) directly, to represent different types of sharding, replication, sharding + replication:

  1. Initialize an HCCL process group on each rank.

  2. Construct a device mesh with available devices (multi-host or single host).

  3. Define the data distribution (placement - shard, replicate, partial).

  4. Create the distributed tensor with device mesh and placement defined in #2 and #3. The tensor is distributed based on the data distribution type specified.

Use Module Level APIs

The following shows how to use module level APIs to directly distribute the module parameters:

  1. Initialize an HCCL process group on each rank.

  2. Construct a device mesh with available devices (multi-host or single host).

  3. Convert the module parameters to DistributedTensor parameters using the distribute_module() API according to the partition_fn specified. It also converts the input/output of the module by specifying the input_fn and output_fn.

  4. Forward/backward pass will use the distributed tensors.

Enable DDP

To enable DDP with DistributedTensor:

  1. Init step - Define the sharding strategy:

    1. Initialize an HCCL process group on each rank.

    2. Construct a device mesh with available devices (multi-host or single host).

    3. Convert the module parameters to DistributedTensor parameters using the distribute_module() API according to the partition_fn specified. It also converts the input/output of the module by specifying the input_fn and output_fn.

  2. Forward step:

    1. With DistributedTensor, in the forward pass, parameters are replicated, and input activations are sharded on the batch dimension. Therefore, in the forward pass, the output activations are also sharded on the batch dimension. No communications are triggered (default is replication type for module parameter).

    2. Forward pass is calculated.

  3. Backward step:

    1. Backward pass is calculated.

    2. If the input activations is ShardedTensor and activation grad is also ShardedTensor, both are sharded on the batch dimension. The output parameter gradient will be partial. AllReduce will be generated on the PartialTensor param_grad before reading from it.

Enable FSDP

FSDP does allGather → comp → discard in forward, and AllGather → comp → ReduceScatter in backward. This can also be presented using regular DistributedTensor behaviors. In the forward, parameters are originally ShardedTensors. You can run a redistribute to convert them to ReplicatedTensors. To enable FSDP with DistributedTensor:

  1. Init step - Define the sharding strategy:

    1. Initialize an HCCL process group on each rank.

    2. Construct a device mesh with available devices (multi-host or single host).

    3. Convert the module parameters to DistributedTensor parameters using the distribute_module() API according to the partition_fn specified. It also converts the input/output of the module by specifying the input_fn and output_fn.

  2. Forward step:

    1. If parameters are sharded DTensors, then redistribute to convert parameters to replicated DTensors.

    2. If parameter are replicate DTensors, then redistribute to convert parameters to sharded DTensors.

    3. Forward pass is calculated.

  3. Backward step:

    1. Backward pass is calculated.

    2. If parameters are sharded DTensors, then redistribute to convert parameters to replicated DTensors. This is the same as AllGather

    3. If parameters are replicate DTensors, then redistribute to convert parameters to sharded DTensors. This is the same as discard non-owning shards.