Handling Dynamic Shapes

This document refers to Dynamic Shapes as a generic term describing model behavior characteristics in which some topologies are subjected to certain dynamicity of their input data or operators producing variable shape of output tensors. Such dynamicity in a model may cause constant re-compilations, leading to longer training time.

The following sections discuss methods to detect Dynamic Shapes and mitigate their impact on performance.

Types of Dynamicity

Dynamic Shapes can be broadly classified into two categories:

  • Inputs - Dynamic shapes due to varying input shapes during training, such as:

    • Varying sentence lengths in language models

    • Differing image resolutions in image model

  • Ops - Dynamic shapes due to Ops occur for certain Ops whose output shape depends on the actual input data, rather than only the input shapes, that is Ops with non-inferable output shapes given input shapes.

Detecting and Mitigating Dynamicity Overview

Dynamicity, resulting from changing input shapes or dynamic ops, can lead to multiple recompilations, causing a longer training time and reducing performance. Below are the guidelines to detect and mitigate this issue:

  1. Detecting Model Recompilation - To detect if the model is recompiling, set the environment flag as follows: PT_HPU_METRICS_FILE=/root/metricslog.json PT_HPU_METRICS_DUMP_TRIGGERS=process_exit,metric_change. The image below displays the text graph_compilation, which is dumped into the specified JSON file every time a compilation occurs. For static graphs, a reduction in recompilations is expected after a few steps. If recompilations continue to exist, go to step 2.

  1. Enabling Dynamicity Support from Graph Compiler - To enable dynamicity support, set the PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES=1, as it is disabled by default. For further details, refer to Graph Compiler Dynamicity Support. If recompilations continue to exist, or you encounter instability and want to achieve better performance, go to step 3 and 4.

  2. Detecting Dynamic Shapes due to Varying Inputs - Use the data_dynamicity tool to generate a report on input dynamicity as described in Detecting Dynamic Inputs. If there is a large amount of variability in the input data, use data bucketing and padding as described in Mitigation Through Bucketing and Padding.

  3. Detecting Dynamic Shapes due to Ops - Use the detect_recompilation_auto_model tool to automatically detect which part of the model has dynamic ops as described in Detecting Dynamic Ops. If the model has dynamic ops, replace the dynamic ops with static ops as described in Replacing Dynamic Ops with Static Ops. If replacing dynamic ops with static ops is not possible, split dynamic and static sections to make recompiling section smaller as described in Splitting Dynamic and Static Sections.

  4. If you are running steps 2,3 and 4 in a multi-card scenario, enable recipe caching via PT_HPU_RECIPE_CACHE_CONFIG, for details please refer to Runtime Environment Variables. This allows one card to use the previously compiled graphs from another card.

Graph Compiler Dynamicity Support

To enable dynamicity support, set the PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES=1, as it is disabled by default. If a model experiences excessive recompilations due to dynamic data or ops, this variable can be set to enable the Intel Gaudi PyTorch bridge and graph compiler to automatically manage dynamic shapes in model scripts. The graphs will be automatically bucketed and padded into ranges to achieve a common size, reducing recompilations and improving performance when working with dynamic workloads.

In a multi-card scenario, it is also essential to enable recipe caching via PT_HPU_RECIPE_CACHE_CONFIG, for details please refer to Runtime Environment Variables. This allows one card to use previously compiled graphs from another card. If recompilations continue to exist, or you encounter instability, please refer to the following sections to detect dynamicity and find tips on rewriting the model.

Dynamic Shapes due to Varying Inputs

Detecting Dynamic Inputs

The first step is to find the distribution of the input data’s shapes by using the data_dynamicity tool to generate a report on input dynamicity. The example below shows how to use this tool on a sample MNIST dataset:

from habana_frameworks.torch.utils.experimental import data_dynamicity
import torchvision
from torch.utils.data import DataLoader

# Creating a sample MNIST dataloader
mnist_ds = torchvision.datasets.MNIST('mnist', download=True, transform=torchvision.transforms.ToTensor())
mnist_dl = DataLoader(mnist_ds, batch_size=7, num_workers=2)


