Habana PyTorch Python API (habana_frameworks.torch)

This package provides PyTorch bridge interfaces and modules such as optimizers, mixed precision configuration, fused kernels for training on HPU and so on.

The various modules are organized as listed in the below example:

habana_frameworks.torch
  core
  distributed
     hccl
  hpex
     hmp
     kernels
     normalization
     optimizers
  hpu
  utils

The following sections provided a brief description of each module.

core

core module provides Python bindings to PyTorch-Habana bridge interfaces. For example, mark_step which is used to trigger execution of accumulated graphs in Lazy mode.

distributed/hccl

  • distributed/hccl module registers and adds support for HCCL communication backend.

  • from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu - API imports

  • initialize_distributed_hpu() - Helper function used to return world_size, rank and local_rank if the processes are launched using either MPI or with torchrun related APIs.

You can find a usage code in the ResNet50 Model References GitHub page.

The following snippet shows a simple usage sequence:

from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
world_size, rank, local_rank = initialize_distributed_hpu()

hpex/hmp

hpex/hmp module contains the habana_mixed_precision (hmp) tool which can be used to train a model in mixed precision on HPU. Refer to PyTorch Mixed Precision Training on Gaudi for further details.

hpex/kernels

hpex/kernels module contains Python interfaces to Habana only custom operators, such as EmbeddingBag and EmbeddingBagPreProc operators.

hpex/normalization

hpex/normalization module contains Python interfaces to the Habana implementation for common normalize & clip operations performed on gradients in some models. Usage of Habana provided implementation can provide better performance (compared to equivalent operator provided in torch). Refer to Other Custom OPs for further details.

hpex/optimizers

hpex/optimizers contains Python interfaces to Habana implementation for some of the common optimizers used in DL models. Usage of Habana implementation can provide better performance (compared to corresponding optimizer implementations available in torch). Refer to Custom Optimizers for further details.

hpu APIs

Support for HPU tensors is provided with this package. The following APIs provide the same functionality as CPU tensors but HPU is used for the underlying implementation. This package can be imported on demand.

  • import habana_frameworks.torch.hpu as hthpu - Imports the package.

  • hthpu.is_available() - Returns a boolean indicating if a HPU device is currently available.

  • hthpu.device_count() - Returns the number of compute-capable devices.

  • hthpu.get_device_name() - Returns the name of the HPU device.

  • hthpu.current_device() - Returns the index of the current selected HPU device.

utils

utils module contains general Python utilities required for training on HPU.

Memory Stats APIs

  • htorch.hpu.max_memory_allocated - Returns peak HPU memory allocated by tensors (in bytes). reset_peak_memory_stats() can be used to reset the starting point in tracing stats.

  • htorch.hpu.memory_allocated - Returns the current HPU memory occupied by tensors.

  • htorch.hpu.memory_stats - Returns list of HPU memory statics. The below summarizes the sample memory stats printout and details:

    • Limit - Amount of total memory on HPU device

    • InUse - Amount of allocated memory at any instance. Starting point after reset_peak_memory_stats()

    • MaxInUse - Amount of total active memory allocated

    • NumAllocs - Number of allocations

    • NumFrees - Number of freed chunks

    • ActiveAllocs - Number of active allocations

    • MaxAllocSize - Maximum allocated size

    • TotalSystemAllocs - Total number of system allocations

    • TotalSystemFrees - Total number of system frees

    • TotalActiveAllocs - Total number of active allocations

  • htorch.hpu.memory_summary - Returns human readable printout of current memory stats.

  • htorch.hpu.reset_accumulated_memory_stats - This API to clear the no of allocs and no of frees.

  • htorch.hpu.reset_peak_memory_stats - Resets starting point of memory occupied by tensors.

The below shows a usage example:

import torch
import habana_frameworks.torch as htorch
device = torch.device("hpu")
import torch.nn as nn
import torch.nn.functional as F

