Intel Gaudi PyTorch Python API (habana_frameworks.torch)

This package provides PyTorch bridge interfaces and modules such as optimizers, mixed precision configuration, fused kernels for training on Intel® Gaudi® AI accelerator and so on.

The various modules are organized as listed in the below example:

habana_frameworks.torch
  core
  distributed
     hccl
  gpu_migration
  hpex
     kernels
     normalization
     optimizers
  hpu
  utils

The following sections provided a brief description of each module.

core

core module provides Python bindings to Intel Gaudi PyTorch bridge interfaces. For example, mark_step which is used to trigger execution of accumulated graphs in Lazy mode.

distributed/hccl

  • distributed/hccl module registers and adds support for HCCL communication backend.

  • from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu - API imports

  • initialize_distributed_hpu() - Helper function used to return world_size, rank and local_rank if the processes are launched using either MPI or with torchrun related APIs.

You can find a usage code in the ResNet50 Model References GitHub page.

The following snippet shows a simple usage sequence:

from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
world_size, rank, local_rank = initialize_distributed_hpu()

gpu_migration

gpu_migration can be utilized to quickly migrate your model that strongly depends on CUDA to HPU. Refer to GPU Migration Toolkit for further details.

hpex/kernels/CustomSoftmax

CustomSoftmax is an optimized version of torch.nn.Softmax operator. It only supports bfloat16 input and computes Softmax along the last dimension (dim=-1). Additionally, its second argument is flavor. The following flavors are available:

  • flavor=0 - Equivalent to the original torch.nn.Softmax.

  • flavor=1 - Uses fast approximation of exp and reciprocal.

  • flavor=2 - Does not subtract the maximum and uses fast approximation of exp and reciprocal.

This operator is useful in optimizing performance of inference workloads, for example Stable Diffusion which uses flavor=1.

Some layers can be sensitive to Softmax accuracy and its numerical stability so applying the fastest option (2) for all the layers may harm model output. Therefore, collecting statistics for each layer (min and max of maximums along the last dimensions) can help to find layers suitable for applying this optimization: if min and max are far away from 0 then subtracting the maximum will be required for numerical stability. Also, using CustomSoftmax only in selected layers may help with optimizing the model if some layers are very sensitive to Softmax accuracy.

The below shows a usage example:

import habana_frameworks.torch.hpex.kernels as hpu_kernels

attn_weights = attn_weights * self.inv_scale_attn
attn_weights = attn_weights + attention_mask
attn_weights = hpu_kernels.CustomSoftmax.apply(attn_weights, 2) # flavor=2

hpex/kernels/FBGEMM

FBGEMM (Facebook GEneral Matrix Multiplication Kernels Library) is a collection of high-performance operator libraries designed for training and inference. The habana_frameworks.torch.hpex.kernels.fbgemm package provides the following APIs:

  • bounds_check_indices - Checks and corrects any invalid values of the provided indices and offsets.

  • expand_into_jagged_permute - Expands the permute index for sparse data from the table dimension to the batch dimension in cases where the sparse features have different batch sizes across ranks.

  • permute_1D_sparse_data - Shuffles lengths, indices and weights tensors according to the values in permute tensor.

  • permute_2D_sparse_data - Shuffles lengths, indices and weights tensors according to the values in permute tensor.

  • split_embedding_codegen_lookup_function - A simple lookup table that stores embeddings for a fixed dictionary of a specific size..

  • split_permute_cat - Replaces the combination of split_with_sizes, permute, and cat operations.

The below shows a usage example:

import torch
from test_utils import hpu
from habana_frameworks.torch.hpex.kernels.fbgemm import split_permute_cat

B = 3; F = 3; D = 8
input = torch.randn(B, F * D, dtype=torch.float32)
indices = torch.randperm(F)

output = split_permute_cat(input.to(hpu), indices.to(hpu), B, F, D)
print(output.cpu())

