Distributed Tensor Theory of Operations
On this Page
Distributed Tensor Theory of Operations¶
Construct DistributedTensor¶
The following outlines how to construct a DistributedTensor (DTesnor) directly, to represent different types of sharding, replication, sharding + replication:
Initialize an HCCL process group on each rank.
Construct a device mesh with available devices (multi-host or single host).
Define the data distribution (placement - shard, replicate, partial).
Create the distributed tensor with device mesh and placement defined in #2 and #3. The tensor is distributed based on the data distribution type specified.
Use Module Level APIs¶
The following shows how to use module level APIs to directly distribute the module parameters:
Initialize an HCCL process group on each rank.
Construct a device mesh with available devices (multi-host or single host).
Convert the module parameters to DistributedTensor parameters using the
distribute_module()
API according to thepartition_fn
specified. It also converts the input/output of the module by specifying theinput_fn
andoutput_fn
.Forward/backward pass will use the distributed tensors.
Enable DDP¶
To enable DDP with DistributedTensor:
Init step - Define the sharding strategy:
Initialize an HCCL process group on each rank.
Construct a device mesh with available devices (multi-host or single host).
Convert the module parameters to DistributedTensor parameters using the
distribute_module()
API according to thepartition_fn
specified. It also converts the input/output of the module by specifying theinput_fn
andoutput_fn
.
Forward step:
With DistributedTensor, in the forward pass, parameters are replicated, and input activations are sharded on the batch dimension. Therefore, in the forward pass, the output activations are also sharded on the batch dimension. No communications are triggered (default is replication type for module parameter).
Forward pass is calculated.
Backward step:
Backward pass is calculated.
If the input activations is ShardedTensor and activation grad is also ShardedTensor, both are sharded on the batch dimension. The output parameter gradient will be partial. AllReduce will be generated on the PartialTensor
param_grad
before reading from it.
Enable FSDP¶
FSDP does allGather → comp → discard in forward, and AllGather → comp → ReduceScatter in backward. This can also be presented using regular DistributedTensor behaviors. In the forward, parameters are originally ShardedTensors. You can run a redistribute to convert them to ReplicatedTensors. To enable FSDP with DistributedTensor:
Init step - Define the sharding strategy:
Initialize an HCCL process group on each rank.
Construct a device mesh with available devices (multi-host or single host).
Convert the module parameters to DistributedTensor parameters using the
distribute_module()
API according to thepartition_fn
specified. It also converts the input/output of the module by specifying theinput_fn
andoutput_fn
.
Forward step:
If parameters are sharded DTensors, then redistribute to convert parameters to replicated DTensors.
If parameter are replicate DTensors, then redistribute to convert parameters to sharded DTensors.
Forward pass is calculated.
Backward step:
Backward pass is calculated.
If parameters are sharded DTensors, then redistribute to convert parameters to replicated DTensors. This is the same as AllGather
If parameters are replicate DTensors, then redistribute to convert parameters to sharded DTensors. This is the same as discard non-owning shards.