FP8 Training with Intel Gaudi Transformer Engine

Intel® Gaudi® 2 AI accelerator supports 8-bit floating point precision (FP8) training using the Intel Gaudi Transformer Engine (TE) library.

Intel Gaudi Transformer Engine library provides optimized implementations of PyTorch modules that perform computations in FP8 data type. This allows for better performance with lower memory utilization. The building blocks provided by TE can be used for popular Transformer architectures. An automatic mixed precision-like API can be used to configure and control the operation of the FP8-enabled modules in the code.

Using Transformer Engine on Gaudi

To use TE on Gaudi, perform the following steps:

  1. Import TE and use TE modules in your model, e.g. a te.Linear:

    import torch
    import habana_frameworks.torch.hpex.experimental.transformer_engine as te
    
    # Set dimensions.
    in_features = 768
    out_features = 3072
    hidden_size = 2048
    
    # Initialize model and inputs.
    model = te.Linear(in_features, out_features, bias=True)
    inp = torch.randn(hidden_size, in_features, device="hpu")
    
  2. Wrap the forward pass of the training with fp8_autocast:

    from habana_frameworks.torch.hpex.experimental.transformer_engine import recipe
    
    # Create an FP8 recipe. Note: All input args are optional.
    fp8_recipe = recipe.DelayedScaling(margin=0, interval=1)
    
    # Enable autocasting for the forward pass
    with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
        out = model(inp)
    
    loss = out.sum()
    loss.backward()
    

For an example model using Transformer Engine on Gaudi, see MLPerf GPT3 - refer to layers.py and training.py.

Registered APIs

Linear

Class:

  • habana_frameworks.mediapipe.fn.GaussianBlur(**kwargs)

Description:

Applies a linear transformation to the incoming data y = xA^T + b. On Gaudi, it is a drop-in replacement for torch.nn.Linear.

Parameters:

Parameter

Type

Description

in_features

int

Size of each input sample.

out_features

int

Size of each output sample.

bias

bool

Default = True. If set to False, the layer will not learn an additive bias.

init_method

callable

Default = None. Used for initializing weights in init_method(weight). When set to None, defaults to torch.nn.init.normal_(mean=0.0, std=0.023).

skip_weight_param_allocation

bool

Default = False. If set to True, weight parameter is not allocated and must be passed as a keyword argument weight during the forward pass.

params_dtype

torch.dtype

Default = torch.get_default_dtype(). It controls the type used to allocate the initial parameters. Useful when the model is trained with lower precision where original FP32 parameters would not fit in the device memory.

minimize_memory

bool

Default = False. When set to True, memory usage is decreased by recalculating FP8 weight in backward pass. This reduces memory usage but decreases performance. It works especially well with DeepSpeed pipelining mechanism.

fp8_autocast

Class:

fp8_autocast(enabled: bool = False, force_measurement: Optional[bool] = None, fp8_recipe: Optional[DelayedScaling] = None, fp8_group: Optional[dist_group_type] = None)

Description:

Context manager for FP8 usage.

with fp8_autocast(enabled=True):
    out = model(inp)

Parameters:

Parameter

Type

Description

enabled

bool

Default = False. Whether or not to enable FP8.

force_measurement

bool

Default = None. Whether or not to force amax measurement regardless of the fp8_recipe.interval setting.

fp8_recipe

recipe.DelayedScaling

Default = None. Recipe used for FP8 training.

fp8_group

torch._C._distributed_c10d.ProcessGroup

Default = None. Distributed group over which amaxes for the FP8 tensors are reduced at the end of each training step.

Format

Class:

Format

Description:

Supported FP8 formats:

  • E5M2 - All FP8 tensors are in e5m2 format

DelayedScaling

Class:

DelayedScaling

Description:

Use the delayed scaling factor strategy. Use scale factor from previous iteration, recompute once every interval, and record amax history of amax_history_len steps.

Parameters:

Parameter

Type

Description

margin

int

Default = 0. Margin for the scaling factor computation.

interval

int

Default = 1. Controls how often the scaling factor is recomputed.

fp8_format

{Format.E5M2}

Default = Format.E5M2. Controls the FP8 data format used during forward and backward pass.

amax_history_len

int

Default = 1. The length of the amax history window used for scaling factor computation.

amax_compute_algo

{‘max’, ‘most_recent’, Callable}

Default = most_recent. Algorithm used for choosing the amax value for the scaling factor computation. There are two predefined choices: max chooses the largest amax in the history window, while most_recent always chooses the most recently seen value. Alternatively, one may pass a function of the signature:

def amax_compute(amax_history: Tensor) -> Tensor

where `Tensor` is a framework tensor type

scaling_factor_compute_algo

Callable,

Default = None. Algorithm used for computing the new scaling factor based on the value of amax. It should be a function of the signature:

def scaling_factor_compute(amax: Tensor,
                        old_scaling_factor: Tensor,
                        fp8_max: Tensor,
                        recipe: DelayedScaling) -> Tensor

where `Tensor` is a framework tensor type

override_linear_precision

Tuple (bool, bool, bool)

Default = (False, False, False). Whether or not to execute the fprop, dgrad, and wgrad GEMMs (respectively) in higher precision when using FP8.

reduce_amax

bool

Default = True. By default, if torch.distributed is initialized, the amax value for FP8 tensors is reduced across the fp8_group (specified in the fp8_autocast call). This keeps the amaxes and scaling factors synced across the given distributed group. If set to False, this reduction is skipped and every Gaudi maintains local amaxes and scaling factors. To ensure results are numerically identical across checkpointing boundaries in this case, all ranks must checkpoint in order to store the local tensors.

Note

  • By default (when scaling_factor_compute_algo is left as None), the scaling factor is computed from the final amax value using the formula:

    FP8_MAX = maximum_representable_value(fp8_format)
    exp = get_exponent(FP8_MAX / amax) - margin
    new_scaling_factor = 2.0 ^ exp
    
  • The scaling factor should always be a power of two as to not introduce numerical error during the conversion from FP8 to higher precision format.