diff --git a/include/tim/vx/context.h b/include/tim/vx/context.h index 4922094..684e890 100644 --- a/include/tim/vx/context.h +++ b/include/tim/vx/context.h @@ -29,7 +29,7 @@ namespace tim { namespace vx { -struct CompileOption; +class CompileOption; class Graph; class Context { public: diff --git a/src/tim/vx/ops/custom_base.cc b/src/tim/vx/ops/custom_base.cc index 0f8b1b7..28c223f 100644 --- a/src/tim/vx/ops/custom_base.cc +++ b/src/tim/vx/ops/custom_base.cc @@ -45,6 +45,13 @@ static vx_status derive_kernel_init(vx_node node, const vx_reference* param, static std::map node_base_map_; +namespace { + typedef struct DynamicArrayOfKernelParam_ { + uint32_t size_; + vsi_nn_kernel_node_param_t params_[1]; + } DynamicArrayOfKernelParam; +} + CustomOpBase::CustomOpBase(Graph* graph, uint32_t input_num, uint32_t output_num, int32_t kernel_id, const char* kernel_name) @@ -162,7 +169,10 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs, auto node = vsi_nn_KernelCreateNodeExt(self->graph, kernel, resource); if (node) { - vsi_nn_kernel_node_param_t node_params[param_num] = {NULL}; + vsi_nn_kernel_node_param_t* node_params = nullptr; + DynamicArrayOfKernelParam* node_params_array = (DynamicArrayOfKernelParam*)malloc(sizeof(DynamicArrayOfKernelParam) + sizeof(vsi_nn_kernel_node_param_t)*(param_num - 1) ); + node_params_array->size_ = param_num; + node_params = &node_params_array->params_[0]; vsi_nn_kernel_node_pack_io(node_params, param_num, inputs, op_this->input_num_, outputs, op_this->output_num_); @@ -198,6 +208,8 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs, for (uint32_t i = 0; i < param_num; i++) { vsi_nn_kernel_scalar_release(&node_params[input_start + i]); } + + free(node_params_array); } self->n = (vx_node)node;