Run Inference Using HPU Graphs

HPU Graphs can capture a series of operations using HPU stream and replay them. They mandate the device-side input while output addresses remain constant between invocations. For further details on Stream APIs and HPU Graph APIs, refer to Stream APIs and HPU Graph APIs.


  • HPU Graphs offer the best performance with minimal host overhead. However, their functionality is currently limited.

  • Only models that run completely on HPU have been tested. Models that run partially on CPU may not work.

  • HPU Graphs can be only used to capture and replay static graphs. Dynamic shapes are not supported.

  • Inference using HPU Graphs has been validated only on single cards.

Please refer to the PyTorch Known Issues and Limitations section for a list of current limitations.

Follow the steps in Porting PyTorch Models to Gaudi to prepare the PyTorch model to run on Gaudi. Adding mark_step is not required with HPU Graphs as it is handled implicitly.

To run inference using HPU Graphs, create an example object to show the capture and replay of HPU Graphs:

device = torch.device('hpu')

class GraphTest(object):
def __init__(self, size: int) -> None:
    self.g = ht.hpu.HPUGraph()
    self.s = ht.hpu.Stream()
    self.size = size

# The following function shows the steps to implement the capture and replay on HPU Graphs
def wrap_func(self, first: bool) -> None:
    if first:
            # the following code snippet is for demonstration
            a = torch.full((self.size,), 1, device=device)
            b = a
            b = b + 1

The following is an example test using HPU Graphs created in the above steps.

def test_graph_capture_simple():
    # The following example shows how multiple HPU Graphs can be initialized
    gt1 = GraphTest(size=1000)
    gt2 = GraphTest(size=2000)
    gt3 = GraphTest(size=3000)
    for i in range(10):
        if i == 0:
            # HPU graphs capture is done in the first iteration
            # The replay of the captured HPU Graphs is done in subsequent iterations

if __name__ == "__main__":
print("test ran")