habana_frameworks.mediapipe.fn.Where

Class:
  • habana_frameworks.mediapipe.fn.Where(**kwargs)

Define graph call:
  • __call__(perdicate, input0, input1)

Parameter:
  • predicate - Predicate tensor to operator, size=[batch_size]. Supported dimensions: minimum = 1, maximum = 1. Supported data types: UINT8.

  • input0 - Input0 tensor, Supported dimensions: minimum = 1, maximum = 5. Supported data types: FLOAT16, FLOAT32, BFLOAT16.

  • input1 - Input1 tensor, Supported dimensions: minimum = 1, maximum = 5. Supported data types: FLOAT16, FLOAT32, BFLOAT16.

Description:

Outputs a tensor of elements selected from either input0 or input1, depending on predicate. If predicate is true then input0 is selected, else input1 is selected.

Supported backend:
  • HPU, CPU

Keyword Arguments

kwargs

Description

dtype

Output data type.

  • Type: habana_frameworks.mediapipe.media_types.dtype

  • Default: UINT8

  • Optional: yes

  • Supported data type:

    • FLOAT16

    • FLOAT32

    • BFLOAT16

Note

All input0, input1, output tensors must be of the same data type and must have the same dimensionality.

Example: Where Operator

The following code snippet shows usage of Where operator:

from habana_frameworks.mediapipe import fn
from habana_frameworks.mediapipe.mediapipe import MediaPipe
from habana_frameworks.mediapipe.media_types import dtype as dt
import numpy as np
import glob

# Create media pipeline derived class
class myMediaPipe(MediaPipe):
    def __init__(self, device, queue_depth, patch_size, num_channels, batch_size, num_threads, img_list, lbl_list, seed):
        super(myMediaPipe, self).__init__(
            device,
            queue_depth,
            batch_size,
            num_threads,
            self.__class__.__name__)

        self.images = fn.ReadNumpyDatasetFromDir(device=device,
                                                num_outputs=1,
                                                shuffle=False,
                                                shuffle_across_dataset=False,
                                                file_list=img_list,
                                                dtype=[dt.FLOAT32],
                                                dense=False,
                                                seed=seed,
                                                num_slices=1,
                                                slice_index=0,
                                                drop_remainder=True,
                                                pad_remainder=False
                                                )

        self.labels = fn.ReadNumpyDatasetFromDir(device=device,
                                                num_outputs=1,
                                                shuffle=False,
                                                shuffle_across_dataset=False,
                                                file_list=lbl_list,
                                                dtype=[dt.UINT8],
                                                dense=False,
                                                seed=seed,
                                                num_slices=1,
                                                slice_index=0,
                                                drop_remainder=True,
                                                pad_remainder=False
                                                )

        self.crop = fn.RandomBiasedCrop(patch_size=patch_size,
                                        num_channels=num_channels,
                                        seed=seed,
                                        num_workers=4,
                                        cache_bboxes=True,
                                        device=device,
                                        )

        self.mul = fn.Mult(device=device)

        self.const_val0 = fn.Constant(
            constant=0.2, dtype=dt.FLOAT32, device=device)

        self.const_val1 = fn.Constant(
            constant=0.3, dtype=dt.FLOAT32, device=device)

        self.const_val2 = fn.Constant(
            constant=0.7, dtype=dt.FLOAT32, device=device)

        self.coin_flip = fn.CoinFlip(seed=seed, dtype=dt.INT8, device=device)

        self.where = fn.Where(dtype=dt.FLOAT32, device=device)

    def definegraph(self):
        img = self.images()
        lbl = self.labels()
        img, lbl, coord = self.crop(img, lbl)
        prob = self.const_val0()
        # coin flip used to generate the condition tensor for where
        predicate = self.coin_flip(prob)
        scale1 = self.const_val1()
        scale2 = self.const_val2()
        scale = self.where(predicate, scale1, scale2)
        img_out = self.mul(img, scale)
        return img, scale, img_out


def main():
    batch_size = 1
    patch_size = [5, 5, 5]
    queue_depth = 2
    num_channels = 1
    num_threads = 1
    dir = "/path/to/numpy/files/"
    pattern0 = "case_*_x.npy"
    pattern1 = "case_*_y.npy"
    image_list = np.array(sorted(glob.glob(dir + "/{}".format(pattern0))))
    label_list = np.array(sorted(glob.glob(dir + "/{}".format(pattern1))))
    device = 'hpu'
    seed = 1234

    # Create media pipeline object
    pipe = myMediaPipe(device, queue_depth, patch_size,
                    num_channels, batch_size, num_threads, image_list, label_list, seed)

    # Build media pipeline
    pipe.build()

    # Initialize media pipeline iterator
    pipe.iter_init()
    bcnt = 0

    # Run media pipeline
    input, scale_selected, multiplication = pipe.run()

    if (device == 'cpu'):
        # Copy data as numpy array
        input = input.as_nparray()
        scale_selected = scale_selected.as_nparray()
        multiplication = multiplication.as_nparray()
    else:
        # Copy data to host from device as numpy array
        input = input.as_cpu().as_nparray()
        scale_selected = scale_selected.as_cpu().as_nparray()
        multiplication = multiplication.as_cpu().as_nparray()

    print("\nwhere op tensor shape:", scale_selected.shape)
    print("\nwhere op tensor dtype:", scale_selected.dtype)
    print("\nwhere op tensor data:", scale_selected)

if __name__ == "__main__":
    main()

The following is the output of Where operator:

where op tensor shape: (1,)

where op tensor dtype: float32

where op tensor data: [0.3]