Run Inference Using Native PyTorch

Follow the steps in Importing PyTorch Models Manually to prepare the PyTorch model to run on Gaudi.

Use model.eval Mode

Run the forward path (model inference) using model.eval mode. See the example below.

  1. Create the model and move it to the HPU:

# The following example uses an untrained model for demonstration purposes only.
# For inference, load a model from pretrained or checkpoint (refer to Saving and Loading Models page https://pytorch.org/tutorials/beginner/saving_loading_models.html).
device = torch.device('hpu')
in_c, out_c = 3, 64
k_size = 7
stride = 2
conv = torch.nn.Conv2d(in_c, out_c, kernel_size=k_size, stride=stride, bias=True)
bn = torch.nn.BatchNorm2d(out_c, eps=0.001)
relu = torch.nn.ReLU()
model = torch.nn.Sequential(conv, bn, relu)
# Let Pytorch optimize the model for inference.
model.eval()
# Place the model on HPU.
model = model.to(device)
  1. Create the inputs and move them to the HPU to run model inference:

# Create inputs and move them to the HPU.
N, H, W = 256, 224, 224
input = torch.randn((N,in_c,H,W),dtype=torch.float)
input_hpu = input.to(hpu)
# Invoke the model.
output = model(input_hpu)
# In Lazy mode execution, mark_step() must be added after model inference.
htcore.mark_step()

Note

This method is recommended for running inference on Gaudi as it supports most of the models. However, the throughput may not be optimal if the output is copied back to the host in every iteration.

Use torch.jit.trace Mode

Load and save models in JIT trace format using torch.jit.trace mode. For further details on JIT format, refer to TORCHSCRIPT page.

  1. Create the model and move it to the HPU:

device = torch.device('hpu')
in_c, out_c = 3, 64
k_size = 7
stride = 2
conv = torch.nn.Conv2d(in_c, out_c, kernel_size=k_size, stride=stride, bias=True)
bn = torch.nn.BatchNorm2d(out_c, eps=0.001)
relu = torch.nn.ReLU()
model = torch.nn.Sequential(conv, bn, relu)
model.eval()
model = model.to(device)
  1. Save the model using torch.jit.trace:

N, H, W = 256, 224, 224
model_input = torch.randn((N,in_c,H,W), dtype=torch.float).to(device)
with torch.no_grad():
    trace_model = torch.jit.trace(model, (model_input), check_trace=False, strict=False)
    # Save the HPU model with torch.jit.save.
    trace_model.save("trace_model.pt")
  1. Load the model using torch.jit.load:

# Load the model directly to HPU.
model = torch.jit.load("trace_model.pt", map_location=torch.device('hpu'))
  1. Create the inputs and move them to the HPU to run model inference:

# Create inputs and move them to the HPU.
N, H, W = 256, 224, 224
input = torch.randn((N,in_c,H,W),dtype=torch.float)
input_hpu = input.to(device)
# Invoke the model.
output = model(input_hpu)
# In Lazy mode execution, mark_step() must be added after model inference.
htcore.mark_step()

Note

JIT format is functionally correct but not yet optimized. This will be supported in a future release.