PyTorch Mixed Precision Training on Gaudi¶
Habana Mixed Precision (HMP) package is a tool that allows you to run mixed precision training on HPU without extensive modifications to existing FP32 model scripts. You can easily add mixed precision training support to the model script by adding the following lines anywhere in the script before the start of the training loop:
from habana_frameworks.torch.hpex import hmp hmp.convert()
Any segment of script (e.g. optimizer) in which you want to avoid using mixed precision should be kept under the following Python context:
from habana_frameworks.torch.hpex import hmp with hmp.disable_casts(): code line:1 code line:2
Two different lists are maintained: (i) OPs that always run in BF16 only, (ii) OPs that always run in FP32 only.
Python decorators are used to add required functionality (bf16 or fp32 casts on OP input(s)) to torch functions (refer to code snippet below).
Any OPs not in the above two lists will run with precision type of its 1st input (except for exceptions listed below).
For OPs with multiple tensor inputs (maintained in a separate list, e.g. add, sub, cat, stack etc.), cast all inputs to the widest precision type among all input precision types. If any of these OPs are in BF16 or FP32 list, that list has a higher precedence.
For in-place OPs (output & 1st input share storage), cast all inputs to precision type of 1st input.
from functools import wraps def op_wrap(op, cast_fn): """Adds wrapper function to OPs. All tensor inputs for the OP are casted to type determined by cast_fn provided. Args: op (torch.nn.functional/torch/torch.Tensor): Input OP cast_fn (to_bf16/to_fp32): Fn to cast input tensors Returns: Wrapper function that shall be inserted back to corresponding module for this OP. """ @wraps(op) def wrapper(*args, **kwds): args_cast = get_new_args(cast_fn, args, kwds) return op(*args_cast, **kwds) return wrapper
HMP provides two modes (opt_level = O1/O2) of mixed precision training to choose from. These modes can be chosen
opt_level= as an argument to
O1 is the default and recommended mode of operation when using HMP. O2 can be used for debugging convergence issues as well as for initial iterations of converting a new model to run with mixed precision.
In this mode, OPs that always run in BF16 and OPs that always run in FP32 are selected from a BF16 list and FP32 list respectively. BF16 list contains OPs that are numerically safe to run in lower precision on HPU, whereas FP32 list contains OPs that should be run in higher precision (conservative choice that works across models).
Default BF16 list = [addmm, bmm, conv1d, conv2d, conv3d, dot, mm, mv]
Default FP32 list = [batch_norm, cross_entropy, log_softmax, softmax, nll_loss, topk]
HMP provides the option of overriding these internal lists, allowing you to provide your own BF16 and FP32 lists (pass
fp32_file_path=<.txt> as arguments to
hmp.convert()). This is particularly useful when
customizing mixed precision training for a particular model. For example:
Custom BF16 list for ResNet50 = [ addmm, avg_pool2d, bmm, conv2d, dot, max_pool2d, mm, mv, relu, t, linear]
Custom FP32 list for ResNet50 = [cross_entropy, log_softmax, softmax, nll_loss, topk]
import torch from habana_frameworks.torch.hpex import hmp N, D_in, D_out = 64, 1024, 512 x = torch.randn(N, D_in, device=“hpu”) y = torch.randn(N, D_out, device=“hpu”) # enable mixed precision training with optimization level O1, default BF16 list, default FP32 list and logging disabled # use opt_level to select desired mode of operation # use bf16_file_path to provide absolute path to a file with custom BF16 list # use fp32_file_path to provide absolute path to a file with custom FP32 list # use isVerbose to disable/enable debug logs hmp.convert(opt_level="O1", bf16_file_path="", fp32_file_path="", isVerbose=False) model = torch.nn.Linear(D_in, D_out).to(torch.device(“hpu”)) optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) for t in range(500): y_pred = model(x) loss = torch.nn.functional.mse_loss(y_pred, y) optimizer.zero_grad() loss.backward() # disable mixed precision for optimizer block with hmp.disable_casts(): optimizer.step()
HMP provides the ability to log precision decisions for each OP for debugging purposes. You can enable verbose logs by passing
isVerbose = True as an argument to
hmp.convert(). The log prints the precision type each time an OP (covered by a
Python decorator) is called in the model being run. See the example below:
casting <method '__mul__' of 'torch._C._TensorBase' objects> to to_fp32 casting <built-in method embedding of type object at 0x7feab47edfa0> to_fp32 casting <function layer_norm at 0x7feaab2a4320> to_bf16 casting <function dropout at 0x7feaab2a2e60> to_bf16 casting <method 'matmul' of 'torch._C._TensorBase' objects> to_bf16 casting <method '__iadd__' of 'torch._C._TensorBase' objects> to to_bf16