habana_frameworks.mediapipe.fn.Where
habana_frameworks.mediapipe.fn.Where¶
- Class:
habana_frameworks.mediapipe.fn.Where(**kwargs)
- Define graph call:
__call__(perdicate, input0, input1)
- Parameter:
predicate - Predicate tensor to operator, size=[batch_size]. Supported dimensions: minimum = 1, maximum = 1. Supported data types: UINT8.
input0 - Input0 tensor, Supported dimensions: minimum = 1, maximum = 5. Supported data types: FLOAT16, FLOAT32, BFLOAT16.
input1 - Input1 tensor, Supported dimensions: minimum = 1, maximum = 5. Supported data types: FLOAT16, FLOAT32, BFLOAT16.
Description:
Outputs a tensor of elements selected from either input0 or input1, depending on predicate. If predicate is true then input0 is selected, else input1 is selected.
- Supported backend:
HPU, CPU
Keyword Arguments
kwargs |
Description |
---|---|
dtype |
Output data type.
|
Note
All input0, input1, output tensors must be of the same data type and must have the same dimensionality.
Example: Where Operator
The following code snippet shows usage of Where operator:
from habana_frameworks.mediapipe import fn
from habana_frameworks.mediapipe.mediapipe import MediaPipe
from habana_frameworks.mediapipe.media_types import dtype as dt
# Create media pipeline derived class
class myMediaPipe(MediaPipe):
def __init__(self, device, queue_depth, batch_size, num_threads, dir):
super(myMediaPipe, self).__init__(
device,
queue_depth,
batch_size,
num_threads,
self.__class__.__name__)
self.inp = fn.ReadNumpyDatasetFromDir(num_outputs=1,
shuffle=False,
dir=dir,
pattern="inp_x_*.npy",
dense=True,
dtype=dt.FLOAT32,
device=device)
self.const = fn.Constant(
constant=0.7, dtype=dt.FLOAT32, device=device)
self.scale1 = fn.Constant(
constant=0.3, dtype=dt.FLOAT32, device=device)
self.scale2 = fn.Constant(
constant=0.7, dtype=dt.FLOAT32, device=device)
self.coin_flip = fn.CoinFlip(seed=100,
device=device)
self.where = fn.Where(dtype=dt.FLOAT32,
device=device)
self.mul = fn.Mult(device=device)
def definegraph(self):
inp = self.inp()
scale1 = self.scale1()
scale2 = self.scale2()
probability = self.const()
predicate = self.coin_flip(probability)
scale = self.where(predicate, scale1, scale2)
out = self.mul(inp, scale)
return inp, scale, predicate, out
def main():
batch_size = 1
queue_depth = 2
num_threads = 1
device = 'cpu'
dir = '/path/to/numpy/files/'
# Create media pipeline object
pipe = myMediaPipe(device, queue_depth, batch_size, num_threads, dir)
# Build media pipeline
pipe.build()
# Initialize media pipeline iterator
pipe.iter_init()
# Run media pipeline
inp, scale, predicate, out = pipe.run()
if (device == 'cpu'):
# Copy data as numpy array
inp = inp.as_nparray()
scale = scale.as_nparray()
predicate = predicate.as_nparray()
out = out.as_nparray()
else:
# Copy data to host from device as numpy array
inp = inp.as_cpu().as_nparray()
scale = scale.as_cpu().as_nparray()
predicate = predicate.as_cpu().as_nparray()
out = out.as_cpu().as_nparray()
print("\ninp tensor shape:", inp.shape)
print("inp tensor dtype:", inp.dtype)
print("inp tensor data:\n", inp)
print("\nscale tensor shape:", scale.shape)
print("scale tensor dtype:", scale.dtype)
print("scale tensor data:\n", scale)
print("\npredicate tensor shape:", predicate.shape)
print("predicate tensor dtype:", predicate.dtype)
print("predicate tensor data:\n", predicate)
print("\nout tensor shape:", out.shape)
print("out tensor dtype:", out.dtype)
print("out tensor data:\n", out)
pipe.del_iter()
if __name__ == "__main__":
main()
The following is the output of Where operator:
inp tensor shape: (1, 3, 2, 3)
inp tensor dtype: float32
inp tensor data:
[[[[182. 227. 113.]
[175. 128. 253.]]
[[ 58. 140. 136.]
[ 86. 80. 111.]]
[[175. 196. 178.]
[ 20. 163. 108.]]]]
scale tensor shape: (1,)
scale tensor dtype: float32
scale tensor data:
[0.3]
predicate tensor shape: (1,)
predicate tensor dtype: uint8
predicate tensor data:
[1]
out tensor shape: (1, 3, 2, 3)
out tensor dtype: float32
out tensor data:
[[[[54.600002 68.100006 33.9 ]
[52.500004 38.4 75.9 ]]
[[17.400002 42. 40.800003]
[25.800001 24. 33.300003]]
[[52.500004 58.800003 53.4 ]
[ 6. 48.9 32.4 ]]]]