habana_frameworks.mediapipe.fn.Where
habana_frameworks.mediapipe.fn.Where¶
- Class:
habana_frameworks.mediapipe.fn.Where(**kwargs)
- Define graph call:
__call__(perdicate, input0, input1)
- Parameter:
predicate - Predicate tensor to operator, size=[batch_size]. Supported dimensions: minimum = 1, maximum = 1. Supported data types: UINT8.
input0 - Input0 tensor, Supported dimensions: minimum = 1, maximum = 5. Supported data types: FLOAT16, FLOAT32, BFLOAT16.
input1 - Input1 tensor, Supported dimensions: minimum = 1, maximum = 5. Supported data types: FLOAT16, FLOAT32, BFLOAT16.
Description:
Outputs a tensor of elements selected from either input0 or input1, depending on predicate. If predicate is true then input0 is selected, else input1 is selected.
- Supported backend:
HPU, CPU
Keyword Arguments
kwargs |
Description |
---|---|
dtype |
Output data type.
|
Note
All input0, input1, output tensors must be of the same data type and must have the same dimensionality.
Example: Where Operator
The following code snippet shows usage of Where operator:
from habana_frameworks.mediapipe import fn
from habana_frameworks.mediapipe.mediapipe import MediaPipe
from habana_frameworks.mediapipe.media_types import dtype as dt
import numpy as np
import glob
# Create media pipeline derived class
class myMediaPipe(MediaPipe):
def __init__(self, device, queue_depth, patch_size, num_channels, batch_size, num_threads, img_list, lbl_list, seed):
super(myMediaPipe, self).__init__(
device,
queue_depth,
batch_size,
num_threads,
self.__class__.__name__)
self.images = fn.ReadNumpyDatasetFromDir(device=device,
num_outputs=1,
shuffle=False,
shuffle_across_dataset=False,
file_list=img_list,
dtype=[dt.FLOAT32],
dense=False,
seed=seed,
num_slices=1,
slice_index=0,
drop_remainder=True,
pad_remainder=False
)
self.labels = fn.ReadNumpyDatasetFromDir(device=device,
num_outputs=1,
shuffle=False,
shuffle_across_dataset=False,
file_list=lbl_list,
dtype=[dt.UINT8],
dense=False,
seed=seed,
num_slices=1,
slice_index=0,
drop_remainder=True,
pad_remainder=False
)
self.crop = fn.RandomBiasedCrop(patch_size=patch_size,
num_channels=num_channels,
seed=seed,
num_workers=4,
cache_bboxes=True,
device=device,
)
self.mul = fn.Mult(device=device)
self.const_val0 = fn.Constant(
constant=0.2, dtype=dt.FLOAT32, device=device)
self.const_val1 = fn.Constant(
constant=0.3, dtype=dt.FLOAT32, device=device)
self.const_val2 = fn.Constant(
constant=0.7, dtype=dt.FLOAT32, device=device)
self.coin_flip = fn.CoinFlip(seed=seed, dtype=dt.INT8, device=device)
self.where = fn.Where(dtype=dt.FLOAT32, device=device)
def definegraph(self):
img = self.images()
lbl = self.labels()
img, lbl, coord = self.crop(img, lbl)
prob = self.const_val0()
# coin flip used to generate the condition tensor for where
predicate = self.coin_flip(prob)
scale1 = self.const_val1()
scale2 = self.const_val2()
scale = self.where(predicate, scale1, scale2)
img_out = self.mul(img, scale)
return img, scale, img_out
def main():
batch_size = 1
patch_size = [5, 5, 5]
queue_depth = 2
num_channels = 1
num_threads = 1
dir = "/path/to/numpy/files/"
pattern0 = "case_*_x.npy"
pattern1 = "case_*_y.npy"
image_list = np.array(sorted(glob.glob(dir + "/{}".format(pattern0))))
label_list = np.array(sorted(glob.glob(dir + "/{}".format(pattern1))))
device = 'hpu'
seed = 1234
# Create media pipeline object
pipe = myMediaPipe(device, queue_depth, patch_size,
num_channels, batch_size, num_threads, image_list, label_list, seed)
# Build media pipeline
pipe.build()
# Initialize media pipeline iterator
pipe.iter_init()
bcnt = 0
# Run media pipeline
input, scale_selected, multiplication = pipe.run()
if (device == 'cpu'):
# Copy data as numpy array
input = input.as_nparray()
scale_selected = scale_selected.as_nparray()
multiplication = multiplication.as_nparray()
else:
# Copy data to host from device as numpy array
input = input.as_cpu().as_nparray()
scale_selected = scale_selected.as_cpu().as_nparray()
multiplication = multiplication.as_cpu().as_nparray()
print("\nwhere op tensor shape:", scale_selected.shape)
print("\nwhere op tensor dtype:", scale_selected.dtype)
print("\nwhere op tensor data:", scale_selected)
if __name__ == "__main__":
main()
The following is the output of Where operator:
where op tensor shape: (1,)
where op tensor dtype: float32
where op tensor data: [0.3]