Handling Dynamic Shapes

This document refers to Dynamic Shapes as a generic term describing model behavior characteristics in which some topologies are subjected to certain dynamicity of their input data or operators producing variable shape of output tensors. Such dynamicity in a model may cause constant re-compilations, leading to longer training time.

The following sections discuss methods to detect Dynamic Shapes and mitigate their impact on performance.

Types of Dynamicity

Dynamic Shapes can be broadly classified into two categories:

  • Inputs - Dynamic shapes due to varying input shapes during training, such as:

    • Varying sentence lengths in language models

    • Differing image resolutions in image model

  • Ops - Dynamic shapes due to Ops occur for certain Ops whose output shape depends on the actual input data, rather than only the input shapes, that is Ops with non-inferable output shapes given input shapes.

The below table provides a brief description for detecting and mitigating dynamicity:

Type of Model



Check if model has input dynamicity: Detecting Dynamic Inputs

Has input dynamicity

Fix input shape for testing purposes and check if model has dynamic Ops. See Detecting Dynamic Ops.

Does not have input dynamicity

Check if model has dynamic Ops. See Detecting Dynamic Ops.

Dynamic shapes in multi-card


See Dynamic Shapes in Distributed Training

Dynamic Shapes due to Varying Inputs

Detecting Dynamic Inputs

The first step is to find the probability distribution of the input data’s shapes. You can run a forward pass of the training loop (without the actual training) to find the distribution of input shapes. The below is an example pseudocode:

histogram_of_shapes = {}

for data in dataloader:

      shape = data.shape

      histogram_of_shapes[shape] = histogram_of_shapes.get(shape, 0) + 1

If there is a large amount of variability in the input data, data padding and/or data bucketing can be used. Using a bucketing and padding mechanism normalizes the input shapes to a limited number of shapes without sacrificing training accuracy while potentially improving training performance on Gaudi.

Mitigation Through Bucketing and Padding

To reduce the number of unique shapes of the input data, Bucketing and Padding can be applied. This involves dividing the input data shapes from minimum size to maximum size into certain buckets and padding the datapoints up to the smallest bucket larger than the largest data in each batch. For bucketing algorithms, the number of buckets is a hyper parameter; choosing a number of buckets that is too large will cause many recompilations, while choosing a number of buckets that is too low will require excess padding, causing wasted computation.

A simple bucketing strategy is to pad all datapoints to the largest input data size. However, this method leads to wasting computational power for the smaller data points. A common bucketing strategy is percentile based bucketing as shown in shown in Case Study Using Wav2vec2 for Dynamic Inputs to Models.

You can use novel padding/bucketing strategies to get the best tradeoff between compilation delays and runtime speeds.

Variable Batch Size

Variable batch size can be utilized through bucketing. If the same batch size is set, the shorter input data under utilize the device, since a larger batch size could have been accommodated. However, if too large a batch size is used, larger input data might throw Out-of-Memory errors. Therefore, it is recommended to use a larger batch size for shorter sequence and vice versa. One possibility is batch_size=max_num_words/sentence_length.

Refer to Memory Stats APIs for more information on memory usage.

Dynamic Shapes due to Ops

For certain Ops, such as the following, the output shape cannot be predicted even if we have the input shape:

Detecting Dynamic Ops

To detect if dynamic Ops, perform a simple string search in the repo for ops or run the script with GRAPH_VISUALIZATION=1 to create a .graph_dumps folder and monitor that folder to see if the number of graphs dumped keeps increasing. If dynamic shapes are present, multiple recompilations will cause the number of dumps in .graph_dumps to increase.


To detect dynamic Ops, make sure to use same sized inputs to avoid recompilations due to varying input shapes.

Mitigation Techniques for Dynamic Ops

Replacing Dynamic Ops with Static Ops

Removing dynamicity needs a case-by-case inspection but the general guideline is identifying the section of the code where dynamicity starts and ends, then replacing that with static code. See the below examples.

  • Example 1 - This example of boolean indexing which is dynamic can be replaced to make the code static:

#Dynamic code

non_pad_mask = targets.ne(self.padding_idx).view(-1)
smooth_loss = smooth_loss[non_pad_mask]
smooth_loss = smooth_loss.sum()

#Static Code

non_pad_mask = targets.ne(self.padding_idx).view(-1)
smooth_loss = smooth_loss.squeeze() * non_pad_mask.int()
smooth_loss = smooth_loss.sum()
  • Example 2 - If we have a tensor A, B and C of length N. Where \(C[i]\) is \(–1\) we want to filter out, rest we want to multiply \(A[i]*B[i]\) and add up the result. Pseudocode - res = sum([A[i]*B[i] for i in range(len(C)) if C[i] != -1]).

#Dynamic code

import torch

import habana_frameworks.torch.core as htcore

A = torch.tensor([1,2,3,4,5]).to('hpu')

B = torch.tensor([6,7,8,9,10]).to('hpu')

C = torch.tensor([1,-1,1,1,-1]).to('hpu')


A_filtered=torch.gather(A, 0, indices[0])

B_filtered=torch.gather(B, 0, indices[0])

res_dyn = torch.sum(A_filtered * B_filtered)

#Static Code

res_tmp = A*B

prod_filtererd = torch.where(C!=-1, res_tmp, torch.zeros_like(res_tmp))

res_stat = torch.sum(prod_filtererd)

assert (res_stat == res_dyn).item()

In this example we identify the start of dynamicity (where) and its end (sum) and replace it with a static implementation. The diagram below shows the two possible paths, with the dynamic nodes and tensors marked in red.


Explicit CPU Fallback

In normal mode, Ops that are not supported on HPU are already automatically placed on the CPU as explained in Placement of Ops on HPU. Explicit CPU fallback refers to moving certain Ops to the CPU (and possibly later bringing them back on the HPU) explicitly for dynamic Ops.

#say x is on HPU
z = torch.unique(x) #this op is done on HPU because input x is on HPU
# can be replaced by:
z = torch.unique(x) # runs on cpu
#move tensors back to HPU if needed
x = x.to('hpu')
z = z.to('hpu')

Sections on the CPU can be sped up using torch jit by wrapping them in the @torch.jit.script decorator.

If you do explicit CPU fallback while also using Habana Mixed Precision (hmp), you may need to disable casts around the section of the code that is on CPU, using hmp.disable_casts(). An example can be found here.

Splitting Dynamic and Static Sections

Consider a case where the network starts with static layers, has a few dynamic layers, and then ends with static layers. Under normal circumstances the whole network will recompile in each step, making execution slow. However, we can add a mark_step between the static and dynamic sections of the network. With this change, the static part will compile only once, and the dynamic part is smaller (compared to the whole network) and hence compiles faster.

# The whole network static1->dynamic->static2 recompiles
x = static1(x)
x = dynamic(x)
x = static2(x)

# After splitting them, now only dynamic recompiles. static1 and static2 compile only once
x = static1(x)
x = dynamic(x)
x = static2(x)


When possible, replacing dynamic code with static code, as detailed in Replacing Dynamic Ops with Static Ops, is recommended. If it cannot be done, try CPU fallback or splitting dynamic and static Ops or include no change to figure out which option is fastest. Which method works better changes depending on the model as each method has its disadvantages. CPU fallback might trigger costly host-to-device or device-to-host copies. While splitting reduces the compile time, there is still dynamicity and hence compilations will happen.

Dynamic Shapes in Distributed Training

If we have \(S\) shapes and \(N\) cards, in the worst case scenario, \(N*(S-1) + 1\) compilation delays is observed. This means only 1 card is compiling a new shape, while the other cards receive shapes they have compiled before. Those cards finish compiling faster and remain idle. The figures below show the inefficiency of cards having to sit idle while waiting to compile new shapes (highlighted).

To mitigate this issue, make sure all cards get all unique shapes of data in the first iterations.


Case Study Using Wav2vec2 for Dynamic Inputs to Models

Models with dynamic input shapes are constantly recompiled resulting in longer training time. In audio models, it is common that the input to the model are wave files which contain variant length. The steps below outline the Bucketing and Padding solution steps required to achieve better training performance on Gaudi using Wav2vec2 as an example.

Bucketing and Padding Solution Example

  1. Sort the wave files according to the length. This is to make sure that wave files with similar length are grouped together to decrease possible padding.

  2. Define the size of buckets. According to the distribution of wave file length in the dataset and the number of bucket slots the user defined, define the length of each bucket slot. A general rule in Wave2vec is to define the length of a bucket so that the number of files falling into different buckets are similar.

Below is the example code that Wav2vec uses to define the length of different bucket via numpys percentile function, where sizes is an array that contains the size of all wave files and num_buckets is the number of buckets you want to use:

def get_buckets(sizes, num_buckets):
   buckets = np.unique(
            np.linspace(0, 100, num_buckets + 1),
   return buckets
  1. Split the dataset into different batches. For example, in Wav2vec, a threshold value is defined to make sure the total file size in one batch does not exceed this threshold value. Refer to the example here.

  2. Padding:

    1. Pad wave files in the same batch to the same file size. To make sure each wave file in the same batch has the same size, it is padded to the max file size in that batch.

    2. Pad wave files in the same batch to the bucket size. When the file length in the batch is not the same as the bucket length, it needs to be padded to match the length of the bucket with the closest distance. This way, even if there are a lot of batches, the shape of those batches will be limited.

    You can find a padding example using Wav2vec2 here.

With this Bucketing and Padding solution, the number of dynamic shapes of the input to Wav2vec can be greatly decreased. Refer to the implementation in Wav2vec2 Model Reference on GitHub.