PyTorch Mixed Precision Training on Gaudi
On this Page
PyTorch Mixed Precision Training on Gaudi¶
Habana Mixed Precision (HMP) package is a tool that allows you to run mixed precision training on HPU without extensive modifications to existing FP32 model scripts. You can easily add mixed precision training support to the model script by adding the following lines anywhere in the script before the start of the training loop:
from habana_frameworks.torch.hpex import hmp hmp.convert()
Any segment of script (e.g. optimizer, fallback to CPU backend, etc.) in which you want to avoid using mixed precision should be kept under
disable_casts Python context.
Make sure that any tensor inputs, of type FP32 only, going into
disable_casts context are explicitly cast to FP32.
from habana_frameworks.torch.hpex import hmp with hmp.disable_casts(): input_1 = input_1.float() .... input_n = input_n.float() code line:1 code line:2
Basic Design Rules¶
Two different lists are maintained: (i) OPs that always run in BF16 only, (ii) OPs that always run in FP32 only.
Python decorators are used to add required functionality (bf16 or fp32 casts on OP input(s)) to torch functions (refer to code snippet below).
Ensure that BFloat16 specific OPs and functions are used in place of Float16; for example, tensor.bfloat16() should be used instead of tensor.half().
Any OPs not in the above two lists will run with precision type of its 1st input (except for exceptions listed below).
For OPs with multiple tensor inputs (maintained in a separate list, e.g. add, sub, cat, stack etc.), cast all inputs to the widest precision type among all input precision types. If any of these OPs are in BF16 or FP32 list, that list has a higher precedence.
For in-place OPs (output & 1st input share storage), cast all inputs to precision type of 1st input.
from functools import wraps def op_wrap(op, cast_fn): """Adds wrapper function to OPs. All tensor inputs for the OP are casted to type determined by cast_fn provided. Args: op (torch.nn.functional/torch/torch.Tensor): Input OP cast_fn (to_bf16/to_fp32): Fn to cast input tensors Returns: Wrapper function that shall be inserted back to corresponding module for this OP. """ @wraps(op) def wrapper(*args, **kwds): args_cast = get_new_args(cast_fn, args, kwds) return op(*args_cast, **kwds) return wrapper
By default OPs that always run in BF16 and OPs that always run in FP32 are selected from an internal hard-coded BF16 list and FP32 list respectively. These internal lists are expected to work across models. BF16 list contains OPs that are numerically safe to run in lower precision on HPU, whereas FP32 list contains OPs that should be run in higher precision.
Default BF16 list = [addmm, batch_norm, bmm, conv1d, conv2d, conv3d, conv_transpose1d, conv_transpose2d, conv_transpose3d, dot, dropout, dropout1d, dropout2d, dropout3d, group_norm, instance_norm, layer_norm, leaky_relu, linear, matmul, mm, mul, mv, relu, t]
Default FP32 list = [binary_cross_entropy, binary_cross_entropy_with_logits, cross_entropy, div, divide, embedding, embedding_bag, log, log2, log_softmax, nll_loss, smooth_l1_loss, softmax, topk, truediv]
HMP provides the option of overriding these internal hard-coded lists for advanced users, allowing you to provide your own BF16 and FP32 lists.
Make sure to pass
fp32_file_path=<.txt> as arguments to
hmp.convert(). This is particularly useful when
customizing mixed precision training for a particular model to get best possible performance. An example of such custom lists can be found in
PyTorch ResNet BF16 List and
PyTorch ResNet FP32 List.
The below shows a usage example:
import torch from habana_frameworks.torch.hpex import hmp N, D_in, D_out = 64, 1024, 512 x = torch.randn(N, D_in, device='hpu') y = torch.randn(N, D_out, device='hpu') # enable mixed precision training with default BF16 list, default FP32 list and logging disabled # use bf16_file_path to provide absolute path to a file with custom BF16 list # use fp32_file_path to provide absolute path to a file with custom FP32 list # use isVerbose to disable/enable debug logs hmp.convert(bf16_file_path="", fp32_file_path="", isVerbose=False) model = torch.nn.Linear(D_in, D_out).to(torch.device(“hpu”)) optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) for t in range(500): y_pred = model(x) loss = torch.nn.functional.mse_loss(y_pred, y) optimizer.zero_grad() loss.backward() # disable mixed precision for optimizer block with hmp.disable_casts(): optimizer.step()
HMP provides the ability to log precision decisions for each OP for debugging purposes. You can enable verbose logs by passing
isVerbose = True as an argument to
hmp.convert(). The log prints the precision type each time an OP (covered by a
Python decorator) is called in the model being run. See the example below:
casting <method '__mul__' of 'torch._C._TensorBase' objects> to to_fp32 casting <built-in method embedding of type object at 0x7feab47edfa0> to_fp32 casting <function layer_norm at 0x7feaab2a4320> to_bf16 casting <function dropout at 0x7feaab2a2e60> to_bf16 casting <method 'matmul' of 'torch._C._TensorBase' objects> to_bf16 casting <method '__iadd__' of 'torch._C._TensorBase' objects> to to_bf16
Visualizing Torch Graph¶
Use torchviz (https://github.com/szagoruyko/pytorchviz) to visualize the model graph and check if cast nodes are inserted at expected positions to convert portions of the model to BF16.
Cast nodes show up as “CopyBackwards” in the graph.