Habana PyTorch Python API (habana_frameworks.torch)
On this Page
Habana PyTorch Python API (habana_frameworks.torch)¶
This package provides PyTorch bridge interfaces and modules such as optimizers, mixed precision configuration, fused kernels for training on HPU and so on.
The various modules are organized as listed in the below example:
habana_frameworks.torch
core
distributed
hccl
gpu_migration
hpex
hmp
kernels
normalization
optimizers
hpu
utils
The following sections provided a brief description of each module.
core¶
core
module provides Python bindings to PyTorch-Habana bridge interfaces.
For example, mark_step
which is used to trigger execution of accumulated graphs in Lazy mode.
distributed/hccl¶
distributed/hccl
module registers and adds support for HCCL communication backend.from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
- API importsinitialize_distributed_hpu()
- Helper function used to return world_size, rank and local_rank if the processes are launched using either MPI or with torchrun related APIs.
You can find a usage code in the ResNet50 Model References GitHub page.
The following snippet shows a simple usage sequence:
from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
world_size, rank, local_rank = initialize_distributed_hpu()
gpu_migration¶
gpu_migration
can be utilized to quickly migrate your model that strongly depends on CUDA to HPU.
Refer to GPU Migration Toolkit for further details.
hpex/hmp¶
hpex/hmp
module contains the habana_mixed_precision
(hmp) tool which can be used to train a model in mixed
precision on HPU. Refer to PyTorch Mixed Precision Training on Gaudi for further details.
hpex/kernels/FBGEMM¶
FBGEMM (Facebook GEneral Matrix Multiplication Kernels Library) is a collection of high-performance operator libraries designed
for training and inference. The habana_frameworks.torch.hpex.kernels.fbgemm
package provides the following APIs:
bounds_check_indices
- Checks and corrects any invalid values of the provided indices and offsets.expand_into_jagged_permute
- Expands the permute index for sparse data from the table dimension to the batch dimension in cases where the sparse features have different batch sizes across ranks.permute_1D_sparse_data
- Shuffles lengths, indices and weights tensors according to the values in permute tensor.permute_2D_sparse_data
- Shuffles lengths, indices and weights tensors according to the values in permute tensor.split_embedding_codegen_lookup_function
- A simple lookup table that stores embeddings for a fixed dictionary of a specific size..split_permute_cat
- Replaces the combination ofsplit_with_sizes
, permute, and cat operations.
The below shows a usage example:
import torch
from test_utils import hpu
from habana_frameworks.torch.hpex.kernels.fbgemm import split_permute_cat
B = 3; F = 3; D = 8
input = torch.randn(B, F * D, dtype=torch.float32)
indices = torch.randperm(F)
output = split_permute_cat(input.to(hpu), indices.to(hpu), B, F, D)
print(output.cpu())
hpex/kernels/FusedSDPA¶
FusedSDPA is a fused implementation of torch.nn.functional.scaled_dot_product_attention()
on HPU,
which maintains the same functionality and interface as the original function.
The FusedSDPA
class takes several parameters (query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None)
and produces the output of scaled dot product attention. For further information on the functions and parameters, refer to
Scaled Dot Product Attention.
In addition to the mentioned parameters, you can call specific parameters through FusedSDPA.apply
as described below:
attn_mask
- In addition to the default shape (N,…,L,S), added support for shape (N,…,1,S).scale
is the softmax scale factor and is set to None by default. Ifscale=None
, the actual scale factor used is 1.0/sqrt(E), where E is the embedding dimension of the query and key.
The below shows a usage example:
import torch
from habana_frameworks.torch.hpex.kernels import FusedSDPA
query = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="hpu")
key = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="hpu")
value = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="hpu")
# No attention mask, dropout = 0.1, is_causal = False and scale = None
# i.e. scale factor =1.0/sqrt(64)
sdpa_out = FusedSDPA.apply(query, key, value, None, 0.1)
print(sdpa_out.to("cpu"))
Note
The supported data types are FP32 and BF16.
hpex/normalization¶
hpex/normalization
module contains Python interfaces to the Habana implementation for common normalize & clip operations performed
on gradients in some models. Usage of Habana provided implementation can provide better performance (compared to
equivalent operator provided in torch). Refer to Other Custom OPs for further details.
hpex/optimizers¶
hpex/optimizers
contains Python interfaces to Habana implementation for some of the common optimizers used in DL models.
Usage of Habana implementation can provide better performance (compared to corresponding optimizer implementations
available in torch). Refer to Custom Optimizers for further details.
hpu APIs¶
Support for HPU tensors is provided with this package. The following APIs provide the same functionality as CPU tensors but HPU is used for the underlying implementation. This package can be imported on demand.
import habana_frameworks.torch.hpu as hthpu
- Imports the package.hthpu.is_available()
- Returns a boolean indicating if a HPU device is currently available.hthpu.device_count()
- Returns the number of compute-capable devices.hthpu.get_device_name()
- Returns the name of the HPU device.hthpu.current_device()
- Returns the index of the current selected HPU device.
utils¶
utils
module contains general Python utilities required for training on HPU.
Memory Stats APIs¶
htorch.hpu.max_memory_allocated
- Returns peak HPU memory allocated by tensors (in bytes). reset_peak_memory_stats() can be used to reset the starting point in tracing stats.htorch.hpu.memory_allocated
- Returns the current HPU memory occupied by tensors.htorch.hpu.memory_stats
- Returns list of HPU memory statics. The below summarizes the sample memory stats printout and details:Limit - Amount of total memory on HPU device
InUse - Amount of allocated memory at any instance. Starting point after reset_peak_memory_stats()
MaxInUse - Amount of total active memory allocated
NumAllocs - Number of allocations
NumFrees - Number of freed chunks
ActiveAllocs - Number of active allocations
MaxAllocSize - Maximum allocated size
TotalSystemAllocs - Total number of system allocations
TotalSystemFrees - Total number of system frees
TotalActiveAllocs - Total number of active allocations
htorch.hpu.memory_summary
- Returns human readable printout of current memory stats.htorch.hpu.reset_accumulated_memory_stats
- This API to clear the no of allocs and no of frees.htorch.hpu.reset_peak_memory_stats
- Resets starting point of memory occupied by tensors.
The below shows a usage example:
import torch
import habana_frameworks.torch as htorch
device = torch.device("hpu")
import torch.nn as nn
import torch.nn.functional as F
if __name__ == '__main__':
hpu = torch.device('hpu')
cpu = torch.device('cpu')
input1 = torch.randn((64,28,28,20),dtype=torch.float, requires_grad=True)
input1_hpu = input1.contiguous(memory_format=torch.channels_last).to(hpu)
mem_summary1 = htorch.hpu.memory_summary()
print('memory_summary1:')
print(mem_summary1)
htorch.hpu.reset_peak_memory_stats()
input2 = torch.randn((64,28,28,20),dtype=torch.float, requires_grad=True)
input2_hpu = input2.contiguous(memory_format=torch.channels_last).to(hpu)
mem_summary2 = htorch.hpu.memory_summary()
print('memory_summary2:')
print(mem_summary2)
mem_allocated = htorch.hpu.memory_allocated()
print('memory_allocated: ', mem_allocated)
mem_stats = htorch.hpu.memory_stats()
print('memory_stats:')
print(mem_stats)
max_mem_allocated = htorch.hpu.max_memory_allocated()
print('max_memory_allocated: ', max_mem_allocated)
Metric APIs¶
The Metric APIs provide various performance-related metrics, such as the number of graph compilations, the total time of graph compilations, and more. This section provides instructions on how to retrieve metrics using Python APIs listed below. You can also obtain metrics by saving them to a file without requiring any changes to the script using the environment variables listed in Runtime Environment Variables.
Metric APIs functions are defined in habana_frameworks.torch.hpu.metrics
module as described below.
habana_frameworks.torch.hpu.metrics.metrics_global
- Returns a global metric object based on the name provided. The object is active and present during the entire execution process.The below shows a usage example:
import torch from habana_frameworks.torch.hpu.metrics import metric_global gc_metric = metric_global("graph_compilation") print(gc_metric.stats()) device = torch.device('hpu') do_computation(device) print(gc_metric.stats()) # 'gc_metric.stats' returns the list of counters collected during the entire execution process. Each counter is expressed as a tuple of the counter's name and its value.
habana_frameworks.torch.hpu.metrics.metrics_localcontext
- Returns a context manager object for the requested metric. The metric collection is limited to the scope of the context manager defined within thewith
statement.The below shows a usage example:
import torch from habana_frameworks.torch.hpu.metrics import metric_localcontext with metric_localcontext("graph_compilation") as local_metric: do_computation() print(local_metric.stats()) # 'local_metric.stats' returns only the list of counters reflecting the data collected within the "with" statement. Each counter is expressed as a tuple of the counter's name and its value.
habana_frameworks.torch.hpu.metrics.metrics_dump
- Stores the contents collected by the global metrics only in a file. Both JSON and TEXT formats are supported.The below shows a usage example:
import torch from habana_frameworks.torch.hpu.metrics import metrics_dump device = torch.device('hpu') do_computation(device) metrics_dump("metrics.json", "json") # Only global metrics will be saved in 'metric.json' file in JSON format. The local metrics created through 'metric_localcontext' will not be stored.
Below is the list of metrics that are currently supported along with their respective properties:
Metric |
Properties |
---|---|
|
|
|
|
|
|
|
|
Stream APIs¶
Streams and events are advanced features for concurrency that allow users to manage multiple asynchronous tasks running on the HPU.
import habana_frameworks.torch as htorch
- Imports the package.htorch.hpu.Stream
- Returns a wrapper around a HPU stream.htorch.hpu.stream
- Wrapper around the Context-manager StreamContext that selects a given stream.htorch.hpu.set_stream
- Sets the current stream.htorch.hpu.current_stream
- Sets the current stream.htorch.hpu.default_stream
- Gets the default stream.
APIs available on Stream object:
query()
- Checks if all the work submitted on a HPU stream has completed.synchronize()
- Wait for all the kernels in a HPU stream to complete.record_event()
- Records an event on the HPU stream.wait_event()
- Makes all future work submitted to the stream wait for an event.wait_stream()
- All future work submitted to this stream will wait until all kernels submitted to a given stream at the time of call are completed.
The below shows a usage example:
import torch
import habana_frameworks.torch.hpu as htcore
import habana_frameworks.torch as ht
s0 = ht.hpu.Stream()
in_shape = (10,2)
tA_h = torch.zeros(in_shape).to('hpu')
tB_h = torch.ones(in_shape).to('hpu')
tOut1 = torch.add(tA_h,tA_h) #executes on default stream
with ht.hpu.stream(s0):
tOut2 = torch.add(tB_h,tB_h) #executes on stream s0
s0.synchronize() #synchronize with s0
print(s0.query()) #returns True as all operations on s0 have finished
Event APIs¶
import habana_frameworks.torch as htorch
- Imports the package.htorch.hpu.Event
- Returns a wrapper around a HPU event.
APIs available on Event object:
query()
- Checks if all work currently captured by event has completed.synchronize()
- Waits until the completion of all work currently captured in this event.record()
- Records the event in a given stream.wait()
- Makes all future work submitted to the given stream wait for this event.elapsed_time(end_event)
- Returns the time elapsed in milliseconds after the event was recorded and before the end_event was recorded.
The below shows a usage example:
import torch
import habana_frameworks.torch.hpu as htcore
import habana_frameworks.torch as ht
in_shape = (10,2)
tA_h = torch.zeros(in_shape).to('hpu')
tB_h = torch.ones(in_shape).to('hpu')
startEv =ht.hpu.Event(enable_timing=True)
endEv = ht.hpu.Event(enable_timing=True)
startEv.record()
for _ in range(100):
tA_h = torch.add(tA_h,tB_h)
endEv.record()
endEv.synchronize()
print(f'Time Elapsed={startEv.elapsed_time(endEv)}')
HPU Graph APIs¶
HPU Graph APIs for Inference¶
import habana_frameworks.torch as htorch
- Imports the package.htorch.hpu.HPUGraph
- Returns a wrapper around a HPU graph.htorch.hpu.wrap_in_hpu_graph
- Wraps module forward function with hpu graph.
APIs available on HPU Graph object for inference:
capture_begin()
- Begins capturing HPU work on the current stream.capture_end()
- Ends capturing HPU work on the current stream.replay()
- Replays the HPU work captured by this graph.
The below shows a usage example:
import torch
import habana_frameworks.torch as ht
g = ht.hpu.HPUGraph()
s = ht.hpu.Stream()
with ht.hpu.stream(s):
g.capture_begin()
a = torch.full((100,), 1, device="hpu")
b = a
b = b + 1
g.capture_end()
g.replay()
ht.hpu.synchronize()
import torch
import habana_frameworks.torch as ht
model = GetModel()
model = ht.hpu.wrap_in_hpu_graph(model)
HPU Graph APIs for Training¶
import habana_frameworks.torch as htorch
- Imports the package.htorch.hpu.HPUGraph
- Returns a wrapper around a HPU graph.
APIs available on HPU Graph object for training:
make_graphed_callables((callables, sample_args, warmup=0)
- Makes the training graph. Each graph callable is made to forward and backward pass by overloading the autograd function. This API requires the model to have only tuples for tensors as input and output which is incompatible with workloads using data structures such as dicts and lists.
The below shows a usage example:
def test_graph_training():
N, D_in, H, D_out = 2, 2, 2, 2
module = torch.nn.Linear(D_in, H).to('cpu')
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(module.parameters(), lr=0.1)
x = torch.randn(N, D_in, device='hpu')
module = ht.hpu.make_graphed_callables(module, (x,))
real_inputs = [torch.rand_like(x) for _ in range(100)]
real_targets= [torch.randn(N, D_out, device="hpu") for _ in range(100)]
for data, target in zip(real_inputs, real_targets):
optimizer.zero_grad(set_to_none=True)
output = module(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
ht.hpu.ModuleCacher(max_graphs=10)(model=model, inplace=True)
- Specifies the number of graphs that need to be cached. A larger number of cached graphs will increase the device memory usage, but also increase the number of cache hits. This API is compatible with workloads using dictionaries and lists.
The below shows a usage example:
def test_cached_module_training():
model = Net().to('hpu')
state_dict = copy.deepcopy(model.state_dict())
optimizer = torch.optim.SGD(model.parameters(),lr=0.1)
meta_args = [((3, 4), True), ((5, 4), False), ((11, 4), True)]
net_input = []
net_output = []
for i in range(2):
for item in meta_args:
x = torch.randn(item[0]).to('hpu')
y = torch.randn(item[0]).to('hpu')
net_input.append({'x' : x, 'y' : y, 'boolean_var' : item[1]})
net_output.append(torch.randn(item[0][0]).to('hpu'))
def train_model():
for inp, y in zip(net_input, net_output):
output = model(**inp)
y_pred = torch.mean(output[1], 1)
optimizer.zero_grad(set_to_none=True)
loss = torch.nn.functional.mse_loss(y_pred, y)
loss.backward()
optimizer.step()
ht.core.mark_step()
return loss.cpu()
loss_original = train_model()
model.load_state_dict(state_dict)
ht.hpu.ModuleCacher()(model=model, inplace=True)
loss_cached = train_model()
return loss_original == loss_cached
Note
The training interfaces, ModuleCacher
and make_graphed_callables
, should be called before the DDP hook registration.
Random Number Generator APIs¶
import habana_frameworks.torch.hpu.random as htrandom
imports the HPU random package.
Below are the APIs available in torch.hpu.random package:
htrandom.get_rng_state
- Returns the random number generator state of the specified HPU as a ByteTensor.htrandom.get_rng_state_all
- Returns a list of ByteTensor representing the random number states of all devices.htrandom.set_rng_state
- Sets the random number generator state of the specified HPU.htrandom.set_rng_state_all
- Sets the random number generator state of all devices.htrandom.manual_seed
- Sets the seed for generating random numbers for the current HPU device.htrandom.manual_seed_all
- Sets the seed for generating random numbers on all HPUs.htrandom.seed
- Sets the seed for generating random numbers to a random number for the current HPU.htrandom.seed_all
- Sets the seed for generating random numbers to a random number on all HPUs.htrandom.initial_seed
- Returns the current random seed of the current HPU.
The below shows a usage example:
import torch
import habana_frameworks.torch.hpu.random as htrandom
state = htrandom.get_rng_state()
htrandom.set_rng_state(state)
initial_seed = htrandom.initial_seed()
htrandom.manual_seed(2)
htrandom.seed()
print (hrandom.initial_seed())
Dynamic Shape APIs¶
The Dynamic Shape APIs control how the Habana PyTorch bridge and Graph Compiler manage dynamic shapes in model scripts. Dynamic Shapes enabling is disabled by default. Please see the Dynamic Shapes Optimization section for more information on how to optimize for Dynamic Shapes in Data and Models.
import habana_frameworks.torch.hpu as ht
- Imports the package.ht.disable_dynamic_shape()
- Disables dynamic shape feature.ht.enable_dynamic_shape()
- Enables dynamic shape feature.
import habana_frameworks.torch.hpu as ht
ht.disable_dynamic_shape()
ht.enable_dynamic_shape()