DeepSpeed User Guide for Training
On this Page
DeepSpeed User Guide for Training¶
The purpose of this document is to guide Data Scientists to run PyTorch models on the Habana® Gaudi® infrastructure using a DeepSpeed interface.
DeepSpeed Validated Configurations¶
The following DeepSpeed configurations have been validated to be fully functioning with both first-gen Gaudi and Gaudi2:
Configuration |
Description |
---|---|
Distributed Data Parallel (multi-card) |
Trains the same model across multiple ranks by splitting the datasets between the workers to achieve better performance compared to a single card. |
ZeRO-1 |
Partitions the optimizer states across the ranks so that each process updates its own partition. |
ZeRO-2 |
On top of ZeRO-1, each process retains only the gradients corresponding to its portion of the optimizer states. |
ZeRO-3 |
The full model state is partitioned across the processes (including 16-bit weights). ZeRO-3 automatically collects and partitions them during the forward and backward passes. Make sure to use only optimizers that have been tested with DeepSpeed ZeRO. For further details, refer to Using ZeRO-3 section. |
ZeRO-Offload |
Offloads the optimizer’s memory and computation from HPU to the host CPU. The implementation of Adam on CPU is made more efficient by DeepSpeedCPUAdam. |
ZeRO-Infinity |
Extends ZeRO-3 functionality by allowing the offload of both the model and optimizer parameters to the CPU memory. |
Model Pipeline Parallelism |
Splits the model layers between several workers so each one will execute the forward and backward of their own layer. To optimize the evaluation process during training, refer to Optimizing Pipeline Parallelism section. |
Model Tensor Parallelism (using Megatron-DeepSpeed) |
Splits the model tensors into chunks so that each tensor resides on its designated HPU. Megatron introduces an approach for model tensor parallelism for transformer based models. |
BF16 precision |
Reduces the model memory consumption and improves performance by training with BF16 precision. |
BF16Optimizer |
Allows BF16 precision training with pipeline parallelism. An optimizer that implements ZeRO-1 for BF16 and with gradient accumulation at FP32. |
Activation Checkpointing |
Recomputes forward pass activations during the backward pass in order to save memory. For further details, refer to Using Activation Checkpointing section. |
Note
Model Pipeline and Tensor Parallelism (via Megatron Deepspeed) are currently supported only on Gaudi2.
There is a known issue with HCCL which impacts ZeRO features on Gaudi2, and may cause incorrect behavior during training. This will be fixed in the upcoming release.
DeepSpeed’s multi-node training uses
pdsh
for invoking the processes on remote hosts. Make sure it is installed on your machine before using it.Upon initialization, Habana-DeepSpeed enforces Deterministic behavior by setting
habana_frameworks.torch.hpu.setDeterministic(True)
.All further information on DeepSpeed configurations can be found in DeepSpeed documentation.
Installing DeepSpeed Library¶
The HabanaAI GitHub has a fork of the DeepSpeed library that includes changes to add support for SynapseAI. To use DeepSpeed with Gaudi, you must install Habana’s fork for DeepSpeed by installing directly from the DeepSpeed fork repository located in HabanaAI GitHub:
pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.11.0
DeepSpeed training was tested on this fork which is based on DeepSpeed v0.9.4.
Integrating DeepSpeed with Gaudi¶
To run DeepSpeed on Gaudi, make sure to:
Prepare your PyTorch model to run on Gaudi by following the steps in the Importing Habana Torch Library. If you have an existing training script that runs on Gaudi, migrating your model is not required.
Follow the steps in DeepSpeed Requirements to enable DeepSpeed on Gaudi.
DeepSpeed Requirements¶
Follow the instructions in https://www.deepspeed.ai/getting-started/ with the following modifications:
Replace the
loss.backward()
andoptimizer.step())
withmodel_engine.backward(loss)
andmodel_engine.step())
.Replace all usages of model object in
deepspeed.initialize()
call with the returned newmodel_engine
object.Remove
from torch.nn.parallel import DistributedDataParallel as DDP
and remove the DDP call for the model.
In
deepspeed.init_distributed()
, make sure thatdist_backend
is set to HCCL:
deepspeed.init_distributed(dist_backend='hccl', init_method = <init_method>)
For the current release, the following steps are required in this specific order before calling
deepspeed.initialize()
:Move your model to HPU and cast it to BF16 in case required.
model.to(hpu, bf16)
If your model uses weight sharing, make sure these weights are created inside the module. Refer to Weight Sharing.
Initialize the optimizer.
Update DeepSpeed run command to include
--use_hpu
flag. The name of the file is up to you.
deepspeed model.py --deepspeed --deepspeed_config <json file path> --use_hpu
Note
It is highly recommended to review our pretraining examples:
Using ZeRO-3¶
For optimal performance of ZeRO-3, it is recommended to configure the following parameters in the DeepSpeed ZeRO settings as explained below:
overlap_comm=false
contiguous_gradients=true
reduce_scatter": false
Note
If you encounter accuracy issues, it is recommended to set the contiguous_gradients
to false.
The following shows a usage example:
"zero_optimization": {
"stage": 3,
"overlap_comm": false,
...
"contiguous_gradients": true,
"reduce_scatter": false
}
For further information on how to configure ZeRO, refer to ZeRO Configuration section.
Using Activation Checkpointing¶
To use activation checkpointing with Gaudi, integrate deepspeed.runtime.activation_checkpointing.checkpointing.checkpoint
wrapper from Habana’s DeepSpeed
into your model according to the instructions in TORCH.UTILS.CHECKPOINT guide. For example, see the following
extracted from DeepSpeed-BERT script/modeling.py.
class BertEncoder(nn.Module):
def __init__(self, config):
super(BertEncoder, self).__init__()
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
self.output_all_encoded_layers = config.output_all_encoded_layers
self._checkpoint_activations = False
self._checkpoint_activations_interval = 1
...
def forward(self, hidden_states, attention_mask):
all_encoder_layers = []
layer_norm_input = 0
if self._checkpoint_activations:
hidden_states, layer_norm_input = self.checkpointed_forward(
hidden_states, layer_norm_input, attention_mask, self._checkpoint_activations_interval)
else:
for i, layer_module in enumerate(self.layer):
hidden_states, layer_norm_input = layer_module(hidden_states, layer_norm_input, attention_mask)
The values of the following parameters have been validated to be fully functioning on Gaudi:
“partition_activations”: true/false
“cpu_checkpointing”: true/false
“contiguous_memory_optimization”: true/false - As per DeepSpeed documentation,
contiguous_memory_optimizationcan=true
only whenpartition_activations=true
.“synchronize_checkpoint_boundary”: true/false
“profile”: false
For further details, refer to Configuring Activation Checkpointing section.
Optimizing Pipeline Parallelism¶
During model training with pipeline parallelism, communication redundancy among ranks can be eliminated to optimize the evaluation process.
This can be achieved by setting the bcast_loss
flag to False
. Consequently, the return value of non-0 ranks within pipeline groups
will change, and only the rank-0 of each group will return the actual evaluation loss obtained from the eval_batch
call.
def eval_batch(self,
data_iter,
return_logits=False,
compute_loss=True,
bcast_loss=True,
reduce_output='avg')
To maintain the original behavior of DeepSpeed, the default value of bcast_loss
has been kept as True
.
Using the BLOOM13B Megatron-DeepSpeed model, you can review the example below extracted from the training script:
if args.deepspeed and args.ds_pipeline_enabled:
# DeepSpeed uses eval_batch() and already aggregates losses.
assert isinstance(model, list) and len(model) == 1
loss = model[0].eval_batch(data_iterator, bcast_loss=False, eval_micro_batches=num_eval_microbatches)
loss_dicts = [{'lm loss' : loss}] * num_eval_microbatches
else:
assert args.micro_batch_size == args.eval_micro_batch_size, \
"Unsupported for split micro batch"
loss_dicts = forward_backward_func(
forward_step_func, data_iterator, model, optimizer=None,
timers=None, forward_only=True)
DeepSpeed Runtime Environment Variables¶
The following table describes runtime flags that should be set in the environment to handle OutOfMemory issues in large scale models. Make sure to set all the flags listed below.
Flag |
Description |
---|---|
|
Limits internal graph size to 1000 Ops and reduces the lazy mode memory overheard. This will be improved in future releases. Note: This may affect performance. |
|
Sets memory pool to consume the entire HBM memory. |