Getting Started with PyTorch and Gaudi
On this Page
Getting Started with PyTorch and Gaudi¶
This guide provides simple steps for preparing a PyTorch model to run on Gaudi. Make sure to install the PyTorch packages provided by Habana. Installing public PyTorch packages is not supported.
To set up the PyTorch environment, refer to the Installation Guide.The supported PyTorch versions are listed in the Support Matrix.
Note
Please refer to the PyTorch Known Issues and Limitations section for a list of current limitations.
Creating a Simple PyTorch Example¶
The below example contains the highlighted Habana-specific modifications that have been added to the PyTorch Hello World example.
Create a file named example.py
with the code below.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import os
# Import Habana Torch Library
import habana_frameworks.torch.core as htcore
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
out = x.view(-1,28*28)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out
def train(net,criterion,optimizer,trainloader,device):
net.train()
train_loss = 0.0
correct = 0
total = 0
for batch_idx, (data, targets) in enumerate(trainloader):
data, targets = data.to(device), targets.to(device)
optimizer.zero_grad()
outputs = net(data)
loss = criterion(outputs, targets)
loss.backward()
# API call to trigger execution
htcore.mark_step()
optimizer.step()
# API call to trigger execution
htcore.mark_step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
train_loss = train_loss/(batch_idx+1)
train_acc = 100.0*(correct/total)
print("Training loss is {} and training accuracy is {}".format(train_loss,train_acc))
def test(net,criterion,testloader,device):
net.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (data, targets) in enumerate(testloader):
data, targets = data.to(device), targets.to(device)
outputs = net(data)
loss = criterion(outputs, targets)
# API call to trigger execution
htcore.mark_step()
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
test_loss = test_loss/(batch_idx+1)
test_acc = 100.0*(correct/total)
print("Testing loss is {} and testing accuracy is {}".format(test_loss,test_acc))
def main():
epochs = 20
batch_size = 128
lr = 0.01
milestones = [10,15]
load_path = './data'
save_path = './checkpoints'
if(not os.path.exists(save_path)):
os.makedirs(save_path)
# Target the Gaudi HPU device
device = torch.device("hpu")
# Data
transform = transforms.Compose([
transforms.ToTensor(),
])
trainset = torchvision.datasets.MNIST(root=load_path, train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
testset = torchvision.datasets.MNIST(root=load_path, train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=2)
net = SimpleModel()
net.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr,
momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)
for epoch in range(1, epochs+1):
print("=====================================================================")
print("Epoch : {}".format(epoch))
train(net,criterion,optimizer,trainloader,device)
test(net,criterion,testloader,device)
torch.save(net.state_dict(), os.path.join(save_path,'epoch_{}.pth'.format(epoch)))
scheduler.step()
if __name__ == '__main__':
main()
|
The example.py
presents a basic PyTorch code example. The Habana-specific lines are explained below.
Line 10 - Import Habana Torch Library:
import habana_frameworks.torch.core as htcore
Line 104 - Target the Gaudi HPU device:
device = torch.device("hpu")
Lines 47, 52 - In Lazy mode,
mark_step()
must be added in all training scripts right afterloss.backward()
andoptimizer.step()
.
htcore.mark_step()
Executing the Example¶
After creating the example.py
, perform the following:
Set PYTHON to python executable:
export PYTHON=/usr/bin/python3.8
Execute the
example.py
by running:
$PYTHON example.py
The following should appear as part of the output:
Epoch 1/5
469/469 [==============================] - 1s 3ms/step - loss: 1.2647 - accuracy: 0.7208
Epoch 2/5
469/469 [==============================] - 1s 2ms/step - loss: 0.7113 - accuracy: 0.8433
Epoch 3/5
469/469 [==============================] - 1s 2ms/step - loss: 0.5845 - accuracy: 0.8606
Epoch 4/5
469/469 [==============================] - 1s 2ms/step - loss: 0.5237 - accuracy: 0.8688
Epoch 5/5
469/469 [==============================] - 1s 2ms/step - loss: 0.4865 - accuracy: 0.8749
313/313 [==============================] - 1s 2ms/step - loss: 0.4482 - accuracy: 0.8869
Since the first iteration includes graph compilation time, you can see the first iteration takes longer to run than later iterations. The software stack compiles the graph and saves the recipe to cache. Unless the graph changes or a new graph comes in, no recompilation is needed during the training. Typically, the graph compilation happens at the beginning of the training and at the beginning of the evaluation.
Torch Multiprocessing for DataLoaders¶
If training scripts use multiprocessing with multiple workers for PyTorch dataloader, change the start method to spawn
or forkserver
using the
PyTorch API multiprocessing.set_start_method(...)
. For example:
torch.multiprocessing.set_start_method('spawn')
Default start method is fork
which may result in undefined behavior.