PyTorch Mixed Precision Training on Gaudi


Habana Mixed Precision (HMP) package is a tool that allows you to run mixed precision training on HPU without extensive modifications to existing FP32 model scripts. You can easily add mixed precision training support to the model script by adding the following lines anywhere in the script before the start of the training loop:

from habana_frameworks.torch.hpex import hmp

Any segment of script (e.g. optimizer) in which you want to avoid using mixed precision should be kept under the following Python context:

from habana_frameworks.torch.hpex import hmp
with hmp.disable_casts():
  code line:1
  code line:2

Basic Design Rules

  • Two different lists are maintained: (i) OPs that always run in BF16 only, (ii) OPs that always run in FP32 only.

  • Python decorators are used to add required functionality (bf16 or fp32 casts on OP input(s)) to torch functions (refer to code snippet below).

  • Any OPs not in the above two lists will run with precision type of its 1st input (except for exceptions listed below).

  • For OPs with multiple tensor inputs (maintained in a separate list, e.g. add, sub, cat, stack etc.), cast all inputs to the widest precision type among all input precision types. If any of these OPs are in BF16 or FP32 list, that list has a higher precedence.

  • For in-place OPs (output & 1st input share storage), cast all inputs to precision type of 1st input.

from functools import wraps
def op_wrap(op, cast_fn):
"""Adds wrapper function to OPs. All tensor inputs
for the OP are casted to type determined by cast_fn

op (torch.nn.functional/torch/torch.Tensor): Input OP
cast_fn (to_bf16/to_fp32): Fn to cast input tensors

Wrapper function that shall be inserted back to
corresponding module for this OP.
def wrapper(*args, **kwds):
    args_cast = get_new_args(cast_fn, args, kwds)
    return op(*args_cast, **kwds)

return wrapper

Configuration Options

HMP provides two modes (opt_level = O1/O2) of mixed precision training to choose from. These modes can be chosen by passing opt_level= as an argument to hmp.convert().

O1 is the default and recommended mode of operation when using HMP. O2 can be used for debugging convergence issues as well as for initial iterations of converting a new model to run with mixed precision.

Opt_level = O1

In this mode, OPs that always run in BF16 and OPs that always run in FP32 are selected from a BF16 list and FP32 list respectively. BF16 list contains OPs that are numerically safe to run in lower precision on HPU, whereas FP32 list contains OPs that should be run in higher precision (conservative choice that works across models).

  • Default BF16 list = [addmm, bmm, conv1d, conv2d, conv3d, dot, mm, mv]

  • Default FP32 list = [batch_norm, cross_entropy, log_softmax, softmax, nll_loss, topk]

HMP provides the option of overriding these internal lists, allowing you to provide your own BF16 and FP32 lists (pass bf16_file_path=<.txt> and fp32_file_path=<.txt> as arguments to hmp.convert()). This is particularly useful when customizing mixed precision training for a particular model. For example:

  • Custom BF16 list for ResNet50 = [ addmm, avg_pool2d, bmm, conv2d, dot, max_pool2d, mm, mv, relu, t, linear]

  • Custom FP32 list for ResNet50 = [cross_entropy, log_softmax, softmax, nll_loss, topk]

Opt_level = O2

In this mode, only GEMM and Convolution type OPs (e.g. conv1d, conv2d, conv3d, addmm, mm, bmm, mv, dot) should run in BF16 and all other OPs should run in FP32.

Usage Examples

import torch
from habana_frameworks.torch.hpex import hmp

N, D_in, D_out = 64, 1024, 512
x = torch.randn(N, D_in, device=“hpu”)
y = torch.randn(N, D_out, device=“hpu”)

# enable mixed precision training with optimization level O1, default BF16 list, default FP32 list and logging disabled
# use opt_level to select desired mode of operation
# use bf16_file_path to provide absolute path to a file with custom BF16 list
# use fp32_file_path to provide absolute path to a file with custom FP32 list
# use isVerbose to disable/enable debug logs
hmp.convert(opt_level="O1", bf16_file_path="", fp32_file_path="", isVerbose=False)
model = torch.nn.Linear(D_in, D_out).to(torch.device(“hpu”))
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

for t in range(500):
   y_pred = model(x)
   loss = torch.nn.functional.mse_loss(y_pred, y)

   # disable mixed precision for optimizer block
   with hmp.disable_casts():


HMP Logs

HMP provides the ability to log precision decisions for each OP for debugging purposes. You can enable verbose logs by passing isVerbose = True as an argument to hmp.convert(). The log prints the precision type each time an OP (covered by a Python decorator) is called in the model being run. See the example below:

casting  <method '__mul__' of 'torch._C._TensorBase' objects>  to  to_fp32
casting  <built-in method embedding of type object at 0x7feab47edfa0> to_fp32
casting  <function layer_norm at 0x7feaab2a4320> to_bf16
casting  <function dropout at 0x7feaab2a2e60> to_bf16
casting  <method 'matmul' of 'torch._C._TensorBase' objects> to_bf16
casting  <method '__iadd__' of 'torch._C._TensorBase' objects>  to  to_bf16

Visualizing Torch Graph

Use torchviz ( to visualize the model graph and check if cast nodes are inserted at expected positions to convert portions of the model to BF16.


Cast nodes show up as “CopyBackwards” in the graph.