From b4091318ea860800c0375812188e17e2eaebefcf Mon Sep 17 00:00:00 2001 From: Zhouheng Zheng Date: Wed, 6 Apr 2022 17:21:54 +0800 Subject: [PATCH] fix buf of param init in custom op (#345) Co-authored-by: zhouheng.zheng --- src/tim/vx/ops/custom_base.cc | 35 +++++++++++++---------------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/src/tim/vx/ops/custom_base.cc b/src/tim/vx/ops/custom_base.cc index d8e748c..2ff185d 100644 --- a/src/tim/vx/ops/custom_base.cc +++ b/src/tim/vx/ops/custom_base.cc @@ -45,13 +45,6 @@ static vx_status derive_kernel_init(vx_node node, const vx_reference* param, static std::map node_base_map_; -namespace { - typedef struct DynamicArrayOfKernelParam_ { - uint32_t size_; - vsi_nn_kernel_node_param_t params_[1]; - } DynamicArrayOfKernelParam; -} - CustomOpBase::CustomOpBase(Graph* graph, uint32_t input_num, uint32_t output_num, int32_t kernel_id, const char* kernel_name) @@ -169,33 +162,32 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs, auto node = vsi_nn_KernelCreateNodeExt(self->graph, kernel, resource); if (node) { - vsi_nn_kernel_node_param_t* node_params = nullptr; - DynamicArrayOfKernelParam* node_params_array = (DynamicArrayOfKernelParam*)malloc(sizeof(DynamicArrayOfKernelParam) + sizeof(vsi_nn_kernel_node_param_t)*(param_num - 1) ); - node_params_array->size_ = param_num; - node_params = &node_params_array->params_[0]; - vsi_nn_kernel_node_pack_io(node_params, param_num, inputs, + uint32_t input_start = op_this->input_num_ + op_this->output_num_; + + std::vector node_params(param_num + input_start); + vsi_nn_kernel_node_param_t* node_params_ptr = node_params.data(); + vsi_nn_kernel_node_pack_io(node_params_ptr, param_num, inputs, op_this->input_num_, outputs, op_this->output_num_); - uint32_t input_start = op_this->input_num_ + op_this->output_num_; for (uint32_t i = 0; i < op_this->param_list_.size(); i++) { if (op_this->param_list_[i].type == tim::vx::DataType::FLOAT32) { - node_params[input_start++] = vsi_nn_kernelScalarCreate( + node_params_ptr[input_start++] = vsi_nn_kernelScalarCreate( self->graph, F32, &(op_this->param_list_[i].data.f)); } else if (op_this->param_list_[i].type == tim::vx::DataType::UINT32) { - node_params[input_start++] = vsi_nn_kernelScalarCreate( + node_params_ptr[input_start++] = vsi_nn_kernelScalarCreate( self->graph, U32, &(op_this->param_list_[i].data.ui)); } else if (op_this->param_list_[i].type == tim::vx::DataType::INT32) { - node_params[input_start++] = vsi_nn_kernelScalarCreate( + node_params_ptr[input_start++] = vsi_nn_kernelScalarCreate( self->graph, I32, &(op_this->param_list_[i].data.i)); } else if (op_this->param_list_[i].type == tim::vx::DataType::BOOL8) { - node_params[input_start++] = vsi_nn_kernelScalarCreate( + node_params_ptr[input_start++] = vsi_nn_kernelScalarCreate( self->graph, BOOL8, &(op_this->param_list_[i].data.b)); }else if (op_this->param_list_[i].type == tim::vx::DataType::UINT8) { - node_params[input_start++] = vsi_nn_kernelScalarCreate( + node_params_ptr[input_start++] = vsi_nn_kernelScalarCreate( self->graph, U8, &(op_this->param_list_[i].data.b)); } else if (op_this->param_list_[i].type == tim::vx::DataType::INT8) { - node_params[input_start++] = vsi_nn_kernelScalarCreate( + node_params_ptr[input_start++] = vsi_nn_kernelScalarCreate( self->graph, I8, &(op_this->param_list_[i].data.b)); } else{ std::cout << "Can not find scalar type in op compute" << std::endl; @@ -204,12 +196,11 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs, } input_start = op_this->input_num_ + op_this->output_num_; - status = vsi_nn_KernelNodePassParam(node, node_params, param_num); + status = vsi_nn_KernelNodePassParam(node, node_params_ptr, param_num); for (uint32_t i = 0; i < param_num; i++) { - vsi_nn_kernel_scalar_release(&node_params[input_start + i]); + vsi_nn_kernel_scalar_release(&node_params_ptr[input_start + i]); } - free(node_params_array); } self->n = (vx_node)node;