Note that MNIST dataset has a constant image shape, but two distinct input shapes are obtained, since the last batch has less number of images if we do not have drop_last=True (in mnist_dl = DataLoader). However, this is a very low amount of input dynamicity, so nothing is required on the input.

Food101 dataset has images of different shapes. In the example below, batches are created by cropping the images to the size of the smallest image in each batch:

from habana_frameworks.torch.utils.experimental import data_dynamicity
import torchvision
from torch.utils.data import DataLoader
import torch

def collate(batch):
   dim1 = min([k[0].shape[1] for k in batch])
   dim2 = min([k[0].shape[2] for k in batch])
   images = torch.stack([k[0][:,:dim1,:dim2] for k in batch])
   labels = torch.tensor([k[1] for k in batch])
   return (images,labels)

food101_ds = torchvision.datasets.Food101('food101', download=True, transform=torchvision.transforms.ToTensor())
food101_dl = DataLoader(food101_ds, batch_size=7, num_workers=2, collate_fn=collate)

A large amount of input dynamicity is obtained for Food101.

Number of unique shapes:  66
# There is a lot of dynamicity in input data shapes

If there is a large amount of variability in the input data, data padding and/or data bucketing can be used. Using a bucketing and padding mechanism normalizes the input shapes to a limited number of shapes without sacrificing training accuracy while potentially improving training performance on Gaudi.

Mitigation Through Bucketing and Padding

To reduce the number of unique shapes of the input data, Bucketing and Padding can be applied. This involves dividing the input data shapes from minimum size to maximum size into certain buckets and padding the datapoints up to the smallest bucket larger than the largest data in each batch. For bucketing algorithms, the number of buckets is a hyper parameter; choosing a number of buckets that is too large will cause many recompilations, while choosing a number of buckets that is too low will require excess padding, causing wasted computation.

The example below shows the steps needed to take a random dataset of wav files, create two buckets and then sort and pad the data.


Bucketing algorithms can be classified into two categories. You can use any novel padding/bucketing strategy to get the best tradeoff between bucketing algorithm time, compilation delays and HPU runtime speeds. The following table lists the suggested bucketing algorithms:

Bucketing Algorithm Type


Optimization based

These algorithms have a criterion function to optimize. For example, the linear programming approach to reduce amount of padding wastage used in fairseq. Usually these methods are slower, but are more optimal in terms of using less padding.

Fast Heuristics-based

The following are bucketing algorithms that are quick and easy solutions for bucketing. However, these methods usually have more padding wastage than optimization based methods.

Preprocessing Input Dataset

When the input dataset is large, processing all the shapes in it, especially for costlier optimization based bucketing algorithms, might be too slow. To work around this issue the following can be done:

  • Shuffle the shapes of the dataset and pick a small number of shapes to pass to the bucketing algorithm, reducing its workload.

  • Instead of creating a histogram for each shape, create a histogram with a wider bin, that is multiple contiguous shapes fall in the same bin.

Saving Bucketing Results

If the bucketing algorithm takes a lot of time, it might not be ideal to compute it every time you run the script. You can store the results of the bucketing algorithm based on the hash of the input shapes. For example, see here.

import hashlib
lst = [2, 4, 5, 3, 6, 3, 2]
key = hashlib.sha256(bytes(sorted(lst))).hexdigest().upper()

Further Notes

Whenever the probability distribution of the dynamic shapes at input are present, you may need to run bucketing. Some examples of multiple distributions are:

  • Multiple inputs - A model might have two inputs, each with its own dynamic distribution of input shapes. Running bucketing separately on both inputs is recommended. An example of such a scenario is translation task, where the source and target language sentences might have different distributions.

  • Multiple datasets - If you have multiple datasets, running bucketing algorithm for each dataset separately is recommended. For example, Roboflow is made of 100 different datasets, all training on the same model in 100 different runs.

Variable Batch Size

Variable batch size can be utilized through bucketing. If the same batch size is set, the shorter input data under utilize the device, since a larger batch size could have been accommodated. However, if too large a batch size is used, larger input data might throw Out-of-Memory errors. Therefore, it is recommended to use a larger batch size for shorter sequence and vice versa. One possibility is batch_size=max_num_words/sentence_length.

Refer to Memory Stats APIs for more information on memory usage.

