MediaPipe for PyTorch ResNet
On this Page
MediaPipe for PyTorch ResNet¶
This section describes how to define a ResnetMediaPipe class
and implement a MediaPipe API dataloader, MediaApiDataLoader, for PyTorch ResNet.
Defining a ResnetMediaPipe Class¶
The ResnetMediaPipe class is derived from MediaPipe class, as described in Creating and Executing Media Pipeline, implementing the MediaPipe API for ResNet.
An implementation of ResnetMediaPipe class can be found in resnet_media_pipe.py
in Intel Gaudi Model References GitHub repository.
The following set of operations is performed on the data:
Reading image using ReadImageDatasetFromDir
Decoding image using ImageDecoder
- For training data only:
 Random cropping as a part of ImageDecoder with
RANDOMIZED_AREA_AND_ASPECT_RATIO_CROPasrandom_crop_typeResizing as a part of ImageDecoder
Random flipping implemented using RandomFlip interface
- For evaluation data only:
 Resizing as a part of ImageDecoder
Cropping as a part of CropMirrorNorm
Normalizing using CropMirrorNorm to produce float32 output
API functions are first defined in the ResnetMediaPipe constructor.
Then, a sequence of operations is set up in the definegraph method:
def definegraph(self):
    jpegs, data = self.input()
    images = self.decode(jpegs)
    if self.is_training == True:
        flip = self.random_flip_input()
        images = self.random_flip(images, flip)
    mean = self.norm_mean()
    std = self.norm_std()
    images = self.cmn(images, mean, std)
    return images, data
Implementing MediaApiDataLoader¶
The MediaApiDataLoader includes an HPUResnetPytorchIterator interface, which iterates over image and label pairs produced by the ResnetMediaPipe instance.
This interface is compatible with torch.utils.data.DataLoader, allowing you
to re-use your PyTorch ResNet model code as written for torch.utils.data.DataLoader without additional changes.
The relevant code is defined in MediaApiDataLoader function located in
data_loaders.py:
class MediaApiDataLoader(torch.utils.data.DataLoader):
    def __init__(self, dataset, sampler, batch_size, num_workers, pin_memory=True, pin_memory_device=None, is_training=False):
        # setting parameters here, code removed for clarity
        from resnet_media_pipe import ResnetMediaPipe
        pipeline = ResnetMediaPipe(is_training=is_training, root=root, batch_size=batch_size,
                                   shuffle=self.shuffle, drop_last=False, queue_depth=queue_depth,
                                   num_instances=num_instances, instance_id=instance_id, device=device_string)
        from habana_frameworks.mediapipe.plugins.iterator_pytorch import HPUResnetPytorchIterator
        self.iterator = HPUResnetPytorchIterator(mediapipe=pipeline)
        print("Running with Media API DataLoader")
Selecting a DataLoader for PyTorch ResNet¶
In Intel Gaudi’s PyTorch ResNet, two dataloaders are implemented in the model:
MediaApiDataLoader
Native PyTorch DataLoader
The MediaApiDataLoader can be used only if the following conditions are met:
The
habana_media_loaderpackage is installed.The workload is run on Gaudi 2.
The
--dl-worker-typeargument for the PyTorch ResNet script is set to “HABANA” (the default value).
These conditions are handled by the choose_data_loader function in the
data_loaders.py file.
Choosing and creating a proper dataset is managed in build_data_loader in the same file:
def build_data_loader(is_training, dl_worker_type, **kwargs):
    data_loader_type = choose_data_loader(dl_worker_type)
    use_fallback = False
    try:
        if data_loader_type == DataLoaderType.MediaAPI:
            return MediaApiDataLoader(**kwargs, is_training=is_training)
        elif data_loader_type == DataLoaderType.Aeon:
            return AeonDataLoader(**kwargs)
    except Exception as e:
        if os.getenv('DATALOADER_FALLBACK_EN', "True") == "True":
            print(f"Failed to initialize Habana Dataloader, error: {str(e)}\nRunning with PyTorch Dataloader")
            return torch.utils.data.DataLoader(**kwargs)
        else:
            print(f"Habana dataloader configuration failed: {e}")
            raise
    if data_loader_type == DataLoaderType.Python:
        return torch.utils.data.DataLoader(**kwargs)
    else:
        raise ValueError(f"Unknown data_loader_type {data_loader_type}")
If all conditions are met,
data_loader_type == DataLoaderType.MediaAPI is obtained from data_loader_type.
If any errors are encountered when building the MediaApiDataLoader,
fallback to Habana Dataloader or the native PyTorch DataLoader will be attempted.
Information about the failure and fallback will be added to output log.