FP8 Training with Intel Gaudi Transformer Engine
On this Page
FP8 Training with Intel Gaudi Transformer Engine¶
Intel® Gaudi® 2 AI accelerator supports 8-bit floating point precision (FP8) training using the Intel Gaudi Transformer Engine (TE) library.
The Intel Gaudi Transformer Engine library provides optimized implementations of PyTorch modules that perform computations in FP8 data type. This allows for better performance with lower memory utilization. The building blocks provided by TE can be used for popular Transformer architectures. An automatic mixed precision-like API can be used to configure and control the operation of the FP8-enabled modules in the code.
Using Transformer Engine on Gaudi¶
To use TE on Gaudi, perform the following steps:
Import TE and use TE modules in your model, e.g. a
te.Linear
:import torch import habana_frameworks.torch.hpex.experimental.transformer_engine as te # Set dimensions. in_features = 768 out_features = 3072 hidden_size = 2048 # Initialize model and inputs. model = te.Linear(in_features, out_features, bias=True) inp = torch.randn(hidden_size, in_features, device="hpu")
Wrap the forward pass of the training with
fp8_autocast
:from habana_frameworks.torch.hpex.experimental.transformer_engine import recipe # Create an FP8 recipe. Note: All input args are optional. fp8_recipe = recipe.DelayedScaling(margin=0, interval=1) # Enable autocasting for the forward pass with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): out = model(inp) loss = out.sum() loss.backward()
For an example model using Transformer Engine on Gaudi, see MLPerf GPT3 - refer to layers.py and training.py.
Registered APIs¶
Linear¶
Class:
habana_frameworks.mediapipe.fn.GaussianBlur(**kwargs)
Description:
Applies a linear transformation to the incoming data y = xA^T + b
. On Gaudi, it is a drop-in replacement for torch.nn.Linear
.
Parameters:
Parameter |
Type |
Description |
---|---|---|
in_features |
int |
Size of each input sample. |
out_features |
int |
Size of each output sample. |
bias |
bool |
Default = True. If set to False, the layer will not learn an additive bias. |
init_method |
callable |
Default = None. Used for initializing weights in |
skip_weight_param_allocation |
bool |
Default = False. If set to True, weight parameter is not allocated and must be passed as a keyword argument |
params_dtype |
torch.dtype |
Default = |
minimize_memory |
bool |
Default = False. When set to True, memory usage is decreased by recalculating FP8 weight in backward pass. This reduces memory usage but decreases performance. It works especially well with DeepSpeed pipelining mechanism. |
fp8_autocast¶
Class:
fp8_autocast(enabled: bool = False, force_measurement: Optional[bool] = None, fp8_recipe: Optional[DelayedScaling] = None, fp8_group: Optional[dist_group_type] = None)
Description:
Context manager for FP8 usage.
with fp8_autocast(enabled=True):
out = model(inp)
Parameters:
Parameter |
Type |
Description |
---|---|---|
enabled |
bool |
Default = False. Whether or not to enable FP8. |
force_measurement |
bool |
Default = None. Whether or not to force amax measurement regardless of the |
fp8_recipe |
recipe.DelayedScaling |
Default = None. Recipe used for FP8 training. |
fp8_group |
torch._C._distributed_c10d.ProcessGroup |
Default = None. Distributed group over which amaxes for the FP8 tensors are reduced at the end of each training step. |
Format¶
Class:
Format
Description:
Supported FP8 formats:
E5M2 - All FP8 tensors are in e5m2 format
HYBRID - FP8 tensors in the forward pass are in e4m3 format, FP8 tensors in the backward pass are in e5m2 format
DelayedScaling¶
Class:
DelayedScaling
Description:
Use the delayed scaling factor strategy. Use scale factor from previous iteration, recompute once every interval
, and record
amax history of amax_history_len
steps.
Parameters:
Parameter |
Type |
Description |
---|---|---|
margin |
int |
Default = 0. Margin for the scaling factor computation. |
interval |
int |
Default = 1. Controls how often the scaling factor is recomputed. |
fp8_format |
{Format.E5M2, Format.HYBRID} |
Default = Format.E5M2. Controls the FP8 data format used during forward and backward pass. |
amax_history_len |
int |
Default = 1. The length of the amax history window used for scaling factor computation. |
amax_compute_algo |
{max, most_recent, Callable} |
Default = most_recent. Algorithm used for choosing the
|
scaling_factor_compute_algo |
Callable |
Default = None. Algorithm used for computing the new scaling factor based on the value of
|
override_linear_precision |
Tuple (bool, bool, bool) |
Default = (False, False, False). Whether or not to execute the |
reduce_amax |
bool |
Default = True. By default, if |
Note
By default (when
scaling_factor_compute_algo
is left as None), the scaling factor is computed from the finalamax
value using the formula:FP8_MAX = maximum_representable_value(fp8_format) exp = get_exponent(FP8_MAX / amax) - margin new_scaling_factor = 2.0 ^ exp
The scaling factor should always be a power of two as to not introduce numerical error during the conversion from FP8 to higher precision format.