Dynamic Shapes due to Ops

For certain Ops, such as the following, the output shape cannot be predicted even if we have the input shape:

Detecting Dynamic Ops

To detect dynamic Ops, perform a simple string search in the repo for ops or run the script with GRAPH_VISUALIZATION=1 to create a .graph_dumps folder and monitor that folder to see if the number of graphs dumped keeps increasing. If dynamic shapes are present, multiple recompilations will cause the number of dumps in .graph_dumps to increase.


To detect dynamic Ops, make sure to use same sized inputs to avoid recompilations due to varying input shapes.

For automatic detection of which part of the model has dynamic ops, use the detect_recompilation_auto_model tool. Consider the code snippet below which shows inference on a model with a dynamic op (boolean indexing) and dynamic inputs (changing batch size):

from habana_frameworks.torch.utils.experimental import detect_recompilation_auto_model
import torch

class InnerNet(torch.nn.Module):
   def __init__(self):
      super(InnerNet, self).__init__()
      self.conv = torch.nn.Conv2d(1, 8, 3, 3)

   def forward(self, x):
      x = torch.flatten(self.conv(x), 1)
      x = x[x>0] # This is dynamic
      return x.sum()

net = torch.nn.Sequential(torch.nn.ReLU(), InnerNet())
net = detect_recompilation_auto_model(net)

for bs in [20,20,20,30,30]: #Input shape changes at 4th step
   inp = torch.rand(bs, 1, 50, 50).to('hpu')
net.analyse_dynamicity() # Call this after a few steps to generate the dynamicity report

This produces the following report (along with the CSV versions of it):

  • Step 0: The first four lines of the first table show all four modules recompile since it is the first step.

  • Step 1: The next two lines show Net and InnerNet recompile. The “Comment” column, however, shows that InnerNet might be dynamic because it recompiled even without dynamic children modules, while Net might not be dynamic as it might have recompiled because its child (InnerNet) has recompiled.

  • Step 2: The next two lines show Step 2 which is similar to Step 1.

  • Step 3: The next four lines show Step 3, where a new input shape is seen, so every module recompiles as expected shown in the “Comment” column.

  • Step 4: The last two lines for Step 4 again point to InnerNet as having dynamic ops.

The table below is a summarized view that shows which modules recompile the most. As expected, Net/1 of type InnerNet (and its parent) show up at the top.



The detect_recompilation_auto_model tool slows down execution, so it is only intended for debugging purposes and should be removed for the actual run.

Mitigation Techniques for Dynamic Ops

Replacing Dynamic Ops with Static Ops

Removing dynamicity needs a case-by-case inspection but the general guideline is identifying the section of the code where dynamicity starts and ends, then replacing that with static code. See the below examples.

  • Example 1 - This example of boolean indexing which is dynamic can be replaced to make the code static:

#Dynamic code

non_pad_mask = targets.ne(self.padding_idx).view(-1)
smooth_loss = smooth_loss[non_pad_mask]
smooth_loss = smooth_loss.sum()

#Static Code

non_pad_mask = targets.ne(self.padding_idx).view(-1)
smooth_loss = smooth_loss.squeeze() * non_pad_mask.int()
smooth_loss = smooth_loss.sum()
  • Example 2 - If we have a tensor A, B and C of length N. Where C[i] is –1 we want to filter out, rest we want to multiply A[i]*B[i] and add up the result. Pseudocode - res = sum([A[i]*B[i] for i in range(len(C)) if C[i] != -1]).

#Dynamic code

import torch
import habana_frameworks.torch.core as htcore

A = torch.tensor([1,2,3,4,5]).to('hpu')
B = torch.tensor([6,7,8,9,10]).to('hpu')
C = torch.tensor([1,-1,1,1,-1]).to('hpu')

A_filtered=torch.gather(A, 0, indices[0])
B_filtered=torch.gather(B, 0, indices[0])

res_dyn = torch.sum(A_filtered * B_filtered)

#Static Code

res_tmp = A*B
prod_filtererd = torch.where(C!=-1, res_tmp, torch.zeros_like(res_tmp))
res_stat = torch.sum(prod_filtererd)

assert (res_stat == res_dyn).item()