if __name__ == '__main__':
    hpu = torch.device('hpu')
    cpu = torch.device('cpu')
    input1 = torch.randn((64,28,28,20),dtype=torch.float, requires_grad=True)
    input1_hpu = input1.contiguous(memory_format=torch.channels_last).to(hpu)
    mem_summary1 = htorch.hpu.memory_summary()
    print('memory_summary1:')
    print(mem_summary1)
    htorch.hpu.reset_peak_memory_stats()
    input2 = torch.randn((64,28,28,20),dtype=torch.float, requires_grad=True)
    input2_hpu = input2.contiguous(memory_format=torch.channels_last).to(hpu)
    mem_summary2 = htorch.hpu.memory_summary()
    print('memory_summary2:')
    print(mem_summary2)
    mem_allocated = htorch.hpu.memory_allocated()
    print('memory_allocated: ', mem_allocated)
    mem_stats = htorch.hpu.memory_stats()
    print('memory_stats:')
    print(mem_stats)
    max_mem_allocated = htorch.hpu.max_memory_allocated()
    print('max_memory_allocated: ', max_mem_allocated)

Stream APIs

Streams and events are advanced features for concurrency that allow users to manage multiple asynchronous tasks running on the HPU.

  • import habana_frameworks.torch as htorch - Imports the package.

  • htorch.hpu.Stream - Returns a wrapper around a HPU stream.

  • htorch.hpu.stream - Wrapper around the Context-manager StreamContext that selects a given stream.

  • htorch.hpu.set_stream - Sets the current stream.

  • htorch.hpu.current_stream - Sets the current stream.

  • htorch.hpu.default_stream - Gets the default stream.

APIs available on Stream object:

  • query() - Checks if all the work submitted on a HPU stream has completed.

  • synchronize() - Wait for all the kernels in a HPU stream to complete.

  • record_event() - Records an event on the HPU stream.

  • wait_event() - Makes all future work submitted to the stream wait for an event.

  • wait_stream() - All future work submitted to this stream will wait until all kernels submitted to a given stream at the time of call are completed.

The below shows a usage example:

import torch
import habana_frameworks.torch.hpu as htcore
import habana_frameworks.torch as ht

s0 = ht.hpu.Stream()
in_shape = (10,2)
tA_h = torch.zeros(in_shape).to('hpu')
tB_h = torch.ones(in_shape).to('hpu')

tOut1 = torch.add(tA_h,tA_h) #executes on default stream

with ht.hpu.stream(s0):
    tOut2 = torch.add(tB_h,tB_h) #executes on stream s0

s0.synchronize() #synchronize with s0
print(s0.query()) #returns True as all operations on s0 have finished

Event APIs

  • import habana_frameworks.torch as htorch - Imports the package.

  • htorch.hpu.Event - Returns a wrapper around a HPU event.

APIs available on Event object:

  • query() - Checks if all work currently captured by event has completed.

  • synchronize() - Waits until the completion of all work currently captured in this event.

  • record() - Records the event in a given stream.

  • wait() - Makes all future work submitted to the given stream wait for this event.

  • elapsed_time(end_event) - Returns the time elapsed in milliseconds after the event was recorded and before the end_event was recorded.

The below shows a usage example:

import torch
import habana_frameworks.torch.hpu as htcore
import habana_frameworks.torch as ht

in_shape = (10,2)
tA_h = torch.zeros(in_shape).to('hpu')
tB_h = torch.ones(in_shape).to('hpu')

startEv =ht.hpu.Event(enable_timing=True)
endEv = ht.hpu.Event(enable_timing=True)
startEv.record()
for _ in range(100):
  tA_h = torch.add(tA_h,tB_h)
endEv.record()
endEv.synchronize()
print(f'Time Elapsed={startEv.elapsed_time(endEv)}')

HPU Graph APIs

HPU Graph APIs for Inference

  • import habana_frameworks.torch as htorch - Imports the package.

  • htorch.hpu.HPUGraph - Returns a wrapper around a HPU graph.

  • htorch.hpu.wrap_in_hpu_graph - Wraps module forward function with hpu graph.

APIs available on HPU Graph object for inference:

  • capture_begin() - Begins capturing HPU work on the current stream.

  • capture_end() - Ends capturing HPU work on the current stream.

  • replay() - Replays the HPU work captured by this graph.

The below shows a usage example:

import torch
import habana_frameworks.torch as ht

g = ht.hpu.HPUGraph()
s = ht.hpu.Stream()

with ht.hpu.stream(s):
    g.capture_begin()
    a = torch.full((100,), 1, device="hpu")
    b = a
    b = b + 1
    g.capture_end()