hpex/kernels/FusedSDPA

FusedSDPA is a fused implementation of torch.nn.functional.scaled_dot_product_attention() for Gaudi. It maintains the same functionality and interface as the original function but with reduced memory usage.

FusedSDPA implements selected Flash Attention optimization approaches applicable to HPU.

There are two operation modes: Recompute mode (default) and No-recompute mode. For further details, see Using Fused Scaled Dot Product Attention (FusedSDPA).

To select or query operation modes, use the following APIs:

  • habana_frameworks.torch.hpu.sdp_kernel(enable_recompute = True) - Context manager based control to temporarily/locally (i.e, within the current context) enable recompute mode.

  • habana_frameworks.torch.hpu.sdp_kernel(enable_recompute = False) - Context manager based control to temporarily/locally (i.e, within the current context) disable recompute mode.

For example, to run FusedSDPA without recompute in the current context:

with habana_frameworks.torch.hpu.sdp_kernel(enable_recompute = False):
   FusedSDPA.apply(....)
  • habana_frameworks.torch.hpu.enable_recompute_sdp(True) - Globally enable recompute.

  • habana_frameworks.torch.hpu.enable_recompute_sdp(False) - Globally disable recompute.

The recompute state set by this API remains effective until this API is used again to change the state.

  • habana_frameworks.torch.hpu.recompute_sdp_enabled() - Query if recompute mode is enabled or not. is_recompute_enabled returns a value of True/False.

The below shows a FusedSDPA usage example:

import torch
from habana_frameworks.torch.hpex.kernels import FusedSDPA
import habana_frameworks.torch.hpu as ht

query = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="hpu")
key = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="hpu")
value = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="hpu")
# No attention mask, dropout = 0.1, is_causal = True and scale = None
# i.e. scale factor =1.0/sqrt(64)
# recompute mode is enabled by default, so the following context setting
# is only for illustration.
with ht.sdp_kernel(enable_recompute = True):
    sdpa_out = FusedSDPA.apply(query, key, value, None, 0.1, True)
    print(sdpa_out.to("cpu"))

hpex/normalization

hpex/normalization module contains Python interfaces to the Gaudi implementation for common normalize & clip operations performed on gradients in some models. Usage of Gaudi provided implementation can provide better performance (compared to equivalent operator provided in torch). Refer to Other Custom OPs for further details.

hpex/optimizers

hpex/optimizers contains Python interfaces to Gaudi implementation for some of the common optimizers used in DL models. Usage of Gaudi implementation can provide better performance (compared to corresponding optimizer implementations available in torch). Refer to Custom Optimizers for further details.

hpex/experimental/transformer_engine

hpex/experimental/transformer_engine contains Python interfaces to Intel Gaudi Transformer Engine implementation. Intel Gaudi Transformer Engine provides optimized implementations of PyTorch modules for popular Transformer architectures that perform the computations in 8bit floating point data type. Usage of TE and the FP8 data type can provide better performance with lower memory utilization. Refer to FP8 Training with Intel Gaudi Transformer Engine for further details.

hpu APIs

Support for HPU tensors is provided with this package. The following APIs provide the same functionality as CPU tensors but HPU is used for the underlying implementation. This package can be imported on demand.

  • import habana_frameworks.torch.hpu as hthpu - Imports the package.

  • hthpu.is_available() - Returns a boolean indicating if a HPU device is currently available.

  • hthpu.device_count() - Returns the number of compute-capable devices.

  • hthpu.get_device_name() - Returns the name of the HPU device.

  • hthpu.current_device() - Returns the index of the current selected HPU device.

utils

utils module contains general Python utilities required for training on HPU.

