mark_step

This function is a key control point in Gaudi’s PyTorch stack for triggering compilation and execution of accumulated operations.

The mark_step(device_str="", sync=False) function allows to manually flush these operations into a computation graph for compilation and execution on the HPU. It plays a crucial role in performance optimization by grouping operations into manageable graph segments and enabling efficient graph caching.

Parameters

The following table describes the parameters used by the mark_step function:

Name

Type

Description

device_str

str, optional

The device identifier (e.g., "hpu").

sync

bool, optional (default=False)

If set to True, the function will wait for host-side execution of the graph to complete. Default is asynchronous execution. This does not synchronize device-side execution (i.e., it is not equivalent to torch.hpu.synchronize()).

Behavior

Calling mark_step performs the following actions:

  1. Collects all operations accumulated so far.

  2. Attempts to fetch a cached graph matching these operations.

  3. If the graph is not cached (cache miss): the graph is lowered, compiled, and launched.

  4. If the graph is cached (cache hit): the compiled graph is reused and launched.

  5. If sync=True, waits for host-side execution to complete.

Common Scenarios for Using mark_step

Use mark_step in scenarios where explicit control over graph boundaries is needed in lazy execution mode.

Scenario

Description

Splitting the Graph Intentionally

Helps divide the computational graph into smaller subgraphs, reducing compilation time and enabling overlap between host (CPU) and device (HPU) workloads.

Custom Training Loops or Graph Control

Essential when implementing manual control flow, such as custom forward/backward passes or looped training logic. It forces execution of accumulated operations, preventing compilation of unintentional large graphs.

Debugging or Profiling Graph Compilation

Helps isolate graph segments, assisting in identifying bottlenecks during compilation or execution.

Preventing Oversized Graphs

Prevents accumulation of large graphs in lazy mode, which may cause: - High compilation latency - Increased HPU memory usage (OOM scenarios) Placing mark_step earlier forces smaller, more manageable graph chunks.

Usage Conditions and Example

The mark_step function is typically needed when:

  • Running in lazy execution mode PT_HPU_LAZY_MODE=1.

  • Fine-grained control over execution is required.

  • Profiling, performance tuning, or debugging graph execution.

  • Manually triggering graph flushing to optimize memory usage or reduce compilation overhead.

Note

  • In eager mode or when using torch.compile, mark_step has no functional effect.

  • In many common cases, an implicit mark_step is triggered automatically by operations such as .cpu() or .item().

Output example:

import torch
import habana_frameworks.torch.core as htcore

def run_model():
    x = torch.randn(64, 128, device="hpu")
    w1 = torch.randn(128, 256, device="hpu")
    w2 = torch.randn(256, 512, device="hpu")

    # Stage 1: Lazy ops
    y = x @ w1
    y = torch.relu(y)

    htcore.mark_step()  # Triggers graph compilation and execution for stage 1

    # Stage 2: More ops (accumulated lazily until next mark_step)
    z = y @ w2
    z = torch.nn.functional.relu(z)

    htcore.mark_step()  # Compiles and executes stage 2

    return z

out = run_model()

Best Practices & Caution

Avoid overusing :code:`mark_step()`:

  • Each call to mark_step introduces a new compilation boundary. Excessive calls can degrade performance by increasing overhead and reducing compilation cache efficiency.

  • Use mark_step only when a logical separation in the computation graph is required as during forward/backward passes, pipelined stages, or to enable asynchronous overlap between CPU and HPU execution.

:code:`mark_step` does not synchronize device execution:

  • mark_step only ensures that host-side dispatch has occurred; it does not guarantee completion of all operations on the HPU.

  • To fully synchronize device execution, especially when measuring runtime performance, make sure to use the following:

    torch.hpu.synchronize()