g.replay()
ht.hpu.synchronize()
import torch
import habana_frameworks.torch as ht

model = GetModel()
model = ht.hpu.wrap_in_hpu_graph(model)

HPU Graph APIs for Training

  • import habana_frameworks.torch as htorch - Imports the package.

  • htorch.hpu.HPUGraph - Returns a wrapper around a HPU graph.

APIs available on HPU Graph object for training:

  • make_graphed_callables((callables, sample_args, warmup=0) - Makes the training graph. Each graph callable is made to forward and backward pass by overloading the autograd function. This API requires the model to have only tuples for tensors as input and output which is incompatible with workloads using data structures such as dicts and lists.

The below shows a usage example:

def test_graph_training():
    N, D_in, H, D_out = 2, 2, 2, 2
    module = torch.nn.Linear(D_in, H).to('cpu')
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(module.parameters(), lr=0.1)
    x = torch.randn(N, D_in, device='hpu')
    module = ht.hpu.make_graphed_callables(module, (x,))
    real_inputs = [torch.rand_like(x) for _ in range(100)]
    real_targets= [torch.randn(N, D_out, device="hpu") for _ in range(100)]

    for data, target in zip(real_inputs, real_targets):
        optimizer.zero_grad(set_to_none=True)
        output = module(data)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
  • ht.hpu.ModuleCacher(max_graphs=10)(model=model, inplace=True) - Specifies the number of graphs that need to be cached.

A larger number of cached graphs will increase the device memory usage, but also increase the number of cache hits. This API is compatible with workloads using dictionaries and lists.

The below shows a usage example:

def test_cached_module_training():
    model = Net().to('hpu')
    state_dict = copy.deepcopy(model.state_dict())
    optimizer = torch.optim.SGD(model.parameters(),lr=0.1)

    meta_args = [((3, 4), True), ((5, 4), False), ((11, 4), True)]

    net_input = []
    net_output = []
    for i in range(2):
        for item in meta_args:
            x = torch.randn(item[0]).to('hpu')
            y = torch.randn(item[0]).to('hpu')
            net_input.append({'x' : x, 'y' : y, 'boolean_var' : item[1]})
            net_output.append(torch.randn(item[0][0]).to('hpu'))

def train_model():
        for inp, y in zip(net_input, net_output):
            output = model(**inp)
            y_pred = torch.mean(output[1], 1)
            optimizer.zero_grad(set_to_none=True)
            loss = torch.nn.functional.mse_loss(y_pred, y)
            loss.backward()
            optimizer.step()
            ht.core.mark_step()
        return loss.cpu()

    loss_original = train_model()
    model.load_state_dict(state_dict)
    ht.hpu.ModuleCacher()(model=model, inplace=True)
    loss_cached = train_model()
    return loss_original == loss_cached

Note

The training interfaces, ModuleCacher and make_grahed_callables, should be called before the DDP hook registration.

Random Number Generator APIs

import habana_frameworks.torch.hpu.random as htrandom imports the HPU random package.

Below are the APIs available in torch.hpu.random package:

  • htrandom.get_rng_state - Returns the random number generator state of the specified HPU as a ByteTensor.

  • htrandom.get_rng_state_all - Returns a list of ByteTensor representing the random number states of all devices.

  • htrandom.set_rng_state - Sets the random number generator state of the specified HPU.

  • htrandom.set_rng_state_all - Sets the random number generator state of all devices.

  • htrandom.manual_seed - Sets the seed for generating random numbers for the current HPU device.

  • htrandom.manual_seed_all - Sets the seed for generating random numbers on all HPUs.

  • htrandom.seed - Sets the seed for generating random numbers to a random number for the current HPU.

  • htrandom.seed_all - Sets the seed for generating random numbers to a random number on all HPUs.

  • htrandom.initial_seed - Returns the current random seed of the current HPU.

The below shows a usage example:

import torch
import habana_frameworks.torch.hpu.random as htrandom
state = htrandom.get_rng_state()
htrandom.set_rng_state(state)
initial_seed = htrandom.initial_seed()
htrandom.manual_seed(2)
htrandom.seed()
print (hrandom.initial_seed())