mark_step
On this Page
mark_step¶
This function is a key control point in Gaudi’s PyTorch stack for triggering compilation and execution of accumulated operations.
The mark_step(device_str="", sync=False)
function allows to manually flush these operations into a computation graph
for compilation and execution on the HPU. It plays a crucial role in performance optimization by grouping operations into manageable
graph segments and enabling efficient graph caching.
Parameters¶
The following table describes the parameters used by the mark_step
function:
Name |
Type |
Description |
---|---|---|
device_str |
str, optional |
The device identifier (e.g., |
sync |
bool, optional (default=False) |
If set to |
Behavior¶
Calling mark_step
performs the following actions:
Collects all operations accumulated so far.
Attempts to fetch a cached graph matching these operations.
If the graph is not cached (cache miss): the graph is lowered, compiled, and launched.
If the graph is cached (cache hit): the compiled graph is reused and launched.
If
sync=True
, waits for host-side execution to complete.
Common Scenarios for Using mark_step
¶
Use mark_step
in scenarios where explicit control over graph boundaries is needed in lazy execution mode.
Scenario |
Description |
---|---|
Splitting the Graph Intentionally |
Helps divide the computational graph into smaller subgraphs, reducing compilation time and enabling overlap between host (CPU) and device (HPU) workloads. |
Custom Training Loops or Graph Control |
Essential when implementing manual control flow, such as custom forward/backward passes or looped training logic. It forces execution of accumulated operations, preventing compilation of unintentional large graphs. |
Debugging or Profiling Graph Compilation |
Helps isolate graph segments, assisting in identifying bottlenecks during compilation or execution. |
Preventing Oversized Graphs |
Prevents accumulation of large graphs in lazy mode, which may cause:
- High compilation latency
- Increased HPU memory usage (OOM scenarios)
Placing |
Usage Conditions and Example¶
The mark_step
function is typically needed when:
Running in lazy execution mode
PT_HPU_LAZY_MODE=1
.Fine-grained control over execution is required.
Profiling, performance tuning, or debugging graph execution.
Manually triggering graph flushing to optimize memory usage or reduce compilation overhead.
Note
In eager mode or when using
torch.compile
,mark_step
has no functional effect.In many common cases, an implicit
mark_step
is triggered automatically by operations such as.cpu()
or.item()
.
Output example:
import torch
import habana_frameworks.torch.core as htcore
def run_model():
x = torch.randn(64, 128, device="hpu")
w1 = torch.randn(128, 256, device="hpu")
w2 = torch.randn(256, 512, device="hpu")
# Stage 1: Lazy ops
y = x @ w1
y = torch.relu(y)
htcore.mark_step() # Triggers graph compilation and execution for stage 1
# Stage 2: More ops (accumulated lazily until next mark_step)
z = y @ w2
z = torch.nn.functional.relu(z)
htcore.mark_step() # Compiles and executes stage 2
return z
out = run_model()
Best Practices & Caution¶
Avoid overusing :code:`mark_step()`:
Each call to
mark_step
introduces a new compilation boundary. Excessive calls can degrade performance by increasing overhead and reducing compilation cache efficiency.Use
mark_step
only when a logical separation in the computation graph is required as during forward/backward passes, pipelined stages, or to enable asynchronous overlap between CPU and HPU execution.
:code:`mark_step` does not synchronize device execution:
mark_step
only ensures that host-side dispatch has occurred; it does not guarantee completion of all operations on the HPU.To fully synchronize device execution, especially when measuring runtime performance, make sure to use the following:
torch.hpu.synchronize()