Using HPU Graphs for Training

In PyTorch, the Software stack supports Lazy mode. In Lazy mode, ops get accumulated and flushed only when the user triggers mark_step. With large models the op accumulation time (host time) could become significantly higher compared to device time. This could make the model host bound, hampering its performance.

The usage of HPU Graphs is optional and should be used in addition to running a model using Intel® Gaudi® AI accelerator’s Lazy mode.

HPU Graphs feature for training and inference overcomes this performance bottleneck. HPU Graphs bypasses all op accumulations by recording a static version of the entire graph, then replaying it. These static graphs use fixed memory locations for its inputs, outputs and the underlying graph. While HPU Graphs reduces host overhead significantly, dynamic flexibility of the model is sacrificed.

HPU Graph APIs are similar to the CUDA graph APIs. HPU Graphs provides some extra wrappers such as ModuleCacher.

Note

Eager mode is currently not supported.

Using HPU Graphs

In addition to providing HPU Graph basic APIs, such as Capture and Replay, Gaudi also provides graph, make_graphed_callables and ModuleCacher. The purpose of these APIs is to make your code more streamlined and are the recommended HPU Graph APIs to use in your training model. For example, unlike the basic HPU Graph APIs, ModuleCacher API is capable of accepting non-tensor and non-positional arguments. For more details, see Simple MNIST Example with HPU Graph APIs.

While only static graphs are supported, as a workaround for the dynamicity criteria, separate graphs are compiled for different shaped inputs or for different paths of the dynamic control flow. For example, you can use ModuleCacher API for training models with dynamic inputs. For more details, see Dynamicity in Models.

HPU Graphs can also be used to speed up the process when it is host bound. Detecting if a process is host bound can be achieved by profiling. For profiling examples using HPU Graphs, see Profiling Examples Using HPU Graph APIs.

To detect if a process is host bound, obtain its profile by running:

PT_FORCED_TRACING_MASK=0xffffffff TRACE_POINT_ENABLE=1 HBN_SYNAPSE_LOGGER_COMMANDS=stop_data_capture:no_eager_flush:use_pid_suffix LD_PRELOAD=/usr/local/lib/python3.8/dist-packages/habana_frameworks/torch/lib/pytorch_synapse_logger.so python test.py

HPU Graph APIs

The following section describes the usage of basic and high level HPU Graph APIs in a training model. For a description of each API, refer to HPU Graph APIs.

  • capture_begin, capture_end, and replay() APIs: HPU Graphs supports basic Capture and Replay APIs which are similar to the CUDA APIs. The following example shows how to replace cuda.CUDAGraph.capture_begin with hpu.HPUGraph.capture_begin:

    #torch.cuda.CUDAGraph.capture_begin()
    import habana_frameworks.torch.core as htcore
    htcore.hpu.HPUGraph.capture_begin()
    
  • graph and make_graphed_callables APIs: These APIs are similar to the CUDA graph API and CUDA make_graphed_callables API. The following example shows how to replace cuda.make_graphed_callables with hpu.make_graphed_callables:

    #torch.cuda.make_graphed_callables(...)
    import habana_frameworks.torch.core as htcore
    htcore.hpu.make_graphed_callables(...)
    
  • ModuleCacher API: This API provides another way of wrapping the model and handles dynamic inputs in a training model. ModuleCacher internally keeps track of whether an input shape has changed, and if so creates a new HPU graph. It also makes the model capable of accepting non-tensor and non-positional arguments which is a limitation when using the basic HPU Graph APIs. See Limitations of HPU Graph APIs.

    model = Net().to('hpu')
    htcore.hpu.ModuleCacher(max_graphs=10)(model=model, inplace=True)
    

Limitations of HPU Graph APIs

While the following limitations of HPU Graphs apply when using the basic APIs, using ModuleCacher provides a workaround for these limitations:

  • HPU Graphs training API which mandates the use of only tensors as inputs and outputs

  • The inputs can only be provided as positional args

Simple MNIST Example with HPU Graph APIs

This section shows a simple MNIST example with code examples of the following APIs:

  • Capture and Replay

  • make_graphed_callables

  • ModuleCacher

The common code defining the model and dataloader is shown here:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import habana_frameworks.torch.core as htcore
from tqdm import tqdm

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

device = torch.device("hpu")
model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=0.1)
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])
data_path = './data'
dataset1 = datasets.MNIST(data_path, train=True, download=True,
                    transform=transform)
batch_size = 200
stepsperepoch = None
train_kwargs = {'batch_size': batch_size}
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)