Memory Stats APIs

  • htorch.hpu.max_memory_allocated - Returns peak HPU memory allocated by tensors (in bytes). reset_peak_memory_stats() can be used to reset the starting point in tracing stats.

  • htorch.hpu.memory_allocated - Returns the current HPU memory occupied by tensors.

  • htorch.hpu.memory_stats - Returns list of HPU memory stats. The below summarizes the sample memory stats printout and details:

    • Limit - Amount of total memory on HPU device

    • InUse - Amount of allocated memory at any instance. Starting point after reset_peak_memory_stats()

    • MaxInUse - Amount of total active memory allocated

    • NumAllocs - Number of allocations

    • NumFrees - Number of freed chunks

    • ActiveAllocs - Number of active allocations

    • MaxAllocSize - Maximum allocated size

    • TotalSystemAllocs - Total number of system allocations

    • TotalSystemFrees - Total number of system frees

    • TotalActiveAllocs - Total number of active allocations

  • htorch.hpu.memory_summary - Returns human readable printout of current memory stats.

  • htorch.hpu.reset_accumulated_memory_stats - This API to clear the no of allocs and no of frees.

  • htorch.hpu.reset_peak_memory_stats - Resets starting point of memory occupied by tensors.

The below shows a usage example:

import torch
import habana_frameworks.torch as htorch
device = torch.device("hpu")
import torch.nn as nn
import torch.nn.functional as F

if __name__ == '__main__':
    hpu = torch.device('hpu')
    cpu = torch.device('cpu')
    input1 = torch.randn((64,28,28,20),dtype=torch.float, requires_grad=True)
    input1_hpu = input1.contiguous(memory_format=torch.channels_last).to(hpu)
    mem_summary1 = htorch.hpu.memory_summary()
    print('memory_summary1:')
    print(mem_summary1)
    htorch.hpu.reset_peak_memory_stats()
    input2 = torch.randn((64,28,28,20),dtype=torch.float, requires_grad=True)
    input2_hpu = input2.contiguous(memory_format=torch.channels_last).to(hpu)
    mem_summary2 = htorch.hpu.memory_summary()
    print('memory_summary2:')
    print(mem_summary2)
    mem_allocated = htorch.hpu.memory_allocated()
    print('memory_allocated: ', mem_allocated)
    mem_stats = htorch.hpu.memory_stats()
    print('memory_stats:')
    print(mem_stats)
    max_mem_allocated = htorch.hpu.max_memory_allocated()
    print('max_memory_allocated: ', max_mem_allocated)

Metric APIs

The Metric APIs provide various performance-related metrics, such as the number of graph compilations, the total time of graph compilations, and more. This section provides instructions on how to retrieve metrics using Python APIs listed below. You can also obtain metrics by saving them to a file without requiring any changes to the script using the environment variables listed in Runtime Environment Variables.

Metric APIs functions are defined in habana_frameworks.torch.hpu.metrics module as described below.

  • habana_frameworks.torch.hpu.metrics.metrics_global - Returns a global metric object based on the name provided. The object is active and present during the entire execution process.

    The below shows a usage example:

    import torch
    from habana_frameworks.torch.hpu.metrics import metric_global
    
    gc_metric = metric_global("graph_compilation")
    
    print(gc_metric.stats())
    
    device = torch.device('hpu')
    do_computation(device)
    
    print(gc_metric.stats())
    # 'gc_metric.stats' returns the list of counters collected during the entire execution process. Each counter is expressed as a tuple of the counter's name and its value.
    
  • habana_frameworks.torch.hpu.metrics.metrics_localcontext - Returns a context manager object for the requested metric. The metric collection is limited to the scope of the context manager defined within the with statement.

    The below shows a usage example:

    import torch
    from habana_frameworks.torch.hpu.metrics import metric_localcontext
    with metric_localcontext("graph_compilation") as local_metric:
        do_computation()
    
    print(local_metric.stats())
    # 'local_metric.stats' returns only the list of counters reflecting the data collected within the "with" statement. Each counter is expressed as a tuple of the counter's name and its value.
    
  • habana_frameworks.torch.hpu.metrics.metrics_dump - Stores the contents collected by the global metrics only in a file. Both JSON and TEXT formats are supported.

    The below shows a usage example:

    import torch
    from habana_frameworks.torch.hpu.metrics import metrics_dump
    
    device = torch.device('hpu')
    do_computation(device)
    
    metrics_dump("metrics.json", "json")
    # Only global metrics will be saved in 'metric.json' file in JSON format. The local metrics created through 'metric_localcontext' will not be stored.
    

Below is the list of metrics that are currently supported along with their respective properties:

Metric

Properties

graph_compilation

  • TotalNumber - total number of graph compilations

  • TotalTime - total time of graph compilations in microseconds (μs)

  • AvgTime - average time of graph compilation in microseconds (μs)

cpu_fallback

  • TotalNumber - total number of CPU fallbacks

  • FallbackOps - total number of CPU fallbacks for each operator

memory_defragmentation

  • TotalNumber - total number of memory defragmentations triggered

  • TotalSuccessful - total number of memory defragmentations completed successfully

  • AvgTime - average time of memory defragmentation (ms)

  • MaxTime - maximum time of memory defragmentation (ms)

recipe_cache

  • TotalHit - total number of graph recipe cache hits

  • TotalMiss - total number of graph recipe cache misses

  • RecipeHit - total number of graph recipe cache hits for each recipe

  • RecipeMiss - total number of graph recipe cache misses for each recipe

Stream APIs

Streams and events are advanced features for concurrency that allow users to manage multiple asynchronous tasks running on the HPU.

  • import habana_frameworks.torch as htorch - Imports the package.

  • htorch.hpu.Stream - Returns a wrapper around a HPU stream.

  • htorch.hpu.stream - Wrapper around the Context-manager StreamContext that selects a given stream.

  • htorch.hpu.set_stream - Sets the current stream.

  • htorch.hpu.current_stream - Gets the current stream.

  • htorch.hpu.default_stream - Gets the default stream.

APIs available on Stream object:

  • query() - Checks if all the work submitted on a HPU stream has completed.

  • synchronize() - Wait for all the kernels in a HPU stream to complete.

  • record_event() - Records an event on the HPU stream.

  • wait_event() - Makes all future work submitted to the stream wait for an event.

  • wait_stream() - All future work submitted to this stream will wait until all kernels submitted to a given stream at the time of call are completed.

The below shows a usage example:

import torch
import habana_frameworks.torch.hpu as htcore
import habana_frameworks.torch as ht

s0 = ht.hpu.Stream()
in_shape = (10,2)
tA_h = torch.zeros(in_shape).to('hpu')
tB_h = torch.ones(in_shape).to('hpu')

tOut1 = torch.add(tA_h,tA_h) #executes on default stream

with ht.hpu.stream(s0):
    tOut2 = torch.add(tB_h,tB_h) #executes on stream s0

s0.synchronize() #synchronize with s0
print(s0.query()) #returns True as all operations on s0 have finished

Event APIs

  • import habana_frameworks.torch as htorch - Imports the package.

  • htorch.hpu.Event - Returns a wrapper around a HPU event.

APIs available on Event object:

  • query() - Checks if all work currently captured by event has completed.

  • synchronize() - Waits until the completion of all work currently captured in this event.

  • record() - Records the event in a given stream.

  • wait() - Makes all future work submitted to the given stream wait for this event.

  • elapsed_time(end_event) - Returns the time elapsed in milliseconds after the event was recorded and before the end_event was recorded.

The below shows a usage example:

import torch
import habana_frameworks.torch.hpu as htcore
import habana_frameworks.torch as ht

in_shape = (10,2)
tA_h = torch.zeros(in_shape).to('hpu')
tB_h = torch.ones(in_shape).to('hpu')

startEv =ht.hpu.Event(enable_timing=True)
endEv = ht.hpu.Event(enable_timing=True)
startEv.record()
for _ in range(100):
  tA_h = torch.add(tA_h,tB_h)