In this example we identify the start of dynamicity (where) and its end (sum) and replace it with a static implementation. The diagram below shows the two possible paths, with the dynamic nodes and tensors marked in red.


Explicit CPU Fallback

In normal mode, Ops that are not supported on HPU are already automatically placed on the CPU as explained in Placement of Ops on HPU. Explicit CPU fallback refers to moving certain Ops to the CPU (and possibly later bringing them back on the HPU) explicitly for dynamic Ops.

#say x is on HPU
z = torch.unique(x) #this op is done on HPU because input x is on HPU
# can be replaced by:
z = torch.unique(x) # runs on cpu
#move tensors back to HPU if needed
x = x.to('hpu')
z = z.to('hpu')

Sections on the CPU can be sped up using torch jit by wrapping them in the @torch.jit.script decorator.

Splitting Dynamic and Static Sections

Consider a case where the network starts with static layers, has a few dynamic layers, and then ends with static layers. Under normal circumstances the whole network will recompile in each step, making execution slow. However, we can add a mark_step between the static and dynamic sections of the network. With this change, the static part will compile only once, and the dynamic part is smaller (compared to the whole network) and hence compiles faster.

# The whole network static1->dynamic->static2 recompiles
x = static1(x)
x = dynamic(x)
x = static2(x)

# After splitting them, now only dynamic recompiles. static1 and static2 compile only once
x = static1(x)
x = dynamic(x)
x = static2(x)


When possible, replacing dynamic code with static code, as detailed in Replacing Dynamic Ops with Static Ops, is recommended. If it cannot be done, try CPU fallback or splitting dynamic and static Ops or include no change to figure out which option is fastest. Which method works better changes depending on the model as each method has its disadvantages. CPU fallback might trigger costly host-to-device or device-to-host copies. While splitting reduces the compile time, there is still dynamicity and hence compilations will happen.

Dynamic Shapes in Distributed Training

If we have S shapes and N cards, in the worst case scenario, N*(S-1) + 1 compilation delays is observed. This means only 1 card is compiling a new shape, while the other cards receive shapes they have compiled before. Those cards finish compiling faster and remain idle. The figures below show the inefficiency of cards having to sit idle while waiting to compile new shapes (highlighted).


To mitigate this issue, use across rank caching. When a new shape is encountered, its recipe is saved to file, allowing other cards to use it instead of recompiling each new shape.


Case Study Using Wav2vec2 for Dynamic Inputs to Models

Models with dynamic input shapes are constantly recompiled resulting in longer training time. In audio models, it is common that the input to the model are wave files which contain variant length. The steps below outline the Bucketing and Padding solution steps required to achieve better training performance on Gaudi using Wav2vec2 as an example.

Bucketing and Padding Solution Example

  1. Sort the wave files according to the length. This is to make sure that wave files with similar length are grouped together to decrease possible padding.

  2. Define the size of buckets. According to the distribution of wave file length in the dataset and the number of bucket slots the user defined, define the length of each bucket slot. A general rule in Wave2vec is to define the length of a bucket so that the number of files falling into different buckets are similar.

Below is the example code that Wav2vec uses to define the length of different bucket via numpys percentile function, where sizes is an array that contains the size of all wave files and num_buckets is the number of buckets you want to use:

def get_buckets(sizes, num_buckets):
   buckets = np.unique(
            np.linspace(0, 100, num_buckets + 1),
   return buckets
  1. Split the dataset into different batches. For example, in Wav2vec, a threshold value is defined to make sure the total file size in one batch does not exceed this threshold value. Refer to the example here.

  2. Padding:

    1. Pad wave files in the same batch to the same file size. To make sure each wave file in the same batch has the same size, it is padded to the max file size in that batch.

    2. Pad wave files in the same batch to the bucket size. When the file length in the batch is not the same as the bucket length, it needs to be padded to match the length of the bucket with the closest distance. This way, even if there are a lot of batches, the shape of those batches will be limited.

    You can find a padding example using Wav2vec2 here.

With this Bucketing and Padding solution, the number of dynamic shapes of the input to Wav2vec can be greatly decreased. Refer to the implementation in Wav2vec2 Model Reference on GitHub.