def test(model, device, train_loader):
    print('Running Test')
    model.eval()
    total = 0
    correct = 0
    for batch_idx, (data, target) in tqdm(enumerate(train_loader)):
        data, target = data.to(device), target.to(device)
        output = model(data)
        htcore.mark_step()
        correct += (torch.max(output, 1)[1] == target).sum().item()
        total += output.shape[0]
    print("Eval:", correct/total)

Training Loop Without HPU Graphs

The following shows a general training loop without HPU Graphs used:

def train(model, device, train_loader, optimizer):
    model.train()
    for batch_idx, (data, target) in tqdm(enumerate(train_loader)):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        # mark step is needed for HPU Lazy mode before and after the optimizer step
        htcore.mark_step()
        optimizer.step()
        htcore.mark_step()
train(model, device, train_loader, optimizer)
test(model, device, train_loader)

Training Loop with Capture and Replay

In the following example, the capture phase involves recording all the forward and backward passes, then, replaying it again and again in the actual training phase.

def train_with_capturereplay(model, device, train_loader, optimizer, batchsize):
    # Placeholders used for capture
    static_input = torch.randn(batchsize, 1, 28, 28, device='hpu')
    static_target = torch.randint(0,10,(batchsize,), device='hpu')

    # First we warmup
    s = htcore.hpu.Stream()
    s.wait_stream(htcore.hpu.current_stream())
    with htcore.hpu.stream(s):
        for i in range(3):
            optimizer.zero_grad(set_to_none=True)
            y_pred = model(static_input)
            loss = F.nll_loss(y_pred, static_target)
            loss.backward()
            optimizer.step()
    htcore.hpu.current_stream().wait_stream(s)

    # Then we capture
    g = htcore.hpu.HPUGraph()
    optimizer.zero_grad(set_to_none=True)
    with htcore.hpu.graph(g):
        static_y_pred = model(static_input)
        static_loss = F.nll_loss(static_y_pred, static_target)
        static_loss.backward()
        optimizer.step()

    # Finally the main training loop
    for batch_idx, (data, target) in tqdm(enumerate(train_loader)):
        # data must be copied to existing tensors that were used in the capture phase
        static_input.copy_(data)
        static_target.copy_(target)
        g.replay()
        # result is available in static_loss tensor after graph is replayed
batchsize=200
train_with_capturereplay(model, device, train_loader, optimizer, batchsize)
test(model, device, train_loader)

Training Loop with make_graphed_callables

The make_graphed_callables API can be used to wrap a nn.Module into a standalone graph. A model wrapped in this API creates separate graphs for forward and backward passes. make_graphed_callables internally creates HPU Graph objects, runs warmup iterations, and maintains static inputs and outputs as needed.

def train_with_make_graphed_callables(model, device, train_loader, optimizer, batchsize):
    x = torch.randn(batchsize, 1, 28, 28, device='hpu')
    model.train()
    model = htcore.hpu.make_graphed_callables(model, (x,))
    for batch_idx, (data, target) in tqdm(enumerate(train_loader)):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        htcore.mark_step()
        optimizer.step()
        htcore.mark_step()
batchsize=200
train_with_make_graphed_callables(model, device, train_loader, optimizer, batchsize)
test(model, device, train_loader)

make_graphed_callables can accommodate dynamic control flow or dynamic shapes, if we compile a separate graph for each control path or shape. See Dynamic Control Flow for more details.

Training Loop with ModuleCacher

ModuleCacher handles dynamic inputs automatically and is the recommended method for using HPU Graphs in training models.

Note

ModuleCacher does not handle dynamic control flow or dynamic ops.

def train_with_modelcacher(model, device, train_loader, optimizer, batchsize):
    model.train()
    htcore.hpu.ModuleCacher(max_graphs=10)(model=model, inplace=True)
    for batch_idx, (data, target) in tqdm(enumerate(train_loader)):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad(set_to_none=True)
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        htcore.mark_step()
        optimizer.step()
        htcore.mark_step()
batchsize=200
train_with_modelcacher(model, device, train_loader, optimizer, batchsize)
test(model, device, train_loader)

Note

  1. The basic APIs are useful in static graphs with no dynamic control. Also, the models can only input tensors and positional arguments.

  2. Run a host profile to make sure the model is host bound. If it is device bound, HPU Graphs will not show any benefit.

  3. For dynamic inputs, use ModuleCacher. ModuleCacher also provides a workaround of the input limitation of only tensor and positional arguments.

  4. Make sure to have sync/mark_step.

  5. Even simple models with static inputs might exhibit dynamicity. For example, if the batch size does not divide the dataset evenly, one will get a batch at the end that is a different shape than the ones that came before it. This can handled by using ModuleCacher or by using drop_last=True in the dataloader.

  6. Inference is also supported using HPU Graph APIs. For more details, see Run Inference Using HPU Graphs.