endEv.record()
endEv.synchronize()
print(f'Time Elapsed={startEv.elapsed_time(endEv)}')

HPU Graph APIs

HPU Graph APIs for Inference

  • import habana_frameworks.torch as htorch - Imports the package.

  • htorch.hpu.HPUGraph - Returns a wrapper around a HPU graph.

  • htorch.hpu.wrap_in_hpu_graph(module, asynchronous=False, disable_tensor_cache=False, dry_run=False, max_graphs=None) - Wraps module forward function with HPU Graph.

    • module (torch.nn.Module) - The torch.nn.Module object to be cached.

    • asynchronous (bool, optional) - If True, replays for the HPU Graphs will be launched asynchronously. Set this based on the model used. This flag can be used along with stream and async D2H to copy outputs for reducing host overheads. Default: False

    • disable_tensor_cache (bool, optional) - If False, HPU Graphs cache all input, intermediate and output tensors. Setting this to True frees up previously cached input tensors, intermediate tensors except in-place, views and output tensors, thereby reducing the model’s memory requirements. Since the intermediate tensors are freed by enabling this flag, this feature cannot be used in scenarios where intermediate tensors are saved and used outside the scope of the current graph’s forward pass. Default: False

    • dry_run (bool, optional) - If True, the model’s HPU Graph is cached while avoiding the initial lazy execution. Subsequent replays launch the cached HPU Graphs. This option will be deprecated in a future release. Default: False

    • max_graphs (int, optional) - Specifies the maximum number of HPU graphs that can be cached. Any change in input combinations will lead to the creation of a new graph. Setting this to None indicates no limit, but this may lead to out-of-memory issues if memory requirements are not met. Default: None

APIs available on HPU Graph object for inference:

  • capture_begin() - Begins capturing HPU work on the current stream.

  • capture_end() - Ends capturing HPU work on the current stream.

  • replay() - Replays the HPU work captured by this graph.

The below shows a usage example:

import torch
import habana_frameworks.torch as ht

g = ht.hpu.HPUGraph()
s = ht.hpu.Stream()

with ht.hpu.stream(s):
    g.capture_begin()
    a = torch.full((100,), 1, device="hpu")
    b = a
    b = b + 1
    g.capture_end()

g.replay()
ht.hpu.synchronize()
import torch
import habana_frameworks.torch as ht

model = GetModel()
model = ht.hpu.wrap_in_hpu_graph(model)

HPU Graph APIs for Training

  • import habana_frameworks.torch as htorch - Imports the package.

  • htorch.hpu.HPUGraph - Returns a wrapper around a HPU graph.

APIs available on HPU Graph object for training:

  • make_graphed_callables(callables, sample_args, warmup=0, allow_unused_input=False, asynchronous=False, disable_tensor_cache=False, dry_run=False) - Makes the training graph. Each graph callable is made to forward and backward pass by overloading the autograd function. This API requires the model to have only tuples for tensors as input and output which is incompatible with workloads using data structures such as dicts and lists. Refer to ModuleCacher parameter description below.

The below shows a usage example:

def test_graph_training():
    N, D_in, H, D_out = 2, 2, 2, 2
    module = torch.nn.Linear(D_in, H).to('hpu')
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(module.parameters(), lr=0.1)
    x = torch.randn(N, D_in, device='hpu')
    module = ht.hpu.make_graphed_callables(module, (x,))
    real_inputs = [torch.rand_like(x) for _ in range(100)]
    real_targets= [torch.randn(N, D_out, device="hpu") for _ in range(100)]

    for data, target in zip(real_inputs, real_targets):
        optimizer.zero_grad(set_to_none=True)
        output = module(data)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
  • ht.hpu.ModuleCacher(max_graphs=10)(model=model, use_lfu=False, inplace=True, allow_unused_input=False, asynchronous=False, have_grad_accumulation=False, log_frequency=100, verbose=False, disable_tensor_cache=False, dry_run=False)

    ModuleCacher is a wrapper over the make_graphed_callables API to simplify caching graphs with static/dynamic input shapes. An object of this class can be called with a torch.nn.Module model to return a version of the model which automatically caches a static graph for unique input combinations. When an input combination which was seen before is encountered, the cached version of the model’s graph is used. To check equivalence between two input combinations, a hashed value of the set of input variables is used. To compute the hash, the values are used for all Pythonic variables and only the shape attribute is considered for torch.Tensor objects.

    • model (torch.nn.Module) - The torch.nn.Module object to be cached.

    • max_graphs (int, optional) - Specifies the maximum number of HPU Graphs that can be cached. This value indirectly controls the maximum device memory that caching can utilize. The exact memory utilized will depend on the module being cached. If use_lfu is not specified, the cache is not updated after reaching the limit. After the cache limit is hit, new shapes will not be cached and will run in lazy mode. Default: 10

    • use_lfu (bool, optional) - If True, enables priority based least-frequently-used (LFU) caching. When enabled, the capture_start and capture_end methods need to be invoked, during which the most frequently occurring input combinations will be identified. Top candidates (up to max_graphs) will be cached the next time it is encountered, and the subsequent runs replay the cached HPU Graphs. The disadvantage with use_lfu is that caching begins only after capture_end is triggered. When use_lfu is disabled, caching starts automatically without any external triggers, but there is no guarantee of optimal caching, since only the first max_graphs number of input combinations occupy the cache. This is useful in cases where we have dynamic inputs, so internally during first epoch we calculate stats based on repetition of input shapes. During next epoch, those stats are used to cache specified max_graphs. Default: False

    htorch.hpu.ModuleCacher()(model=model, use_lfu=True, inplace=True)
    for epoch in epochs:
        if epoch == 0: model.capture_start()
        ## Training loop
        if epoch == 0: model.capture_end()
    
    • inplace (bool, optional) - If True, the passed model will be modified in-place to prepare it for caching. Otherwise, a copy of the model is created. Default: True

    • allow_unused_input (bool, optional) - If False, specifying inputs that are irrelevant for computing outputs (and therefore their grad is always zero) is an error. Setting this to True ignores the inputs that do not contribute to gradient computations, and bypasses the error. Refer to torch.autograd.grad for details. Please note that materialize_grads is set to the same value as allow_unused` Default: False

    • asynchronous (bool, optional) - If True, replays for the HPU Graphs will be launched asynchronously. Set this based on the model used. This flag can be used along with stream and async D2H to copy outputs for reducing host overheads. It is also useful in use cases where user knows it does not need output immediately after forward call and it can carry on with other tasks and then use streams to see if output is ready, when it is actually needed. Some checks such as input data asserts are in place to detect anomalies, but these may not be exhaustive checks. User is advised to check asynchronous use cases and revert this value to False, in case of any numerical issues with result. Default: False

    • have_grad_accumulation (bool, optional) - Specifies whether the training employs gradient accumulation to have separate HPU Graphs for first and rest of the iterations. If have_grad_accumulation is set to True, it is expected that the user will call model.set_iteration_count(int) for the models, to identify if the forward pass was called just after setting gradients to zero. Enable this option when the model necessitates distinct graphs for forward passes that immediately follows a zero_grad. Candidate model configuration has gradient accumulation steps greater than 1. Default: False

    • log_frequency (int, optional) - Specifies the logging frequency of ModuleCacher stats in terms of steps per log. Default: 100

    • verbose (bool, optional) - If True, enables debug logs. Default: False

    • disable_tensor_cache (bool, optional) - If False, HPU Graphs cache all input, intermediate and output tensors. Setting this to True frees up previously cached input and output tensors, thereby reducing the model’s memory requirements. This option cannot be used in certain conditions such as output is passed as input to forward. This option will be deprecated in a future release. Default: False

    • dry_run (bool, optional) - If True, the model’s HPU Graph is cached while avoiding the initial lazy execution. Subsequent replays launch the cached HPU Graphs. This option will be deprecated in a future release. Default: False

The usage of the module is shown in the example below:

import torch
import habana_frameworks.torch as htorch
from random import choice

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = torch.nn.Linear(4, 4)
        self.fc2 = torch.nn.Linear(4, 4)
        self.fc3 = torch.nn.Linear(4, 4)

    def forward(self, x, condition=False):
        N, _, _ = x.shape
        x = self.fc1(x)
        if condition:
            x = self.fc2(x)
        else:
            x = self.fc3(x)

        x = x.view(N, -1)
        return x

def generate_data(n_samples):
    shapes = [(2, 3, 4), (3, 11, 4), (3, 4, 4)]
    conditions = [True, False]

    inputs = [{'x': torch.randn(choice(shapes)).to('hpu'), 'condition': choice(conditions)} for _ in range(n_samples)]
    targets = [torch.randn(item['x'].shape[0]).to('hpu') for item in inputs]

    return inputs, targets

rand_inputs, rand_targets = generate_data(n_samples=10)

model = Net().to('hpu')
optimizer = torch.optim.SGD(model.parameters(),lr=0.1)
htorch.hpu.ModuleCacher()(model=model, inplace=True, allow_unused_input=True, verbose=True, log_frequency=1)

for x, y in zip(rand_inputs, rand_targets):
    optimizer.zero_grad(set_to_none=True)
    y_pred = torch.mean(model(**x), 1)
    loss = torch.nn.functional.mse_loss(y_pred, y)
    loss.backward()
    optimizer.step()
    htorch.core.mark_step()

Note

The training interfaces, ModuleCacher and make_graphed_callables, should be called before the DDP hook registration.

Random Number Generator APIs

import habana_frameworks.torch.hpu.random as htrandom imports the HPU random package.

Below are the APIs available in torch.hpu.random package:

  • htrandom.get_rng_state - Returns the random number generator state of the specified HPU as a ByteTensor.

  • htrandom.get_rng_state_all - Returns a list of ByteTensor representing the random number states of all devices.

  • htrandom.set_rng_state - Sets the random number generator state of the specified HPU.

  • htrandom.set_rng_state_all - Sets the random number generator state of all devices.

  • htrandom.manual_seed - Sets the seed for generating random numbers for the current HPU device.

  • htrandom.manual_seed_all - Sets the seed for generating random numbers on all HPUs.

  • htrandom.seed - Sets the seed for generating random numbers to a random number for the current HPU.

  • htrandom.seed_all - Sets the seed for generating random numbers to a random number on all HPUs.

  • htrandom.initial_seed - Returns the current random seed of the current HPU.

The below shows a usage example:

import torch
import habana_frameworks.torch.hpu.random as htrandom
state = htrandom.get_rng_state()
htrandom.set_rng_state(state)
initial_seed = htrandom.initial_seed()
htrandom.manual_seed(2)
htrandom.seed()
print (hrandom.initial_seed())

Dynamic Shape APIs

The Dynamic Shape APIs control how the Intel Gaudi PyTorch bridge and graph compiler manage dynamic shapes in model scripts. Dynamic Shapes enabling is disabled by default. Please see the Dynamic Shapes Optimization section for more information on how to optimize for Dynamic Shapes in Data and Models.

  • import habana_frameworks.torch.hpu as ht - Imports the package.

  • ht.disable_dynamic_shape() - Disables dynamic shape feature.

  • ht.enable_dynamic_shape() - Enables dynamic shape feature.

import habana_frameworks.torch.hpu as ht
    ht.disable_dynamic_shape()
    ht.enable_dynamic_shape()