habana_frameworks.mediapipe.fn.Split

Class:
  • habana_frameworks.mediapipe.fn.Split(**kwargs)

Define graph call:
  • __call__(input)

Parameter:
  • input - Input tensor to operator. Supported dimensions: minimum = 1, maximum = 5. Supported data types: INT32, BFLOAT16, FLOAT32.

Description:

Splits a tensor into a list of tensors. The split is done along the specified ‘axis’. Lengths of the parts are specified in the sizes of the output parts.

Supported backend:
  • HPU

Keyword Arguments

kwargs

Description

axis

Axis along which tensors to be split.

  • Type: int

  • Default: 0

  • Optional: no

num_outputs

Number of split outputs.

  • Type: int

  • Default: 0

  • Optional: no

dtype

Output data type for every split.

  • Type: list of habana_frameworks.mediapipe.media_types.dtype

  • Default: UINT8

  • Optional: no

  • Supported data types:

    • INT32

    • FLOAT32

Note

  1. Input and output tensors must have the same data type.

  2. Aggregate size along aggregate dimension must match input tensor size in that dimension.

  3. Shape of output tensors must match input tensor in all dimensions except the split dimension.

  4. Currently, splitting into 10 tensors maximum is supported.

Example: Split Operator

The following code snippet shows usage of Split operator.

from habana_frameworks.mediapipe import fn
from habana_frameworks.mediapipe.mediapipe import MediaPipe
from habana_frameworks.mediapipe.media_types import dtype as dt
import numpy as np
import os


# Create MediaPipe derived class
class myMediaPipe(MediaPipe):
    def __init__(self, device, queue_depth, batch_size, num_threads,
                op_device, dir, num_split):
        super(
            myMediaPipe,
            self).__init__(
            device,
            queue_depth,
            batch_size,
            num_threads,
            self.__class__.__name__)
        self.num_split = num_split
        self.input0 = fn.ReadNumpyDatasetFromDir(num_outputs=1,
                                                shuffle=False,
                                                dir=dir,
                                                pattern='*x*.npy',
                                                dense=True,
                                                dtype=dt.FLOAT32,
                                                device="cpu")

        self.split = fn.Split(axis=3,
                              num_outputs=self.num_split,
                              dtype=[dt.FLOAT32]*self.num_split,
                              device=op_device)

    def definegraph(self):
        inp = self.input0()
        out0, out1 = self.split(inp)
        return inp, out0, out1


def run(device, op_device):
    batch_size = 2
    queue_depth = 2
    num_threads = 1
    num_split = 2
    base_dir = os.environ['DATASET_DIR']
    dir = base_dir+"/npy_data/fp32/"

    # Create MediaPipe object
    pipe = myMediaPipe(device, queue_depth, batch_size,
                      num_threads, op_device, dir, num_split)
    # Build MediaPipe
    pipe.build()

    # Initialize MediaPipe iterator
    pipe.iter_init()

    # Run MediaPipe
    inp, out0, out1 = pipe.run()

    def as_cpu(tensor):
        if (callable(getattr(tensor, "as_cpu", None))):
            tensor = tensor.as_cpu()
        return tensor

    # Copy data to host from device as numpy array
    out0 = as_cpu(out0).as_nparray()
    out1 = as_cpu(out1).as_nparray()
    inp = as_cpu(inp).as_nparray()

    # Display shapes
    print('input shape:', inp.shape)
    print('output0 shape:', out0.shape)
    print('output1 shape:', out1.shape)
    return inp, out0, out1


def compare_ref(inp, out0, out1):
    num_splits = 2
    ref0, ref1 = np.split(inp, num_splits, axis=0)
    if np.array_equal(ref0, out0) == False:
        raise ValueError(f"Mismatch w.r.t ref for device")
    if np.array_equal(ref1, out1) == False:
        raise ValueError(f"Mismatch w.r.t ref for device")


if __name__ == "__main__":
    dev_opdev = {'mixed': ['hpu'],
                'legacy': ['hpu']}
    for dev in dev_opdev.keys():
        for op_dev in dev_opdev[dev]:
            inp, out0, out1 = run(dev, op_dev)
            compare_ref(inp, out0, out1)

The following is the output of Split operator:

input shape: (2, 3, 2, 3)
output0 shape: (1, 3, 2, 3)
output1 shape: (1, 3, 2, 3)