Optimize Inference on PyTorch
On this Page
Optimize Inference on PyTorch¶
Inference workloads require optimization as they are more prone to host overhead. Since an inference step consumes less computation than a training step and is usually executed with smaller batch sizes, host overhead is more likely to increase throughput and reduce latency.
This document describes how to apply the below optimization methods to minimize host overhead and improve inference performance.
Identifying Host Overhead¶
The following example shows the host overhead observed in a typical inference workload on PyTorch:
for query in inference_queries:
# Optional pre-processing on the host, such as tokenization or normalization of images
# Copy the result to device memory
query_hpu = query.to(hpu)
# Perform processing
output_hpu = model(query_hpu)
# Copy the result to host memory
output = output_hpu.to(cpu)
# Optional host post-processing, such as decoding
The model’s forward()
function typically involves a series of computations. When executed without optimization,
each line of Python code in the forward()
call is evaluated by the Python interpreter, passed through
the PyTorch Python front-end, then sent to the Habana PyTorch bridge. Processing on the device only occurs when
mark_step
is invoked or when the copy to the CPU is requested. See the illustration below.
The diagram indicates that the HPU will have extended periods of inactivity due to computation steps being dependent on each other.
To identify similar cases, use the Habana integration with the PyTorch Profiler, see Profiling with Pytorch section. If there are gaps between device invocations, host overhead is impeding throughput. When host overhead is minimal, the device will function continuously following a brief ramp-up period.
The following are three techniques to lower host overhead and enhance your inference performance.
Using HPU Graphs¶
As described in Run Inference Using HPU Graphs, wrapping the forward()
call in htorch.hpu.wrap_in_hpu_graph
minimizes the time
in which each line of Python code in the forward()
call is evaluated by the Python interpreter. See the example below.
Consequently, minimizing this time allows the HPU to start copying the output and running the computation faster, improving throughput.
Note
Using HPU Graphs for optimizing inference on Gaudi is highly recommended.
Using Asynchronous Copies¶
By default, the host thread that submits the computation to the device will wait for the copy operation to complete.
However, by specifying the argument non_blocking=True
during the copy operation, the Python thread can
continue to execute other tasks while the copy occurs in the background.
To use Asynchronous Copies, replace query_hpu = query.to(hpu)
with query_hpu = query.to(hpu, non_blocking=True)
.
See our Wav2vec inference example.
The following is an example of the timing diagram:
Note
Asynchronous Copies are currently supported from the host to the device only.
Using Software Pipelining¶
Software pipelining is a technique of unrolling a loop to overlap preprocessing with the rest of the computation. In the example below, we overlap preprocessing by reworking the loop.
query = host_preprocessing(...)
for next_query in queries: # note that this starts with the second query
# Copy the result of preprocessing on the query to device memory
query_hpu = query.to(hpu)
# Perform processing
output_hpu = model(query_hpu)
# Perform the preprocessing on the next query. Can consider using a separate thread
next_query_preprocessed = host_preprocessing(next_query)
# Copy the result of the query processing to host memory
output = output_hpu.to(cpu)
# Optional host post-processing, such as decoding
# Prepare the next query
query = next_query_preprocessed
The following is an example of the timing diagram in the steady state:
Based on the relationship between the CPU tasks’ length and the computation on HPU, it may be beneficial to pipeline the post-processing and use multiple threads for CPU tasks to improve performance, especially when HPU computation time is much longer than device time.
Note
Software pipelining is currently experimental.