DeepSpeed User Guide for Training

The purpose of this document is to guide Data Scientists to run PyTorch models on Intel® Gaudi® AI accelerator using a DeepSpeed interface.

DeepSpeed Runtime Environment Variables

The following table describes runtime flags that should be set in the environment to handle Out of Memory issues in large models. Make sure to set all the flags listed below.

Flag

Description

PT_HPU_MAX_COMPOUND_OP_SIZE=1000

Limits internal graph size to 1000 Ops and reduces the lazy mode memory overheard. This will be improved in future releases. Note: This may affect performance.

PT_HPU_POOL_MEM_ACQUIRE_PERC=100

Sets memory pool to consume the entire HBM memory.

DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED=1

Handles OOM in large models when using ZeRO-3. Note: This may affect performance.

DeepSpeed Validated Configurations

The following DeepSpeed configurations have been validated to be fully functioning with HPU:

Configuration

Description

Example

Distributed Data Parallel (multi-card)

Trains the same model across multiple ranks by splitting the datasets between the workers to achieve better performance compared to a single card.

N/A

ZeRO-1

Partitions the optimizer states across the ranks so that each process updates its own partition.

Bert

ZeRO-2

On top of ZeRO-1, each process retains only the gradients corresponding to its portion of the optimizer states.

Bert

ZeRO-3

The full model state is partitioned across the processes (including 16-bit weights). ZeRO-3 automatically collects and partitions them during the forward and backward passes. Make sure to use only optimizers that have been tested with DeepSpeed ZeRO. For further details, refer to Using ZeRO-3 section.

Flan_T5_XXL

ZeRO++ hpZ

ZeRO++ is a set of optimization methods that extend ZeRO capabilities and enhance large model training efficiency. It can only be used with ZeRO-3. Hierarchical partitioning ZeRO (hpZ) is one of ZeRO++ three communication optimizations. Support for the other two methods will be added in future releases. Unlike ZeRO, hpZ keeps a complete model copy on each machine. Although this approach leads to increased memory usage, it replaces the costly cross-machine all-gather/broadcast on weights with an intra-machine alternative, which is faster due to high intra-machine communication bandwidth.

DeepSpeed ZeRO++ Tutorial

ZeRO-Offload

Offloads the optimizer’s memory and computation from HPU to the host CPU. The implementation of Adam on CPU is made more efficient by DeepSpeedCPUAdam.

offload_optimizer_to_cpu

ZeRO-Infinity

Extends ZeRO-3 functionality by allowing the offload of both the model and optimizer parameters to the CPU memory.

offload_optimizer_param_to_cpu

Model Pipeline Parallelism

Splits the model layers between several workers so each one will execute the forward and backward of their own layer. To optimize the evaluation process during training, refer to Optimizing Pipeline Parallelism section. Note: Pipeline parallelism is not supported on first-gen Gaudi.

LLaMa

Model Tensor Parallelism (using Megatron-DeepSpeed)

Splits the model tensors into chunks so that each tensor resides on its designated HPU. Megatron introduces an approach for model tensor parallelism for transformer based models.

LLaMa

Model Sequence Parallelism (using Megatron-DeepSpeed)

Splits the input of the sequence access into smaller sequences that are processed in parallel by each HPU. For further details, refer to Using Sequence Parallelism.

LLaMa

BF16 Precision

Reduces the model memory consumption and improves performance by training with BF16 precision.

Bert

BF16 Optimizer

Allows BF16 precision training with pipeline parallelism. An optimizer that implements ZeRO-1 for BF16 and with gradient accumulation at FP32.

Bert

Activation Checkpointing

Recomputes forward pass activations during the backward pass in order to save memory. For further details, refer to Using Activation Checkpointing section.

Bert

Note

  • Model Pipeline, Tensor Parallelism (via Megatron Deepspeed), and Sequence Parallelism (via Megatron Deepspeed) are currently supported on Gaudi 2 only.

  • DeepSpeed’s multi-node training uses pdsh for invoking the processes on remote hosts. Make sure it is installed on your machine before using it.

  • Upon initialization, Intel Gaudi DeepSpeed enforces Deterministic behavior by setting habana_frameworks.torch.hpu.setDeterministic(True).

  • All further information on DeepSpeed configurations can be found in DeepSpeed documentation.

For further details on using DeepSpeed and Megatron 3D Parallelism configurations with large language models, see Optimizing Large Language Models.

Installing DeepSpeed Library

The Intel Gaudi GitHub has a fork of the DeepSpeed library that includes changes to add support for Intel Gaudi software. To use DeepSpeed with Gaudi, you must install Intel Gaudi’s fork for DeepSpeed directly from the DeepSpeed fork repository:

pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.15.1

DeepSpeed training was tested on this fork which is based on DeepSpeed v0.12.4.

Integrating DeepSpeed with Gaudi

To run DeepSpeed on Gaudi, make sure to:

  • Prepare your PyTorch model to run on Gaudi by following the steps in the Importing PyTorch Models Manually. If you have an existing training script that runs on Gaudi, migrating your model is not required.

  • Follow the steps in DeepSpeed Requirements to enable DeepSpeed on Gaudi.

DeepSpeed Requirements

  • Follow the instructions in https://www.deepspeed.ai/getting-started/ with the following modifications:

    • Replace the loss.backward() and optimizer.step()) with model_engine.backward(loss) and model_engine.step()).

    • Replace all usages of model object in deepspeed.initialize() call with the returned new model_engine object.

    • Remove from torch.nn.parallel import DistributedDataParallel as DDP and remove the DDP call for the model.

  • In deepspeed.init_distributed(), make sure that dist_backend is set to HCCL:

