Habana Mixed Precision

Overview

Habana Mixed Precision (HMP) package is a tool that allows you to run mixed precision training on HPU without extensive modifications to existing FP32 model scripts. You can easily add mixed precision training support to the model script by adding the following lines anywhere in the script before the start of the training loop:

from habana_frameworks.torch.hpex import hmp
  hmp.convert()

Any segment of script (e.g. optimizer, fallback to CPU backend, etc.) in which you want to avoid using mixed precision should be kept under disable_casts Python context. Make sure that any tensor inputs, of type FP32 only, going into disable_casts context are explicitly cast to FP32.

from habana_frameworks.torch.hpex import hmp
with hmp.disable_casts():
  input_1 = input_1.float()
  ....
  input_n = input_n.float()
  code line:1
  code line:2

Basic Design Rules

  • Two different lists are maintained: (i) OPs that always run in BF16 only, (ii) OPs that always run in FP32 only.

  • Python decorators are used to add required functionality (bf16 or fp32 casts on OP input(s)) to torch functions (refer to code snippet below).

  • Ensure that BFloat16 specific OPs and functions are used in place of Float16; for example, tensor.bfloat16() should be used instead of tensor.half().

  • Any OPs not in the above two lists will run with precision type of its 1st input (except for exceptions listed below).

  • For OPs with multiple tensor inputs (maintained in a separate list, e.g. add, sub, cat, stack etc.), cast all inputs to the widest precision type among all input precision types. If any of these OPs are in BF16 or FP32 list, that list has a higher precedence.

  • For in-place OPs (output & 1st input share storage), cast all inputs to precision type of 1st input.

from functools import wraps
def op_wrap(op, cast_fn):
"""Adds wrapper function to OPs. All tensor inputs
for the OP are casted to type determined by cast_fn
provided.

Args:
op (torch.nn.functional/torch/torch.Tensor): Input OP
cast_fn (to_bf16/to_fp32): Fn to cast input tensors

Returns:
Wrapper function that shall be inserted back to
corresponding module for this OP.
"""
@wraps(op)
def wrapper(*args, **kwds):
    args_cast = get_new_args(cast_fn, args, kwds)
    return op(*args_cast, **kwds)

return wrapper

Configuration Options

By default OPs that always run in BF16 and OPs that always run in FP32 are selected from an internal hard-coded BF16 list and FP32 list respectively. These internal lists are expected to work across models. BF16 list contains OPs that are numerically safe to run in lower precision on HPU, whereas FP32 list contains OPs that should be run in higher precision.

  • Default BF16 list = [addmm, batch_norm, bmm, conv1d, conv2d, conv3d, conv_transpose1d, conv_transpose2d, conv_transpose3d, dot, dropout, dropout1d, dropout2d, dropout3d, group_norm, instance_norm, layer_norm, leaky_relu, linear, matmul, mm, mul, mv, relu, t]

  • Default FP32 list = [binary_cross_entropy, binary_cross_entropy_with_logits, cross_entropy, div, divide, embedding, embedding_bag, log, log2, log_softmax, nll_loss, smooth_l1_loss, softmax, topk, truediv]

HMP provides the option of overriding these internal hard-coded lists for advanced users, allowing you to provide your own BF16 and FP32 lists. Make sure to pass bf16_file_path=<.txt> and fp32_file_path=<.txt> as arguments to hmp.convert(). This is particularly useful when customizing mixed precision training for a particular model to get best possible performance. An example of such custom lists can be found in PyTorch ResNet BF16 List and PyTorch ResNet FP32 List.

The below shows a usage example:

import torch
from habana_frameworks.torch.hpex import hmp

N, D_in, D_out = 64, 1024, 512
x = torch.randn(N, D_in, device='hpu')
y = torch.randn(N, D_out, device='hpu')

# enable mixed precision training with default BF16 list, default FP32 list and logging disabled
# use bf16_file_path to provide absolute path to a file with custom BF16 list
# use fp32_file_path to provide absolute path to a file with custom FP32 list
# use isVerbose to disable/enable debug logs
hmp.convert(bf16_file_path="", fp32_file_path="", isVerbose=False)
model = torch.nn.Linear(D_in, D_out).to(torch.device(“hpu”))
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

for t in range(500):
   y_pred = model(x)
   loss = torch.nn.functional.mse_loss(y_pred, y)
   optimizer.zero_grad()
   loss.backward()

   # disable mixed precision for optimizer block
   with hmp.disable_casts():
      optimizer.step()

Debugging

HMP Logs

HMP provides the ability to log precision decisions for each OP for debugging purposes. You can enable verbose logs by passing isVerbose = True as an argument to hmp.convert(). The log prints the precision type each time an OP (covered by a Python decorator) is called in the model being run. See the example below:

casting  <method '__mul__' of 'torch._C._TensorBase' objects>  to  to_fp32
casting  <built-in method embedding of type object at 0x7feab47edfa0> to_fp32
casting  <function layer_norm at 0x7feaab2a4320> to_bf16
casting  <function dropout at 0x7feaab2a2e60> to_bf16
casting  <method 'matmul' of 'torch._C._TensorBase' objects> to_bf16
casting  <method '__iadd__' of 'torch._C._TensorBase' objects>  to  to_bf16

Visualizing Torch Graph

Use torchviz (https://github.com/szagoruyko/pytorchviz) to visualize the model graph and check if cast nodes are inserted at expected positions to convert portions of the model to BF16.

Note

Cast nodes show up as “CopyBackwards” in the graph.