class habana::custom_op::UserCustomOpDescriptor
On this Page
class habana::custom_op::UserCustomOpDescriptor¶
Overview¶
The descriptor contains all the necessary information to define custom TPC op run on PyTorch with Gaudi:
/**
* Descriptor for custom op containing all necessary information to
* define user HPU TPC kernel.
*
* @param schema Name of registered torch operator
* @param guid Name of TPC kernel called by the operator
* @param output_meta_fn Callback for output metadata calculation
* @param fill_params_fn Callback filling TPC kernel params structure if
* necessary
*/
class UserCustomOpDescriptor {
public:
UserCustomOpDescriptor(
const std::string& schema,
const std::string& guid,
OutputMetaFn output_meta_fn,
FillParamsFn fill_params_fn = nullptr)
: schema_(schema),
guid_(guid),
output_meta_fn_(output_meta_fn),
fill_params_fn_(fill_params_fn) {}
UserCustomOpDescriptor() {}
/**
* Actual call by user C++ to op
*
* @param inputs All values by order to op execution
*
* @return Vector of op results.
*/
std::vector<at::Tensor> execute(const std::vector<c10::IValue>& inputs);
/**
* Get the Custom Op Descriptor object
*
* @param op schema registration name which is used in
* registerUserCustomOp
*
* @return Custom op descriptor.
*/
static const UserCustomOpDescriptor& getUserCustomOpDescriptor(
const std::string& op);
/**
* Schema name as used in TORCH_LIBRARY.
*/
const std::string& getSchemaName() const;
/**
* TPC kernel GUID.
*/
const std::string& getGuid() const;
/**
* Callback to calculate output tensors metadata
*/
const OutputMetaFn& getOutputMetaFn() const;
/**
* Callback to allocate and set user params.
*/
const FillParamsFn& getFillParamsFn() const;
};