fix buf of param init in custom op (#345)
Co-authored-by: zhouheng.zheng <zhouheng.zheng@ouotlook.com>
This commit is contained in:
parent
70d2f410a8
commit
b4091318ea
|
|
@ -45,13 +45,6 @@ static vx_status derive_kernel_init(vx_node node, const vx_reference* param,
|
|||
|
||||
static std::map<void*, CustomOpBase*> node_base_map_;
|
||||
|
||||
namespace {
|
||||
typedef struct DynamicArrayOfKernelParam_ {
|
||||
uint32_t size_;
|
||||
vsi_nn_kernel_node_param_t params_[1];
|
||||
} DynamicArrayOfKernelParam;
|
||||
}
|
||||
|
||||
CustomOpBase::CustomOpBase(Graph* graph, uint32_t input_num,
|
||||
uint32_t output_num, int32_t kernel_id,
|
||||
const char* kernel_name)
|
||||
|
|
@ -169,33 +162,32 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
|
|||
|
||||
auto node = vsi_nn_KernelCreateNodeExt(self->graph, kernel, resource);
|
||||
if (node) {
|
||||
vsi_nn_kernel_node_param_t* node_params = nullptr;
|
||||
DynamicArrayOfKernelParam* node_params_array = (DynamicArrayOfKernelParam*)malloc(sizeof(DynamicArrayOfKernelParam) + sizeof(vsi_nn_kernel_node_param_t)*(param_num - 1) );
|
||||
node_params_array->size_ = param_num;
|
||||
node_params = &node_params_array->params_[0];
|
||||
vsi_nn_kernel_node_pack_io(node_params, param_num, inputs,
|
||||
uint32_t input_start = op_this->input_num_ + op_this->output_num_;
|
||||
|
||||
std::vector<vsi_nn_kernel_node_param_t> node_params(param_num + input_start);
|
||||
vsi_nn_kernel_node_param_t* node_params_ptr = node_params.data();
|
||||
vsi_nn_kernel_node_pack_io(node_params_ptr, param_num, inputs,
|
||||
op_this->input_num_, outputs,
|
||||
op_this->output_num_);
|
||||
|
||||
uint32_t input_start = op_this->input_num_ + op_this->output_num_;
|
||||
for (uint32_t i = 0; i < op_this->param_list_.size(); i++) {
|
||||
if (op_this->param_list_[i].type == tim::vx::DataType::FLOAT32) {
|
||||
node_params[input_start++] = vsi_nn_kernelScalarCreate(
|
||||
node_params_ptr[input_start++] = vsi_nn_kernelScalarCreate(
|
||||
self->graph, F32, &(op_this->param_list_[i].data.f));
|
||||
} else if (op_this->param_list_[i].type == tim::vx::DataType::UINT32) {
|
||||
node_params[input_start++] = vsi_nn_kernelScalarCreate(
|
||||
node_params_ptr[input_start++] = vsi_nn_kernelScalarCreate(
|
||||
self->graph, U32, &(op_this->param_list_[i].data.ui));
|
||||
} else if (op_this->param_list_[i].type == tim::vx::DataType::INT32) {
|
||||
node_params[input_start++] = vsi_nn_kernelScalarCreate(
|
||||
node_params_ptr[input_start++] = vsi_nn_kernelScalarCreate(
|
||||
self->graph, I32, &(op_this->param_list_[i].data.i));
|
||||
} else if (op_this->param_list_[i].type == tim::vx::DataType::BOOL8) {
|
||||
node_params[input_start++] = vsi_nn_kernelScalarCreate(
|
||||
node_params_ptr[input_start++] = vsi_nn_kernelScalarCreate(
|
||||
self->graph, BOOL8, &(op_this->param_list_[i].data.b));
|
||||
}else if (op_this->param_list_[i].type == tim::vx::DataType::UINT8) {
|
||||
node_params[input_start++] = vsi_nn_kernelScalarCreate(
|
||||
node_params_ptr[input_start++] = vsi_nn_kernelScalarCreate(
|
||||
self->graph, U8, &(op_this->param_list_[i].data.b));
|
||||
} else if (op_this->param_list_[i].type == tim::vx::DataType::INT8) {
|
||||
node_params[input_start++] = vsi_nn_kernelScalarCreate(
|
||||
node_params_ptr[input_start++] = vsi_nn_kernelScalarCreate(
|
||||
self->graph, I8, &(op_this->param_list_[i].data.b));
|
||||
} else{
|
||||
std::cout << "Can not find scalar type in op compute" << std::endl;
|
||||
|
|
@ -204,12 +196,11 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
|
|||
}
|
||||
|
||||
input_start = op_this->input_num_ + op_this->output_num_;
|
||||
status = vsi_nn_KernelNodePassParam(node, node_params, param_num);
|
||||
status = vsi_nn_KernelNodePassParam(node, node_params_ptr, param_num);
|
||||
for (uint32_t i = 0; i < param_num; i++) {
|
||||
vsi_nn_kernel_scalar_release(&node_params[input_start + i]);
|
||||
vsi_nn_kernel_scalar_release(&node_params_ptr[input_start + i]);
|
||||
}
|
||||
|
||||
free(node_params_array);
|
||||
}
|
||||
self->n = (vx_node)node;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue