Media Pipe for PyTorch ResNet

This section describes how to define a ResnetMediaPipe class and implement a Media pipe API dataloader, MediaApiDataLoader, for PyTorch ResNet.

Defining a ResnetMediaPipe Class

ResnetMediaPipe class is derived from MediaPipe class, as described in Creating and Executing Media Pipeline, implementing Media Pipe API for ResNet. An implementation of ResnetMediaPipe class can be found in in Intel Gaudi Model References GitHub repository.

The following set of operations is performed on the data:

API functions are first defined in the ResnetMediaPipe constructor. A sequence of operations is then set up in the definegraph method:

def definegraph(self):
    jpegs, data = self.input()
    images = self.decode(jpegs)

    if self.is_training == True:
        flip = self.random_flip_input()
        images = self.random_flip(images, flip)

    mean = self.norm_mean()
    std = self.norm_std()
    images = self.cmn(images, mean, std)

    return images, data

Implementing MediaApiDataLoader

The MediaApiDataLoader includes a HPUResnetPytorchIterator interface which iterates over image, label pairs produced by ResnetMediaPipe instance. This interface is compatible with, allowing you to re-use your PyTorch ResNet model code as written for without additional changes.

The relevant code is defined in MediaApiDataLoader function located in

class MediaApiDataLoader(
    def __init__(self, dataset, sampler, batch_size, num_workers, pin_memory=True, pin_memory_device=None, is_training=False):

        # setting parameters here, code removed for clarity

        from resnet_media_pipe import ResnetMediaPipe
        pipeline = ResnetMediaPipe(is_training=is_training, root=root, batch_size=batch_size,
                                   shuffle=self.shuffle, drop_last=False, queue_depth=queue_depth,
                                   num_instances=num_instances, instance_id=instance_id, device=device_string)

        from habana_frameworks.mediapipe.plugins.iterator_pytorch import HPUResnetPytorchIterator
        self.iterator = HPUResnetPytorchIterator(mediapipe=pipeline)
        print("Running with Media API DataLoader")

Selecting a DataLoader for PyTorch ResNet

In Intel Gaudi’s PyTorch ResNet, three DataLoaders are implemented in the model:

  • MediaApiDataLoader

  • Native PyTorch DataLoader

  • Habana Dataloader


Habana Dataloader (AeonDataLoader) is a software acceleration data loader used for first-gen Gaudi, where Media API acceleration is not available. PyTorch ResNet implementation of Aeon is covered in AeonDataLoader in

MediaApiDataLoader can be used only if the following conditions are met:

  • habana_media_loader package is installed.

  • Workload is being run on Gaudi 2. First-gen Gaudi does not support Media API.

  • --dl-worker-type argument for PyTorch ResNet script is set to “HABANA” (the default value).

The above conditions are covered by choose_data_loader function in Choosing and creating a proper dataset is handled in build_data_loader in the same file:

def build_data_loader(is_training, dl_worker_type, **kwargs):
    data_loader_type = choose_data_loader(dl_worker_type)
    use_fallback = False

        if data_loader_type == DataLoaderType.MediaAPI:
            return MediaApiDataLoader(**kwargs, is_training=is_training)
        elif data_loader_type == DataLoaderType.Aeon:
            return AeonDataLoader(**kwargs)
    except Exception as e:
        if os.getenv('DATALOADER_FALLBACK_EN', "True") == "True":
            print(f"Failed to initialize Habana Dataloader, error: {str(e)}\nRunning with PyTorch Dataloader")
            print(f"Habana dataloader configuration failed: {e}")

    if data_loader_type == DataLoaderType.Python:
        raise ValueError(f"Unknown data_loader_type {data_loader_type}")

If all conditions are met, data_loader_type == DataLoaderType.MediaAPI is obtained from data_loader_type. If any errors are encountered when building MediaApiDataLoader, fallback to Habana Dataloader or native PyTorch DataLoader will be attempted. Information about the failure and fallback will be added to output log.