Media Pipe for PyTorch ResNet
On this Page
Media Pipe for PyTorch ResNet¶
This section describes how to define a ResnetMediaPipe
class
and implement a Media pipe API dataloader, MediaApiDataLoader
, for PyTorch ResNet.
Defining a ResnetMediaPipe Class¶
ResnetMediaPipe
class is derived from MediaPipe
class, as described in Creating and Executing Media Pipeline, implementing Media Pipe API for ResNet.
An implementation of ResnetMediaPipe
class can be found in resnet_media_pipe.py
in Habana’s Model References GitHub page.
The following set of operations is performed on the data:
Read image using ReadImageDatasetFromDir
Decode image using ImageDecoder
- For training data only:
Random crop as a part of ImageDecoder with RANDOMIZED_AREA_AND_ASPECT_RATIO_CROP as random_crop_type
Resize as a part of ImageDecoder
Random flip implemented using RandomFlip interface
- For evaluation data only:
Resize as a part of ImageDecoder
Crop as a part of CropMirrorNorm
Normalize using CropMirrorNorm producing float32 output
API functions are first defined in the ResnetMediaPipe
constructor.
A sequence of operations is then set up in the definegraph
method:
def definegraph(self):
jpegs, data = self.input()
images = self.decode(jpegs)
if self.is_training == True:
flip = self.random_flip_input()
images = self.random_flip(images, flip)
mean = self.norm_mean()
std = self.norm_std()
images = self.cmn(images, mean, std)
return images, data
Implementing MediaApiDataLoader¶
The MediaApiDataLoader
includes a HPUResnetPytorchIterator
interface which iterates over image, label pairs produced by ResnetMediaPipe
instance.
This interface is compatible with torch.utils.data.DataLoader
, allowing you
to re-use your PyTorch ResNet model code as written for torch.utils.data.DataLoader
without additional changes.
The relevant code is defined in MediaApiDataLoader
function located in
data_loaders.py:
class MediaApiDataLoader(torch.utils.data.DataLoader):
def __init__(self, dataset, sampler, batch_size, num_workers, pin_memory=True, pin_memory_device=None, is_training=False):
# setting parameters here, code removed for clarity
from resnet_media_pipe import ResnetMediaPipe
pipeline = ResnetMediaPipe(is_training=is_training, root=root, batch_size=batch_size,
shuffle=self.shuffle, drop_last=False, queue_depth=queue_depth,
num_instances=num_instances, instance_id=instance_id, device=device_string)
from habana_frameworks.mediapipe.plugins.iterator_pytorch import HPUResnetPytorchIterator
self.iterator = HPUResnetPytorchIterator(mediapipe=pipeline)
print("Running with Media API DataLoader")
Selecting a DataLoader for PyTorch ResNet¶
In Habana’s PyTorch ResNet, three DataLoaders are implemented in the model:
MediaApiDataLoader
Native PyTorch DataLoader
Habana DataLoader
Note
Habana DataLoader (AeonDataLoader) is a software acceleration data loader used for first-Gen Gaudi, where Media API acceleration is not available.
PyTorch ResNet implementation of Aeon is covered in AeonDataLoader
in data_loaders.py.
MediaApiDataLoader
can be used only if the following conditions are met:
hpu_media_loader
package is installed.Workload is being run on Gaudi2. First-gen Gaudi does not support Media API.
PT_HPU_MEDIA_PIPE
environment variable is not defined or not set to 0, “f”, or “false” (case insensitive).--dl-worker-type
argument for PyTorch ResNet script is set to “HABANA” (the default value).
The above conditions are covered by choose_data_loader
function in
data_loaders.py.
Choosing and creating a proper dataset is handled in build_data_loader
in the same file:
def build_data_loader(is_training, dl_worker_type, **kwargs):
data_loader_type = choose_data_loader(dl_worker_type)
use_fallback = False
try:
if data_loader_type == DataLoaderType.MediaAPI:
return MediaApiDataLoader(**kwargs, is_training=is_training)
elif data_loader_type == DataLoaderType.Aeon:
return AeonDataLoader(**kwargs)
except Exception as e:
if os.getenv('DATALOADER_FALLBACK_EN', "True") == "True":
print(f"Failed to initialize Habana Dataloader, error: {str(e)}\nRunning with PyTorch Dataloader")
return torch.utils.data.DataLoader(**kwargs)
else:
print(f"Habana dataloader configuration failed: {e}")
raise
if data_loader_type == DataLoaderType.Python:
return torch.utils.data.DataLoader(**kwargs)
else:
raise ValueError(f"Unknown data_loader_type {data_loader_type}")
If all conditions are met,
data_loader_type == DataLoaderType.MediaAPI
is obtained from data_loader_type
.
If any errors are encountered when building MediaApiDataLoader
,
fallback to Habana DataLoader or native PyTorch DataLoader will be attempted.
Information about the failure and fallback will be added to output log.