MediaPipe for PyTorch ResNet3d

This section describes how to implement myMediaPipeTrain class and myMediaPipeEval class for Pytorch ResNet3d.

Defining a myMediaPipeTrain Class for ResNet3d

The class myMediaPipeTrain class is derived from MediaPipe, as described in Creating and Executing Media Pipeline. The following set of operations is performed on training data:

API functions are first defined in the myMediaPipeTrain constructor. Then, a sequence of operations is set up in the definegraph method:

def definegraph(self):
    files, label, resample, vid_offset = self.input()

    crop_val = self.random_crop()
    video = self.decode(files, vid_offset, resample,
                        crop_val)

    video = self.reshape_hflip(video)
    is_hflip = self.is_hflip()
    video = self.img_hflip(video, is_hflip)

    video = self.reshape_cmn(video)
    std = self.std_node()
    mean = self.mean_node()
    video = self.cmn(video, mean, std)

    video = self.reshape_pre_transpose(video)
    video = self.transp(video)
    return video, label

Defining a myMediaPipeEval Class for ResNet3d

The class myMediaPipeEval is derived from MediaPipe, as described in Creating and Executing Media Pipeline. The following set of operations is performed on validation data:

API functions are first defined in the myMediaPipeEval constructor. Then, a sequence of operations is set up in the definegraph method:

def definegraph(self):
    files, label, resample, vid_offset = self.input()
    video = self.decode(files, vid_offset, resample)

    video = self.reshape_cmn(video)
    std = self.std_node()
    mean = self.mean_node()
    video = self.cmn(video, mean, std)

    video = self.reshape_pre_transpose(video)
    video = self.transp(video)
    return video, label

The following code is the complete implementation of myMediaPipeTrain and myMediaPipeEval:

import numpy as np
import math
import os

from habana_frameworks.mediapipe import fn
from habana_frameworks.mediapipe.mediapipe import MediaPipe
from habana_frameworks.mediapipe.operators.cpu_nodes.cpu_nodes import media_function

from habana_frameworks.mediapipe.plugins.iterator_pytorch import Resnet3dPytorchIterator

from habana_frameworks.mediapipe.media_types import imgtype as it
from habana_frameworks.mediapipe.media_types import dtype as dt
from habana_frameworks.mediapipe.media_types import randomCropType as rct
from habana_frameworks.mediapipe.media_types import clipSampler as cs

g_iter = 10
g_batch_size = 16
g_num_slices = 1
g_slice_index = 0
g_seed = 100

g_resize_w = 171
g_resize_h = 128
g_crop_w = 112
g_crop_h = 112

g_vid_max_frame_rate = 30.0
g_target_frame_rate = 15
g_frame_per_clip = 16

g_flip_priv_params = {
    'prob': 0.5
}

g_rgb_mean_values = [0.43216, 0.394666, 0.37645]
g_rgb_std_values = [0.22803, 0.22145, 0.216989]
g_rgb_multiplier = 255


class random_crop_func(media_function):
    def __init__(self, params):
        self.np_shape = params['shape'][::-1]
        self.np_dtype = params['dtype']
        self.batch_size = self.np_shape[0]

        self.seed = params['seed'] + params['unique_number']
        self.priv_params = params['priv_params']
        self.resize_width = self.priv_params['input_w']
        self.resize_height = self.priv_params['input_h']
        self.crop_width = self.priv_params['crop_w']
        self.crop_height = self.priv_params['crop_h']
        self.rng = np.random.default_rng(self.seed)

    def __call__(self):
        a = np.zeros(shape=self.np_shape, dtype=self.np_dtype)
        x_val = self.rng.integers(low=0, high=(
            self.resize_width - self.crop_width + 1), size=self.batch_size, dtype=self.np_dtype)
        y_val = self.rng.integers(low=0, high=(
            self.resize_height - self.crop_height + 1), size=self.batch_size, dtype=self.np_dtype)

        for i in range(self.batch_size):
            a[i] = x_val[i], y_val[i], self.crop_width, self.crop_height
        return a


class random_flip_func(media_function):
    def __init__(self, params):
        self.np_shape = params['shape'][::-1]
        self.np_dtype = params['dtype']
        self.seed = params['seed'] + params['unique_number']
        self.priv_params = params['priv_params']
        self.prob = self.priv_params['prob']
        self.rng = np.random.default_rng(self.seed)

    def __call__(self):
        a = self.rng.choice([0, 1],
                            p=[(1-self.prob), self.prob],
                            size=self.np_shape)
        a = np.array(a, dtype=self.np_dtype)
        return a



