Media Pipe for PyTorch ResNet

This section describes how to define a ResnetMediaPipe class and implement a Media pipe API dataloader, MediaApiDataLoader, for PyTorch ResNet.

Defining a ResnetMediaPipe Class

ResnetMediaPipe class is derived from MediaPipe class, as described in Creating and Executing Media Pipeline, implementing Media Pipe API for ResNet. An implementation of ResnetMediaPipe class can be found in resnet_media_pipe.py in Habana’s Model References GitHub page.

The following set of operations is performed on the data:

API functions are first defined in the ResnetMediaPipe constructor. A sequence of operations is then set up in the definegraph method:

def definegraph(self):
    jpegs, data = self.input()
    images = self.decode(jpegs)

    if self.is_training == True:
        flip = self.random_flip_input()
        images = self.random_flip(images, flip)

    mean = self.norm_mean()
    std = self.norm_std()
    images = self.cmn(images, mean, std)

    return images, data

Implementing MediaApiDataLoader

The MediaApiDataLoader includes a HPUResnetPytorchIterator interface which iterates over image, label pairs produced by ResnetMediaPipe instance. This interface is compatible with torch.utils.data.DataLoader, allowing you to re-use your PyTorch ResNet model code as written for torch.utils.data.DataLoader without additional changes.

The relevant code is defined in MediaApiDataLoader function located in data_loaders.py:

class MediaApiDataLoader(torch.utils.data.DataLoader):
    def __init__(self, dataset, sampler, batch_size, num_workers, pin_memory=True, pin_memory_device=None, is_training=False):

        # setting parameters here, code removed for clarity

        from resnet_media_pipe import ResnetMediaPipe
        pipeline = ResnetMediaPipe(is_training=is_training, root=root, batch_size=batch_size,
                                   shuffle=self.shuffle, drop_last=False, queue_depth=queue_depth,
                                   num_instances=num_instances, instance_id=instance_id, device=device_string)

        from habana_frameworks.mediapipe.plugins.iterator_pytorch import HPUResnetPytorchIterator
        self.iterator = HPUResnetPytorchIterator(mediapipe=pipeline)
        print("Running with Media API DataLoader")

Selecting a DataLoader for PyTorch ResNet

In Habana’s PyTorch ResNet, three DataLoaders are implemented in the model:

  • MediaApiDataLoader

  • Native PyTorch DataLoader

  • Habana DataLoader

Note

Habana DataLoader (AeonDataLoader) is a software acceleration data loader used for first-Gen Gaudi, where Media API acceleration is not available. PyTorch ResNet implementation of Aeon is covered in AeonDataLoader in data_loaders.py.

MediaApiDataLoader can be used only if the following conditions are met:

  • habana_media_loader package is installed.

  • Workload is being run on Gaudi2. First-gen Gaudi does not support Media API.

  • --dl-worker-type argument for PyTorch ResNet script is set to “HABANA” (the default value).

The above conditions are covered by choose_data_loader function in data_loaders.py. Choosing and creating a proper dataset is handled in build_data_loader in the same file:

def build_data_loader(is_training, dl_worker_type, **kwargs):
    data_loader_type = choose_data_loader(dl_worker_type)
    use_fallback = False

    try:
        if data_loader_type == DataLoaderType.MediaAPI:
            return MediaApiDataLoader(**kwargs, is_training=is_training)
        elif data_loader_type == DataLoaderType.Aeon:
            return AeonDataLoader(**kwargs)
    except Exception as e:
        if os.getenv('DATALOADER_FALLBACK_EN', "True") == "True":
            print(f"Failed to initialize Habana Dataloader, error: {str(e)}\nRunning with PyTorch Dataloader")
            return torch.utils.data.DataLoader(**kwargs)
        else:
            print(f"Habana dataloader configuration failed: {e}")
            raise

    if data_loader_type == DataLoaderType.Python:
        return torch.utils.data.DataLoader(**kwargs)
    else:
        raise ValueError(f"Unknown data_loader_type {data_loader_type}")

If all conditions are met, data_loader_type == DataLoaderType.MediaAPI is obtained from data_loader_type. If any errors are encountered when building MediaApiDataLoader, fallback to Habana DataLoader or native PyTorch DataLoader will be attempted. Information about the failure and fallback will be added to output log.