Dynamicity in Models

This section provides some examples for using HPU Graphs on models with dynamic control flow and dynamic ops. Dynamic inputs are handled by ModuleCacher as described in Training Loop with ModuleCacher.

Dynamic Control Flow

When dynamic control flow is present, the model needs to be separated into different HPU Graphs. In the example below, the output of module1 feeds module2 or module3 depending on the dynamic control flow.

import torch
import torch.nn as nn
import torch.optim as optim
import habana_frameworks.torch.core as htcore
import itertools
import random

device = 'hpu'
module1 = nn.Linear(784, 128).to(device)
module2 = nn.Linear(128, 100).to(device)
module3 = nn.Linear(128, 50).to(device)
all_params = itertools.chain(module1.parameters(), module2.parameters(), module3.parameters())
optimizer = optim.Adadelta(all_params, lr=0.1)

x = torch.randn(5, 784, device=device)
h = torch.randn(5, 128, device=device)

if device == 'hpu':
    module1 = torch.hpu.make_graphed_callables(module1, (x,))
    module2 = torch.hpu.make_graphed_callables(module2, (h,))
    module3 = torch.hpu.make_graphed_callables(module3, (h,))

real_inputs = [torch.rand_like(x) for _ in range(200)]
real_targets = [real_inputs[i].mean() for i in range(200)]
for data, target in zip(real_inputs, real_targets):
    data = data.to(device)
    target = target.to(device)
    htcore.mark_step()
    optimizer.zero_grad(set_to_none=True)
    tmp = module1(data) # forward ops run as a graph
    if random.random() > 0.5:
        print('Path 1')
        tmp = module2(tmp) # forward ops run as a graph
    else:
        print('Path 2')
        tmp = module3(tmp) # forward ops run as a graph
    loss = torch.abs(tmp.mean() - target)
    loss.backward()
    htcore.mark_step()
    optimizer.step()
    htcore.mark_step()
    print(loss)

Dynamic Ops

In this example we have module1 -> dynamic boolean indexing -> module2. Thus, both the static modules are placed into separate ModuleCacher and the dynamic op part is left out.

import torch
import torch.nn as nn
import torch.optim as optim
import habana_frameworks.torch.core as htcore
import itertools
import random

device = 'hpu'
module1 = nn.Linear(10, 3).to(device)
module2 = nn.ReLU().to(device)

all_params = itertools.chain(module1.parameters(), module2.parameters())
optimizer = optim.Adadelta(all_params, lr=0.1)

if device == 'hpu':
    htcore.hpu.ModuleCacher(max_graphs=10)(model=module1, inplace=True)
    htcore.hpu.ModuleCacher(max_graphs=10)(model=module2, inplace=True)

real_inputs = [torch.randn(5, 10) for _ in range(200)]
real_targets = [real_inputs[i].mean() for i in range(200)]
for data, target in zip(real_inputs, real_targets):
    data = data.to(device)
    target = target.to(device)
    htcore.mark_step()
    optimizer.zero_grad(set_to_none=True)
    tmp = module1(data)

    # dynamic op
    htcore.mark_step()
    tmp = tmp[torch.where(tmp > 0)]
    htcore.mark_step()

    tmp = module2(tmp)
    loss = torch.abs(tmp.mean() - target)
    loss.backward()
    htcore.mark_step()
    optimizer.step()
    htcore.mark_step()
    print(loss)

Profiling Examples Using HPU Graph APIs

This is an example of a simple two layer network on first-gen Gaudi. For the base case, the example is run without any HPU Graphs. Then make_graphed_callables is used and finally capture-and-replay is used.

As shown in the below table make_graphed_callables produces some speedup, but capture-replay is the most efficient mode.

Mode

Time

No HPU Graphs

10.937

make_graphed_callables

9.747

capture replay

6.191

Next, the effect of each mode on host times is observed. The above modes are used with --steps=20 appended to study the host profile for 20 steps, along with the following environment variables:

PT_FORCED_TRACING_MASK=0xffffffff TRACE_POINT_ENABLE=1 HBN_SYNAPSE_LOGGER_COMMANDS=stop_data_capture:no_eager_flush:use_pid_suffix LD_PRELOAD=/usr/local/lib/python3.8/dist-packages/habana_frameworks/torch/lib/pytorch_synapse_logger.so

The base case with no HPU Graphs has a CPU processing time of 4.8ms:

../../_images/nograph.png

The host profile with make_graphed_callables has a CPU processing time of around 3.9 ms:

../../_images/make_graphed_callable.png

The least host time is observed in the capture and replay mode, with a host time of around 1.3ms:

../../_images/capturereplay.png