def get_dec_max_frame(max_frame_rate, target_frame_rate, frame_per_clip):
    frame_rate_ratio = float(max_frame_rate) / target_frame_rate
    dec_max_frame = math.ceil(frame_rate_ratio * (frame_per_clip - 1)) + 1
    return dec_max_frame


def round_up(num, round_to):
    num_round = ((num + round_to - 1) // round_to) * round_to
    return num_round


class myMediaPipeTrain(MediaPipe):
    def __init__(self, device, queue_depth, batch_size, dir):
        print("Media Train Pipe")
        resize_width = g_resize_w
        resize_height = g_resize_h
        channels = 3
        super(myMediaPipeTrain, self).__init__(device=device,
                                            prefetch_depth=queue_depth,
                                            batch_size=batch_size,
                                            pipe_name=self.__class__.__name__)

        self.input = fn.ReadVideoDatasetFromDir(dir=dir,
                                                format="mp4",
                                                frames_per_clip=g_frame_per_clip,
                                                seed=g_seed,
                                                label_dtype=dt.UINT32,
                                                drop_remainder=False,
                                                clips_per_video=5,
                                                target_frame_rate=15,
                                                num_slices=g_num_slices,
                                                slice_index=g_slice_index,
                                                sampler=cs.RANDOM_SAMPLER)

        priv_params = {}
        priv_params['input_w'] = resize_width
        priv_params['input_h'] = resize_height
        priv_params['crop_w'] = g_crop_w
        priv_params['crop_h'] = g_crop_h

        self.random_crop = fn.MediaFunc(func=random_crop_func,
                                        dtype=dt.UINT32,
                                        shape=[4, batch_size],
                                        priv_params=priv_params,
                                        seed=g_seed)

        self.decode = fn.VideoDecoder(output_format=it.RGB_P,
                                    random_crop_type=rct.NO_RANDOM_CROP,
                                    resize=[resize_width, resize_height],
                                    crop_after_resize=[
                                        0, 0, g_crop_w, g_crop_h],
                                    frames_per_clip=g_frame_per_clip)

        self.reshape_hflip = fn.Reshape(
            size=[g_crop_w, g_crop_h, channels * g_frame_per_clip, batch_size], tensorDim=4, layout='')

        self.is_hflip = fn.MediaFunc(func=random_flip_func,
                                    shape=[batch_size],
                                    dtype=dt.UINT8,
                                    seed=g_seed,
                                    priv_params=g_flip_priv_params)

        self.img_hflip = fn.RandomFlip(horizontal=1, dtype=dt.UINT8)

        self.reshape_cmn = fn.Reshape(
            size=[g_crop_w, g_crop_h, channels, batch_size * g_frame_per_clip], tensorDim=4, layout='')

        mean_data = np.array(
            [m * g_rgb_multiplier for m in g_rgb_mean_values], dtype=np.float32)
        std_data = np.array([1 / (s * g_rgb_multiplier)
                            for s in g_rgb_std_values], dtype=np.float32)

        self.std_node = fn.MediaConst(
            data=std_data, shape=[1, 1, 3], batch_broadcast=False, dtype=dt.FLOAT32)
        self.mean_node = fn.MediaConst(
            data=mean_data, shape=[1, 1, 3], batch_broadcast=False, dtype=dt.FLOAT32)

        self.cmn = fn.CropMirrorNorm(
            crop_w=g_crop_w, crop_h=g_crop_h, dtype=dt.FLOAT32)
        self.reshape_pre_transpose = fn.Reshape(
            size=[g_crop_w, g_crop_h, channels, g_frame_per_clip, batch_size], tensorDim=5, layout='', dtype=dt.FLOAT32)

        self.transp = fn.Transpose(
            permutation=[0, 1, 3, 2, 4], tensorDim=5, dtype=dt.FLOAT32)

    def definegraph(self):
        files, label, resample, vid_offset = self.input()

        crop_val = self.random_crop()
        video = self.decode(files, vid_offset, resample,
                            crop_val)

        video = self.reshape_hflip(video)
        is_hflip = self.is_hflip()
        video = self.img_hflip(video, is_hflip)

        video = self.reshape_cmn(video)
        std = self.std_node()
        mean = self.mean_node()
        video = self.cmn(video, mean, std)

        video = self.reshape_pre_transpose(video)
        video = self.transp(video)
        return video, label


class myMediaPipeEval(MediaPipe):
    def __init__(self, device, queue_depth, batch_size, dir):
        print("Media Eval Pipe")
        resize_width = g_resize_w
        resize_height = g_resize_h
        channels = 3
        super(myMediaPipeEval, self).__init__(device=device,
                                            prefetch_depth=queue_depth,
                                            batch_size=batch_size,
                                            pipe_name=self.__class__.__name__)

        self.input = fn.ReadVideoDatasetFromDir(dir=dir,
                                                format="mp4",
                                                frames_per_clip=g_frame_per_clip,
                                                seed=g_seed,
                                                label_dtype=dt.UINT32,
                                                drop_remainder=False,
                                                clips_per_video=5,
                                                target_frame_rate=15,
                                                num_slices=g_num_slices,
                                                slice_index=g_slice_index,
                                                sampler=cs.UNIFORM_SAMPLER)

        dec_max_frame_per_clip = get_dec_max_frame(
            g_vid_max_frame_rate, g_target_frame_rate, g_frame_per_clip)

        resize_width = round_up(resize_width, 2)

        print("eval VideoDecoder max_frame_vid: {} resize: w {} h {}".format(
            dec_max_frame_per_clip, resize_width, resize_height))

        self.decode = fn.VideoDecoder(output_format=it.RGB_P,
                                    random_crop_type=rct.NO_RANDOM_CROP,
                                    resize=[resize_width, resize_height],
                                    frames_per_clip=g_frame_per_clip,
                                    max_frame_vid=dec_max_frame_per_clip,
                                    dpb_size=0)

        self.reshape_cmn = fn.Reshape(
            size=[resize_width, resize_height, channels, batch_size * g_frame_per_clip], tensorDim=4, layout='')

        mean_data = np.array(
            [m * g_rgb_multiplier for m in g_rgb_mean_values], dtype=np.float32)
        std_data = np.array([1 / (s * g_rgb_multiplier)
                            for s in g_rgb_std_values], dtype=np.float32)

        self.std_node = fn.MediaConst(
            data=std_data, shape=[1, 1, 3], batch_broadcast=False, dtype=dt.FLOAT32)
        self.mean_node = fn.MediaConst(
            data=mean_data, shape=[1, 1, 3], batch_broadcast=False, dtype=dt.FLOAT32)

        cmn_pos_offset_x = 0.5
        cmn_pos_offset_y = 0.5
        self.cmn = fn.CropMirrorNorm(
            crop_w=g_crop_w, crop_h=g_crop_h, crop_pos_x=cmn_pos_offset_x, crop_pos_y=cmn_pos_offset_y, dtype=dt.FLOAT32)

        self.reshape_pre_transpose = fn.Reshape(
            size=[g_crop_w, g_crop_h, channels, g_frame_per_clip, batch_size], tensorDim=5, layout='', dtype=dt.FLOAT32)

        self.transp = fn.Transpose(
            permutation=[0, 1, 3, 2, 4], tensorDim=5, dtype=dt.FLOAT32)

    def definegraph(self):
        files, label, resample, vid_offset = self.input()

        video = self.decode(files, vid_offset, resample)

        video = self.reshape_cmn(video)
        std = self.std_node()
        mean = self.mean_node()
        video = self.cmn(video, mean, std)

        video = self.reshape_pre_transpose(video)
        video = self.transp(video)
        return video, label


def main():
    batch_size = g_batch_size
    queue_depth = 3

    # Train mediapipe
    base_dir = os.environ['VID_DATASET_DIR']
    dir_train = base_dir + "/train/"
    pipe_train = myMediaPipeTrain("legacy", queue_depth, batch_size, dir_train)

    iterator_train = Resnet3dPytorchIterator(mediapipe=pipe_train)

    # Eval mediapipe
    base_dir = os.environ['VID_DATASET_DIR']
    dir_val = base_dir + "/val/"
    pipe_val = myMediaPipeEval("legacy", queue_depth, batch_size, dir_val)
    iterator_val = Resnet3dPytorchIterator(mediapipe=pipe_val)

    iter_dict = {
        iterator_train: "train",
        iterator_val: "val"
    }

    for epoch in range(2):
        for iterator in iter_dict.keys():
            print("Iter ", iter_dict[iterator])
            bcnt = 0
            for video, label in iterator:
                bcnt += 1

                video_cpu = video.to('cpu').numpy()
                label_cpu = label.to('cpu').numpy()

                print("bcnt {} video shape {} labels shape {} labels {} ".format(bcnt, video_cpu.shape,
                    label_cpu.shape, label_cpu))

                if bcnt == g_iter:
                    break

    print("End of Test")


if __name__ == "__main__":
    main()