HPU Graphs for Training
On this Page
HPU Graphs for Training¶
In PyTorch, the HPU stack supports both eager and lazy mode. In lazy mode, ops get accumulated and flushed only when the user triggers mark_step
.
With large models the op accumulation time (host time) could become significantly higher compared to device time. This could make the model host bound,
hampering its performance.
The usage of HPU Graphs is optional and should be used in addition to running a model using Habana’s Lazy Mode.
HPU Graphs feature for training and inference overcomes this performance bottleneck. HPU Graphs bypasses all op accumulations by recording a static version of the entire graph, then replaying it. These static graphs use fixed memory locations for its inputs, outputs and the underlying graph. While HPU Graphs reduces host overhead significantly, dynamic flexibility of the model is sacrificed.
HPU Graph APIs are similar to the CUDA graph APIs; HPU Graphs provides some extra wrappers such as ModuleCacher
.
Using HPU Graphs¶
In addition to providing HPU Graph basic APIs, such as Capture and Replay,
Habana also provides graph
, make_graphed_callables
and ModuleCacher
.
The purpose of these APIs is to make your code more streamlined and are the recommended HPU Graph APIs to use in your training model.
For example, unlike the basic HPU Graph APIs, ModuleCacher
API is capable of accepting non-tensor and non-positional arguments.
For more details, see Simple MNIST Example with HPU Graph APIs.
While only static graphs are supported, as a workaround for the dynamicity criteria, separate graphs are compiled for different shaped inputs or for different paths of
the dynamic control flow. For example, you can use ModuleCacher
API for training models with dynamic inputs.
For more details, see Dynamicity in Models.
HPU Graphs can also be used to speed up the process when it is host bound. Detecting if a process is host bound can be achieved by profiling. For profiling examples using HPU Graphs, see Profiling Examples Using HPU Graph APIs.
To detect if a process is host bound, obtain its profile by running:
HPU Graph APIs¶
The following section describes the usage of basic and high level HPU Graph APIs in a training model. For a description of each API, refer to HPU Graph APIs.
capture_begin, capture_end, and replay() APIs: HPU Graphs supports basic Capture and Replay APIs which are similar to the CUDA APIs. The following example shows how to replace
cuda.CUDAGraph.capture_begin
withhpu.HPUGraph.capture_begin
:graph and make_graphed_callables APIs: These APIs are similar to the CUDA graph API and CUDA make_graphed_callables API The following example shows how to replace
cuda.make_graphed_callables
withhpu.make_graphed_callables
:ModuleCacher API: This API provides another way of wrapping the model and handles dynamic inputs in a training model.
ModuleCacher
internally keeps track of whether an input shape has changed, and if so creates a new HPU graph. It also makes the model capable of accepting non-tensor and non-positional arguments which is a limitation when using the basic HPU Graph APIs. See Limitations of HPU Graph APIs.
Limitations of HPU Graph APIs¶
While the following limitations of HPU Graphs apply when using the basic APIs, using ModuleCacher
provides a workaround for these limitations:
HPU Graphs training API which mandates the use of only tensors as inputs and outputs
The inputs can only be provided as positional args
Simple MNIST Example with HPU Graph APIs¶
This section shows a simple MNIST example with code examples of the following APIs:
Capture and Replay
make_graphed_callables
ModuleCacher
The common code defining the model and dataloader is shown here:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import habana_frameworks.torch.core as htcore
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 10)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
device = torch.device("hpu")
model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=0.1)
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
data_path = './data'
dataset1 = datasets.MNIST(data_path, train=True, download=True,
transform=transform)
batch_size = 200
stepsperepoch = None
train_kwargs = {'batch_size': batch_size}
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
Training Loop Without HPU Graphs¶
The following shows a general training loop without HPU Graphs used:
def train(model, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
# mark step is needed for HPU lazy mode before and after the optimizer step
htcore.mark_step()
optimizer.step()
htcore.mark_step()
train(model, device, train_loader, optimizer)
Training Loop with Capture and Replay¶
In the following example, the capture phase involves recording all the forward and backward passes, then, replaying it again and again in the actual training phase.
def train_with_capturereplay(model, device, train_loader, optimizer, batchsize):
# Placeholders used for capture
static_input = torch.randn(batchsize, 1, 28, 28, device='hpu')
static_target = torch.randint(0,10,(batchsize,), device='hpu')
# First we warmup
s = htcore.hpu.Stream()
s.wait_stream(htcore.hpu.current_stream())
with htcore.hpu.stream(s):
for i in range(3):
optimizer.zero_grad(set_to_none=True)
y_pred = model(static_input)
loss = F.nll_loss(y_pred, static_target)
loss.backward()
optimizer.step()
htcore.hpu.current_stream().wait_stream(s)
# Then we capture
g = htcore.hpu.HPUGraph()
optimizer.zero_grad(set_to_none=True)
with htcore.hpu.graph(g):
static_y_pred = model(static_input)
static_loss = F.nll_loss(static_y_pred, static_target)
static_loss.backward()
optimizer.step()
# Finally the main training loop
for batch_idx, (data, target) in enumerate(train_loader):
# data must be copied to existing tensors that were used in the capture phase
static_input.copy_(data)
static_target.copy_(target)
g.replay()
# result is available in static_loss tensor after graph is replayed
train_with_capturereplay(model, device, train_loader, optimizer, batchsize)
Training Loop with make_graphed_callables¶
The make_graphed_callables
API can be used to wrap a nn.Module
into a standalone graph.
A model wrapped in this API creates separate graphs for forward and backward passes.
make_graphed_callables
internally creates HPU Graph objects, runs warmup iterations,
and maintains static inputs and outputs as needed.
def train_with_make_graphed_callables(model, device, train_loader, optimizer, batchsize):
x = torch.randn(batchsize, 1, 28, 28, device='hpu')
model.train()
model = htcore.hpu.make_graphed_callables(model, (x,))
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
train_with_make_graphed_callables(model, device, train_loader, optimizer, batchsize)
make_graphed_callables
can accommodate dynamic control flow or dynamic shapes, if we compile a separate graph for each control path or shape.
See Dynamic Control Flow for more details.
Training Loop with ModuleCacher¶
ModuleCacher
handle dynamic inputs automatically and is the recommended method for using HPU Graphs in training models.
Note
ModuleCacher
does not handle dynamic control flow or dynamic ops.
def train_with_modelcacher(model, device, train_loader, optimizer, batchsize):
model.train()
htcore.hpu.ModuleCacher(max_graphs=10)(model=model, inplace=True)
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
htcore.mark_step()
optimizer.zero_grad(set_to_none=True)
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
htcore.mark_step()
train_with_modelcacher(model, device, train_loader, optimizer, batchsize)
Note
The basic APIs are useful in static graphs with no dynamic control. Also, the models can only input tensors and positional arguments.
Run a host profile to make sure the model is host bound. If it is device bound, HPU Graphs will not show any benefit.
For dynamic inputs, use
ModuleCacher
.ModuleCacher
also provides a workaround of the input limitation of only tensor and positional arguments.Make sure to have sync/mark_step.
Even simple models with static inputs might exhibit dynamicity. For example, if the batch size does not divide the dataset evenly, one will get a batch at the end that is a different shape than the ones that came before it. This can handled by using
ModuleCacher
or by usingdrop_last=True
in the dataloader.Inference is also supported using HPU Graph APIs. For more details, see Run Inference Using HPU Graphs.
Dynamicity in Models¶
This section provides some examples for using HPU Graphs on models with dynamic control flow and dynamic ops.
Dynamic inputs are handled by ModuleCacher
as described in Training Loop with ModuleCacher.
Dynamic Control Flow¶
When dynamic control flow is present, the model needs to be separated into different HPU Graphs.
In the example below, the output of module1
feeds module2
or module3
depending
on the dynamic control flow.
import torch
import torch.nn as nn
import torch.optim as optim
import habana_frameworks.torch.core as htcore
import itertools
import random
device = 'hpu'
module1 = nn.Linear(784, 128).to(device)
module2 = nn.Linear(128, 100).to(device)
module3 = nn.Linear(128, 50).to(device)
all_params = itertools.chain(module1.parameters(), module2.parameters(), module3.parameters())
optimizer = optim.Adadelta(all_params, lr=0.1)
x = torch.randn(5, 784, device=device)
h = torch.randn(5, 128, device=device)
if device == 'hpu':
module1 = torch.hpu.make_graphed_callables(module1, (x,))
module2 = torch.hpu.make_graphed_callables(module2, (h,))
module3 = torch.hpu.make_graphed_callables(module3, (h,))
real_inputs = [torch.rand_like(x) for _ in range(200)]
real_targets = [real_inputs[i].mean() for i in range(200)]
for data, target in zip(real_inputs, real_targets):
data = data.to(device)
target = target.to(device)
htcore.mark_step()
optimizer.zero_grad(set_to_none=True)
tmp = module1(data) # forward ops run as a graph
if random.random() > 0.5:
print('Path 1')
tmp = module2(tmp) # forward ops run as a graph
else:
print('Path 2')
tmp = module3(tmp) # forward ops run as a graph
loss = torch.abs(tmp.mean() - target)
loss.backward()
htcore.mark_step()
optimizer.step()
htcore.mark_step()
print(loss)
Dynamic Ops¶
In this example we have module1 -> dynamic boolean indexing -> module2.
Thus, both the static modules are placed into separate ModuleCacher
and the dynamic op part is left out.
import torch
import torch.nn as nn
import torch.optim as optim
import habana_frameworks.torch.core as htcore
import itertools
import random
device = 'hpu'
module1 = nn.Linear(10, 3).to(device)
module2 = nn.ReLU().to(device)
all_params = itertools.chain(module1.parameters(), module2.parameters())
optimizer = optim.Adadelta(all_params, lr=0.1)
if device == 'hpu':
htcore.hpu.ModuleCacher(max_graphs=10)(model=module1, inplace=True)
htcore.hpu.ModuleCacher(max_graphs=10)(model=module2, inplace=True)
real_inputs = [torch.randn(5, 10) for _ in range(200)]
real_targets = [real_inputs[i].mean() for i in range(200)]
for data, target in zip(real_inputs, real_targets):
data = data.to(device)
target = target.to(device)
htcore.mark_step()
optimizer.zero_grad(set_to_none=True)
tmp = module1(data)
# dynamic op
htcore.mark_step()
tmp = tmp[torch.where(tmp > 0)]
htcore.mark_step()
tmp = module2(tmp)
loss = torch.abs(tmp.mean() - target)
loss.backward()
htcore.mark_step()
optimizer.step()
htcore.mark_step()
print(loss)
Profiling Examples Using HPU Graph APIs¶
This is an example of a simple two layer network on first-gen Gaudi.
For the base case, the example is run without any HPU Graphs.
Then make_graphed_callables
is used and finally capture-and-replay is used.
As shown in the below table make_graphed_callables
produces some speedup, but capture-replay is the most efficient mode.
Mode |
Time |
---|---|
No HPU Graphs |
10.937 |
make_graphed_callables |
9.747 |
capture replay |
6.191 |
Next, the effect of each mode on host times is observed. The above modes are used with --steps=20
appended to study the host profile for 20 steps, along with the following environment variables:
The base case with no HPU Graphs has a CPU processing time of 4.8ms:
The host profile with make_graphed_callables
has a CPU processing time of around 3.9 ms:
The least host time is observed in the capture and replay mode, with a host time of around 1.3ms: