class habana::custom_op::UserCustomOpDescriptor

Overview

The descriptor contains all the necessary information to define custom TPC op run on PyTorch with Gaudi:

/**
* Descriptor for custom op containing all necessary information to
* define user HPU TPC kernel.
*
* @param schema Name of registered torch operator
* @param guid Name of TPC kernel called by the operator
* @param output_meta_fn Callback for output metadata calculation
* @param fill_params_fn Callback filling TPC kernel params structure if
* necessary
*/
class UserCustomOpDescriptor {
public:
    UserCustomOpDescriptor(
        const std::string& schema,
        const std::string& guid,
        OutputMetaFn output_meta_fn,
        FillParamsFn fill_params_fn = nullptr)
        : schema_(schema),
            guid_(guid),
            output_meta_fn_(output_meta_fn),
            fill_params_fn_(fill_params_fn) {}
    UserCustomOpDescriptor() {}

    /**
    * Actual call by user C++ to op
    *
    * @param inputs All values by order to op execution
    *
    * @return Vector of op results.
    */
    std::vector<at::Tensor> execute(const std::vector<c10::IValue>& inputs);

    /**
    * Get the Custom Op Descriptor object
    *
    * @param op schema registration name which is used in
    * registerUserCustomOp
    *
    * @return Custom op descriptor.
    */
    static const UserCustomOpDescriptor& getUserCustomOpDescriptor(
        const std::string& op);

    /**
    * Schema name as used in TORCH_LIBRARY.
    */
    const std::string& getSchemaName() const;

    /**
    * TPC kernel GUID.
    */
    const std::string& getGuid() const;

    /**
    * Callback to calculate output tensors metadata
    */
    const OutputMetaFn& getOutputMetaFn() const;

    /**
    * Callback to allocate and set user params.
    */
    const FillParamsFn& getFillParamsFn() const;
    };