habana_frameworks.mediapipe.fn.Constant

Class:
  • habana_frameworks.mediapipe.fn.Constant(**kwargs)

Define graph call:
  • __call__()

Parameter:

None

Description:

Constant operator generates scalar constant value tensor.

Supported backend:
  • HPU

Keyword Arguments

kwargs

Description

constant

Constant value.

  • Type: float

  • Default: 0.0

  • Optional: no

dtype

Output data type.

  • Type: habana_frameworks.mediapipe.media_types.dtype

  • Default: UINT8

  • Optional: yes

  • Supported data types:

    • INT8

    • UINT8

    • BFLOAT16

    • FLOAT32

Example: Constant Operator

The following code snippet shows usage of Constant operator:

from habana_frameworks.mediapipe import fn
from habana_frameworks.mediapipe.mediapipe import MediaPipe
from habana_frameworks.mediapipe.media_types import dtype as dt
import numpy as np
import os

# Create MediaPipe derived class


class myMediaPipe(MediaPipe):
    def __init__(self, device, queue_depth, batch_size, num_threads, op_device, dir):
        super(myMediaPipe, self).__init__(
            device,
            queue_depth,
            batch_size,
            num_threads,
            self.__class__.__name__)

        self.inp = fn.ReadNumpyDatasetFromDir(num_outputs=1,
                                              shuffle=False,
                                              dir=dir,
                                              pattern="inp_x_*.npy",
                                              dense=True,
                                              dtype=dt.FLOAT32,
                                              device="cpu")

        self.const = fn.Constant(constant=0.5,
                                dtype=dt.FLOAT32,
                                device=op_device)

        self.mul = fn.Mult(device=op_device)

    def definegraph(self):
        inp = self.inp()
        const = self.const()
        out = self.mul(inp, const)
        return out, inp, const


def run(device, op_device):
    batch_size = 1
    queue_depth = 2
    num_threads = 1
    base_dir = os.environ['DATASET_DIR']
    dir = base_dir+"/npy_data/fp32/"

    # Create MediaPipe object
    pipe = myMediaPipe(device, queue_depth, batch_size,
                      num_threads, op_device, dir)

    # Build MediaPipe
    pipe.build()

    # Initialize MediaPipe iterator
    pipe.iter_init()

    # Run MediaPipe
    out, inp, const = pipe.run()

    def as_cpu(tensor):
        if (callable(getattr(tensor, "as_cpu", None))):
            tensor = tensor.as_cpu()
        return tensor

    # Copy data to host from device as numpy array
    const = as_cpu(const).as_nparray()
    inp = as_cpu(inp).as_nparray()
    out = as_cpu(out).as_nparray()

    print("\nconst tensor shape:", const.shape)
    print("const tensor dtype:", const.dtype)
    print("const tensor data:\n", const)

    print("\ninp tensor shape:", inp.shape)
    print("inp tensor dtype:", inp.dtype)
    print("inp tensor data:\n", inp)

    print("\nout tensor shape:", out.shape)
    print("out tensor dtype:", out.dtype)
    print("out tensor data:\n", out)
    return inp, const, out

def compare_ref(inp, const, out):
    ref = inp * const
    if np.array_equal(ref, out) == False:
        breakpoint()
        raise ValueError(f"Mismatch w.r.t ref for device")


if __name__ == "__main__":
    dev_opdev = {'mixed': ['hpu'],
                'legacy': ['hpu']}
    for dev in dev_opdev.keys():
        for op_dev in dev_opdev[dev]:
            inp, const, out = run(dev, op_dev)
            compare_ref(inp, const, out)

The following is the output of Constant operator:

const tensor shape: (1,)
const tensor dtype: float32
const tensor data:
[0.5]

inp tensor shape: (1, 3, 2, 3)
inp tensor dtype: float32
inp tensor data:
[[[[182. 227. 113.]
  [175. 128. 253.]]

  [[ 58. 140. 136.]
  [ 86.  80. 111.]]

  [[175. 196. 178.]
  [ 20. 163. 108.]]]]

out tensor shape: (1, 3, 2, 3)
out tensor dtype: float32
out tensor data:
[[[[ 91.  113.5  56.5]
  [ 87.5  64.  126.5]]

  [[ 29.   70.   68. ]
  [ 43.   40.   55.5]]

  [[ 87.5  98.   89. ]
  [ 10.   81.5  54. ]]]]