MediaPipe for PyTorch ResNet

This section describes how to define a ResnetMediaPipe class and implement a MediaPipe API dataloader, MediaApiDataLoader, for PyTorch ResNet.

Defining a ResnetMediaPipe Class

The ResnetMediaPipe class is derived from MediaPipe class, as described in Creating and Executing Media Pipeline, implementing the MediaPipe API for ResNet. An implementation of ResnetMediaPipe class can be found in in Intel Gaudi Model References GitHub repository.

The following set of operations is performed on the data:

API functions are first defined in the ResnetMediaPipe constructor. Then, a sequence of operations is set up in the definegraph method:

def definegraph(self):
    jpegs, data = self.input()
    images = self.decode(jpegs)

    if self.is_training == True:
        flip = self.random_flip_input()
        images = self.random_flip(images, flip)

    mean = self.norm_mean()
    std = self.norm_std()
    images = self.cmn(images, mean, std)

    return images, data

Implementing MediaApiDataLoader

The MediaApiDataLoader includes an HPUResnetPytorchIterator interface, which iterates over image and label pairs produced by the ResnetMediaPipe instance. This interface is compatible with, allowing you to re-use your PyTorch ResNet model code as written for without additional changes.

The relevant code is defined in MediaApiDataLoader function located in

class MediaApiDataLoader(
    def __init__(self, dataset, sampler, batch_size, num_workers, pin_memory=True, pin_memory_device=None, is_training=False):

        # setting parameters here, code removed for clarity

        from resnet_media_pipe import ResnetMediaPipe
        pipeline = ResnetMediaPipe(is_training=is_training, root=root, batch_size=batch_size,
                                   shuffle=self.shuffle, drop_last=False, queue_depth=queue_depth,
                                   num_instances=num_instances, instance_id=instance_id, device=device_string)

        from habana_frameworks.mediapipe.plugins.iterator_pytorch import HPUResnetPytorchIterator
        self.iterator = HPUResnetPytorchIterator(mediapipe=pipeline)
        print("Running with Media API DataLoader")

Selecting a DataLoader for PyTorch ResNet

In Intel Gaudi’s PyTorch ResNet, three dataloaders are implemented in the model:

  • MediaApiDataLoader

  • Native PyTorch DataLoader

  • Habana Dataloader


Habana Dataloader (AeonDataLoader) is a software acceleration dataloader used for first-gen Gaudi, where Media API acceleration is not available. The PyTorch ResNet implementation of Aeon is covered in AeonDataLoader in the script.

The MediaApiDataLoader can be used only if the following conditions are met:

  • The habana_media_loader package is installed.

  • The workload is run on Gaudi 2. First-gen Gaudi does not support Media API.

  • The --dl-worker-type argument for the PyTorch ResNet script is set to “HABANA” (the default value).

These conditions are handled by the choose_data_loader function in the file. Choosing and creating a proper dataset is managed in build_data_loader in the same file:

def build_data_loader(is_training, dl_worker_type, **kwargs):
    data_loader_type = choose_data_loader(dl_worker_type)
    use_fallback = False

        if data_loader_type == DataLoaderType.MediaAPI:
            return MediaApiDataLoader(**kwargs, is_training=is_training)
        elif data_loader_type == DataLoaderType.Aeon:
            return AeonDataLoader(**kwargs)
    except Exception as e:
        if os.getenv('DATALOADER_FALLBACK_EN', "True") == "True":
            print(f"Failed to initialize Habana Dataloader, error: {str(e)}\nRunning with PyTorch Dataloader")
            print(f"Habana dataloader configuration failed: {e}")

    if data_loader_type == DataLoaderType.Python:
        raise ValueError(f"Unknown data_loader_type {data_loader_type}")

If all conditions are met, data_loader_type == DataLoaderType.MediaAPI is obtained from data_loader_type. If any errors are encountered when building the MediaApiDataLoader, fallback to Habana Dataloader or the native PyTorch DataLoader will be attempted. Information about the failure and fallback will be added to output log.