MediaPipe for PyTorch ResNet
On this Page
MediaPipe for PyTorch ResNet¶
This section describes how to define a ResnetMediaPipe
class
and implement a MediaPipe API dataloader, MediaApiDataLoader
, for PyTorch ResNet.
Defining a ResnetMediaPipe Class¶
The ResnetMediaPipe
class is derived from MediaPipe
class, as described in Creating and Executing Media Pipeline, implementing the MediaPipe API for ResNet.
An implementation of ResnetMediaPipe
class can be found in resnet_media_pipe.py
in Intel Gaudi Model References GitHub repository.
The following set of operations is performed on the data:
Reading image using ReadImageDatasetFromDir
Decoding image using ImageDecoder
- For training data only:
Random cropping as a part of ImageDecoder with
RANDOMIZED_AREA_AND_ASPECT_RATIO_CROP
asrandom_crop_type
Resizing as a part of ImageDecoder
Random flipping implemented using RandomFlip interface
- For evaluation data only:
Resizing as a part of ImageDecoder
Cropping as a part of CropMirrorNorm
Normalizing using CropMirrorNorm to produce float32 output
API functions are first defined in the ResnetMediaPipe
constructor.
Then, a sequence of operations is set up in the definegraph
method:
def definegraph(self):
jpegs, data = self.input()
images = self.decode(jpegs)
if self.is_training == True:
flip = self.random_flip_input()
images = self.random_flip(images, flip)
mean = self.norm_mean()
std = self.norm_std()
images = self.cmn(images, mean, std)
return images, data
Implementing MediaApiDataLoader¶
The MediaApiDataLoader
includes an HPUResnetPytorchIterator
interface, which iterates over image and label pairs produced by the ResnetMediaPipe
instance.
This interface is compatible with torch.utils.data.DataLoader
, allowing you
to re-use your PyTorch ResNet model code as written for torch.utils.data.DataLoader
without additional changes.
The relevant code is defined in MediaApiDataLoader
function located in
data_loaders.py:
class MediaApiDataLoader(torch.utils.data.DataLoader):
def __init__(self, dataset, sampler, batch_size, num_workers, pin_memory=True, pin_memory_device=None, is_training=False):
# setting parameters here, code removed for clarity
from resnet_media_pipe import ResnetMediaPipe
pipeline = ResnetMediaPipe(is_training=is_training, root=root, batch_size=batch_size,
shuffle=self.shuffle, drop_last=False, queue_depth=queue_depth,
num_instances=num_instances, instance_id=instance_id, device=device_string)
from habana_frameworks.mediapipe.plugins.iterator_pytorch import HPUResnetPytorchIterator
self.iterator = HPUResnetPytorchIterator(mediapipe=pipeline)
print("Running with Media API DataLoader")
Selecting a DataLoader for PyTorch ResNet¶
In Intel Gaudi’s PyTorch ResNet, three dataloaders are implemented in the model:
MediaApiDataLoader
Native PyTorch DataLoader
Habana Dataloader
Note
Habana Dataloader
(AeonDataLoader) is a software acceleration dataloader used for first-gen Gaudi, where Media API acceleration is not available.
The PyTorch ResNet implementation of Aeon is covered in AeonDataLoader
in the data_loaders.py script.
The MediaApiDataLoader
can be used only if the following conditions are met:
The
habana_media_loader
package is installed.The workload is run on Gaudi 2. First-gen Gaudi does not support Media API.
The
--dl-worker-type
argument for the PyTorch ResNet script is set to “HABANA” (the default value).
These conditions are handled by the choose_data_loader
function in the
data_loaders.py file.
Choosing and creating a proper dataset is managed in build_data_loader
in the same file:
def build_data_loader(is_training, dl_worker_type, **kwargs):
data_loader_type = choose_data_loader(dl_worker_type)
use_fallback = False
try:
if data_loader_type == DataLoaderType.MediaAPI:
return MediaApiDataLoader(**kwargs, is_training=is_training)
elif data_loader_type == DataLoaderType.Aeon:
return AeonDataLoader(**kwargs)
except Exception as e:
if os.getenv('DATALOADER_FALLBACK_EN', "True") == "True":
print(f"Failed to initialize Habana Dataloader, error: {str(e)}\nRunning with PyTorch Dataloader")
return torch.utils.data.DataLoader(**kwargs)
else:
print(f"Habana dataloader configuration failed: {e}")
raise
if data_loader_type == DataLoaderType.Python:
return torch.utils.data.DataLoader(**kwargs)
else:
raise ValueError(f"Unknown data_loader_type {data_loader_type}")
If all conditions are met,
data_loader_type == DataLoaderType.MediaAPI
is obtained from data_loader_type
.
If any errors are encountered when building the MediaApiDataLoader
,
fallback to Habana Dataloader
or the native PyTorch DataLoader will be attempted.
Information about the failure and fallback will be added to output log.