Run Inference Using HPU Graphs
Run Inference Using HPU Graphs¶
HPU Graphs can capture a series of operations using HPU stream and replay them. They mandate the device-side input while output addresses remain constant between invocations. For further details on Stream APIs and HPU Graph APIs, refer to Stream APIs and HPU Graph APIs.
HPU Graphs offer the best performance with minimal host overhead. However, their functionality is currently limited.
Only models that run completely on HPU have been tested. Models that contain CPU Ops are not supported. During HPU Graphs capturing, in case the Op is not supported, the following message will appear: “… is not supported during HPU Graph capturing”.
HPU Graphs can be only used to capture and replay static graphs. Dynamic shapes are not supported.
Data Dependent dynamic flow is not supported with HPU Graphs.
Capturing HPU Graphs on models containing in-place view updates is not supported.
Multi-card support for inference (DeepSpeed) using HPU Graphs is applicable only with
PT_HPU_ENABLE_LAZY_COLLECTIVES=true. For further details, refer to Inference Using DeepSpeed Guide.
Please refer to the PyTorch Known Issues and Limitations section for a list of current limitations.
Follow the steps in Porting PyTorch Models to Gaudi to prepare the PyTorch model to run on Gaudi.
mark_step is not required with HPU Graphs as it is handled implicitly.
To run inference using HPU Graphs, create an example object to show the capture and replay of HPU Graphs:
device = torch.device('hpu') class GraphTest(object): def __init__(self, size: int) -> None: self.g = ht.hpu.HPUGraph() self.s = ht.hpu.Stream() self.size = size # The following function shows the steps to implement the capture and replay on HPU Graphs def wrap_func(self, first: bool) -> None: if first: with ht.hpu.stream(self.s): self.g.capture_begin() # the following code snippet is for demonstration a = torch.full((self.size,), 1, device=device) b = a b = b + 1 self.g.capture_end() else: self.g.replay()
htorch.hpu.wrap_in_hpu_graph to wrap module forward function with HPU Graphs. This wrapper captures, caches and replays the graph.
import torch import habana_frameworks.torch as ht model = GetModel() model = ht.hpu.wrap_in_hpu_graph(model)