DeepSpeed User Guide for Training

The purpose of this document is to guide Data Scientists to run PyTorch models on the Habana® Gaudi® infrastructure using a DeepSpeed interface.

DeepSpeed Validated Configurations

The following DeepSpeed configurations have been validated to be fully functioning with Gaudi:

Configuration

Description

Distributed Data Parallel (multi-card)

Trains the same model across multiple ranks by splitting the datasets between the workers to achieve better performance compared to a single card.

ZeRO-1

Partitions the optimizer states across the ranks so that each process updates its own partition.

ZeRO-2

On top of ZeRO-1, each process retains only the gradients corresponding to its portion of the optimizer states.

Model Pipeline Parallelism

Splits the model layers between several workers so each one will execute the forward and backward of their own layer.

Model Tensor Parallelism (using Megatron-DeepSpeed)

Splits the model tensors into chunks so that each tensor resides on its designated HPU. Megatron introduces an approach for model tensor parallelism for transformer based models.

BF16 precision

Reduces the model memory consumption and improves performance by training with BF16 precision.

BF16Optimizer

Allows BF16 precision training with pipeline parallelism. An optimizer that implements ZeRO-1 for BF16 and with gradient accumulation at FP32.

Activation Checkpointing

Recomputes forward pass activations during the backward pass in order to save memory. For further details, refer to Using Activation Checkpointing section.

Note

  • All further information on DeepSpeed configurations can be found in DeepSpeed documentation.

  • DeepSpeed’s multi-node training uses pdsh for invoking the processes on remote hosts. Make sure it is installed on your machine before using it.

Installing DeepSpeed Library

The HabanaAI GitHub has a fork of the DeepSpeed library that includes changes to add support for SynapseAI. To use DeepSpeed with Gaudi, you must install Habana’s fork for DeepSpeed by installing directly from the DeepSpeed fork repository located in HabanaAI GitHub:

pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.8.0

Integrating DeepSpeed with Gaudi

To run DeepSpeed on Gaudi, make sure to:

  • Prepare your PyTorch model to run on Gaudi by following the steps in the Porting PyTorch Models to Gaudi. If you have an existing training script that runs on Gaudi, migrating your model is not required.

  • Follow the steps in DeepSpeed Requirements to enable DeepSpeed on Gaudi.

DeepSpeed Requirements

  • Follow the instructions in https://www.deepspeed.ai/getting-started/ with the following modifications:

    • Replace the loss.backward() and optimizer.step()) with model_engine.backward(loss) and model_engine.step()).

    • Replace all usages of model object in deepspeed.initialize() call with the returned new model_engine object.

    • Remove from torch.nn.parallel import DistributedDataParallel as DDP and remove the DDP call for the model.

  • In deepspeed.init_distributed(), make sure that dist_backend is set to HCCL:

deepspeed.init_distributed(dist_backend='hccl', init_method = <init_method>)
  • For the current release, the following steps are required in this specific order before calling deepspeed.initialize():

    • Move your model to HPU and cast it to BF16 in case required.

    model.to(hpu, bf16)
    
    • If your model uses weight sharing, make sure these weights are created inside the module. Refer to Weight Sharing.

    • Initialize the optimizer.

  • Update DeepSpeed run command to include --use_hpu flag. The name of the file is up to you.

deepspeed model.py --deepspeed --deepspeed_config <json file path> --use_hpu

Note

It is highly recommended to review our pretraining examples:

Using Activation Checkpointing

To use activation checkpointing with Gaudi, integrate deepspeed.runtime.activation_checkpointing.checkpointing.checkpoint wrapper from Habana’s DeepSpeed into your model according to the instructions in TORCH.UTILS.CHECKPOINT guide. For example, see the following extracted from DeepSpeed-BERT script/modeling.py.

class BertEncoder(nn.Module):
    def __init__(self, config):
        super(BertEncoder, self).__init__()
        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
        self.output_all_encoded_layers = config.output_all_encoded_layers
        self._checkpoint_activations = False
        self._checkpoint_activations_interval = 1

        ...

    def forward(self, hidden_states, attention_mask):
        all_encoder_layers = []

        layer_norm_input = 0
        if self._checkpoint_activations:
            hidden_states, layer_norm_input = self.checkpointed_forward(
                hidden_states, layer_norm_input, attention_mask, self._checkpoint_activations_interval)
        else:
            for i, layer_module in enumerate(self.layer):
                hidden_states, layer_norm_input = layer_module(hidden_states, layer_norm_input, attention_mask)

Note

The default values of the following parameters have been validated to be fully functioning on Gaudi:

  • “partition_activations”: false

  • “cpu_checkpointing”: false

  • “contiguous_memory_optimization”: false

  • “synchronize_checkpoint_boundary”: false

  • “profile”: false

For further details, refer to Configuring Activation Checkpointing section.

DeepSpeed Runtime Environment Variables

The following table describes runtime flags that should be set in the environment to handle OutOfMemory issues in large scale models. Make sure to set all the flags listed below.

Flag

Description

PT_HPU_MAX_COMPOUND_OP_SIZE=1000

Limits internal graph size to 1000 Ops and reduces the lazy mode memory overheard. This will be improved in future releases. Note: This may affect performance.

TENSORS_KEEP_ALLOCATED=0

Restricts keeping tensors allocated to avoid unnecessary memory dependencies. To use TENSORS_KEEP_ALLOCATED=0, make sure to set ENABLE_EXPERIMENTAL_FLAGS=true. Note: This may affect performance.

PT_HPU_POOL_MEM_ACQUIRE_PERC=100

Sets memory pool to consume the entire HBM memory.