deepspeed.init_distributed(dist_backend='hccl', init_method = <init_method>)
  • For the current release, the following steps are required in this specific order before calling deepspeed.initialize():

    • Move your model to HPU and cast it to BF16 in case required.

    model.to(hpu, bf16)
    
    • If your model uses weight sharing, make sure these weights are created inside the module. Refer to Weight Sharing.

    • Initialize the optimizer.

  • Update DeepSpeed run command to include --use_hpu flag. The name of the file is up to you.

deepspeed model.py --deepspeed --deepspeed_config <json file path> --use_hpu

Note

It is highly recommended to review our pre-training example: DeepSpeed-BERT.

Using ZeRO-3

For optimal performance of ZeRO-3, it is recommended to configure the following parameters in the DeepSpeed ZeRO settings as explained below:

  • overlap_comm=false

  • contiguous_gradients=true

  • reduce_scatter": false

The following shows a usage example:

"zero_optimization": {
    "stage": 3,
    "overlap_comm": false,
    ...

    "contiguous_gradients": true,
    "reduce_scatter": false
}

Note

  • If you encounter Out of Memory issues, set the following environment variable: DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED=1. This may affect performance.

  • If you encounter accuracy issues, it is recommended to set the contiguous_gradients to false.

For further information on how to configure ZeRO, refer to ZeRO Configuration section.

Using Activation Checkpointing

To use activation checkpointing with Gaudi, integrate deepspeed.runtime.activation_checkpointing.checkpointing.checkpoint wrapper from Intel Gaudi’s DeepSpeed into your model according to the instructions in TORCH.UTILS.CHECKPOINT guide. For example, see the following extracted from DeepSpeed-BERT script/modeling.py.

class BertEncoder(nn.Module):
    def __init__(self, config):
        super(BertEncoder, self).__init__()
        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
        self.output_all_encoded_layers = config.output_all_encoded_layers
        self._checkpoint_activations = False
        self._checkpoint_activations_interval = 1

        ...

    def forward(self, hidden_states, attention_mask):
        all_encoder_layers = []

        layer_norm_input = 0
        if self._checkpoint_activations:
            hidden_states, layer_norm_input = self.checkpointed_forward(
                hidden_states, layer_norm_input, attention_mask, self._checkpoint_activations_interval)
        else:
            for i, layer_module in enumerate(self.layer):
                hidden_states, layer_norm_input = layer_module(hidden_states, layer_norm_input, attention_mask)

The values of the following parameters have been validated to be fully functioning on Gaudi:

  • “partition_activations”: true/false

  • “cpu_checkpointing”: true/false

  • “contiguous_memory_optimization”: true/false - As per DeepSpeed documentation, contiguous_memory_optimizationcan=true only when partition_activations=true.

  • “synchronize_checkpoint_boundary”: true/false

  • “profile”: false

For further details, refer to Configuring Activation Checkpointing section.

Optimizing Pipeline Parallelism

During model training with pipeline parallelism, communication redundancy among ranks can be eliminated to optimize the evaluation process. This can be achieved by setting the bcast_loss flag to False. Consequently, the return value of non-0 ranks within pipeline groups will change, and only the rank-0 of each group will return the actual evaluation loss obtained from the eval_batch call.

def eval_batch(self,
                  data_iter,
                  return_logits=False,
                  compute_loss=True,
                  bcast_loss=True,
                  reduce_output='avg')

To maintain the original behavior of DeepSpeed, the default value of bcast_loss has been kept as True.

Using the LLaMa model, you can review the example below extracted from the training script:

if args.deepspeed and args.ds_pipeline_enabled:
    # DeepSpeed uses eval_batch() and already aggregates losses.
    assert isinstance(model, list) and len(model) == 1
    loss = model[0].eval_batch(data_iterator, bcast_loss=False, eval_micro_batches=num_eval_microbatches)
    loss_dicts = [{'lm loss' : loss}] * num_eval_microbatches
else:
    assert args.micro_batch_size == args.eval_micro_batch_size, \
           "Unsupported for split micro batch"
    loss_dicts = forward_backward_func(
        forward_step_func, data_iterator, model, optimizer=None,
        timers=None, forward_only=True)

Note

Pipeline parallelism is not supported on first-gen Gaudi.

Using Sequence Parallelism

Sequence Parallelism is used for training with Tensor Parallelism. This approach involves splitting Layer-Norm and Dropout operations along the sequence. These operations occur after attention and MLP blocks which are replicated across the Tensor Parallel group. As a result, a significant amount of activation memory is reduced.

To configure DeepSpeed pipeline module to support Sequence Parallelism:

1. Mark the sequence parallel weights with an attribute. Each weight split under the sequence parallel region should be added with an attribute during its initialization as sequence_parallel and set to True:

if sequence_parallel:
   # set sequence parallelism flag on weight parameter
   setattr(self.weight, 'sequence_parallel', True)

2. Configure DeepSpeed pipeline engine to disable partitioning of activations and gradients. Set set pipe_partitioned and grad_partitioned attributes to False under the “pipeline” section in DeepSpeed json configuration file:

  "pipeline": {
  "pipe_partitioned": false,
  "grad_partitioned": false
}

You can find a usage code in the LLaMA model.