Optimizing PyTorch Models

Batch Size

A large batch size is, in general, beneficial for throughput. However, some limitations, listed below, apply when using large batch size:

  1. Batch size is limited by Gaudi’s device memory (HBM) size. Usually, larger batch size means more memory consumption in device. Gaudi device memory size is a fixed size.

  2. Large batch size cannot be used when low latency instead of throughput is required.

  3. Large batch size in each Gaudi device may impact the convergence in data parallelism distributed training. For example, the highest global batch size that gives RN50 convergence is around 32K. This means that with an increasing number of Gaudi devices, batch size should be reduced in each device.

The below table provides some examples of batch sizes used in different models, all using mixed precision.


Batch Size



Bert Large pre-training Phase 1


Bert Large pre-training Phase 2




PyTorch Mixed Precision

For details on how to run mixed precision training of PyTorch models on Gaudi, refer to PyTorch Mixed Precision Training on Gaudi.

Usage of Fused Operators

Create a custom op for optimizers (E.g. FusedSGD, FusedAdamW) and other complex ops (e.g FusedClipNorm) to minimize host performance overheads of running many small ops. This can improve the overlap of execution between host and device.

Habana PyTorch package provides some Handling Custom Habana Ops for PyTorch.


Refer to the custom operator FusedSGD in ResNet50 FusedSGD

Adjust the Gradient Bucket Size in Multi-card/Multi-node Training

Based on the size of the model, the size of the gradient bucket can be adjusted to minimize the number of invocations of all-reduce in the backward pass of every training iteration. Documentation is available in PyTorch DDP.


In ResNet50, bucket size of 100MB is optimal whereas ResNext101 requires bucket size of 200MB. Refer to the implementation here.

Setting Gradients as View of Gradient Buckets in Multi-card/Multi-node Training

PyTorch DDP allows parameter gradient tensors to be views of the gradient bucket. This improves performance as device-to-device copies can be reduced and also reduces device memory requirement. Documentation is available in PyTorch DDP.


Refer to the implementation for ResNet50.

Reducing the Frequency of Printing Quantities

In cases where models have been fully optimized and set for production usage, some output messaging should be reduced or eliminated for best performance. The following are two specific examples:

  • Reporting loss using loss.item() or calculating loss to display to the user

  • Showing the progress bar (using TDQM or other libraries) during runtime

Both of these items rely on additional communication between the host CPU and the Gaudi HPU to calculate loss or progress and then display the results. Printing these tensors in the training script requires these device tensors to be pulled to the host CPU and therefore needs the device execution to finish. This can result in non-overlapped execution between host and device leading to sub-optimal performance.

To reduce loss calculation or progress bar update, set the print frequency --print-freq to a high value or eliminate it altogether. You can set the --print-freq variable in the model run command to a size similar to the optimizer step size. For the progress bar, it is recommended to Wait until a run completes 20 or more iterations to minimize unnecessary synchronization.

Pinning Memory For Dataloader

Pinning the memory while instantiating the dataloader avoids a redundant copy in host during the training iteration. Refer to support in PyTorch Dataloader


Refer to the implementation for ResNet50 Dataloader.

Avoiding Constant Variables in Loops

Avoiding the use of loop iterator variables within a loop may reduce the need for recompilations happening in consecutive iterations. Such a loop iterator variable may cause a creation of different constant operators in the execution graph every iteration.

For example, in the original V-Diffusion code the value of the iterator variable changes each time the loop iterates. To avoid triggering recompilations after each iteration, the loop iterator variable i is not used in Habana’s V-Diffusion model.

for i in range(4, num_steps):
     # The following 3 lines remove graph recompilation (variable "i" is not used)
     t_1 = steps[0] # before: steps[i]
     t_2 = steps[1] # before: steps[i+1]
     steps = torch.roll(steps, shifts=(-1), dims=(0))


Refer to the implementation for Habana’s V-Diffusion model and compare it with the original V-Diffusion code.