diff --git a/src/tim/vx/ops/custom_base.cc b/src/tim/vx/ops/custom_base.cc index 2ff185d..d26aa13 100644 --- a/src/tim/vx/ops/custom_base.cc +++ b/src/tim/vx/ops/custom_base.cc @@ -100,6 +100,7 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs, reinterpret_cast(self->nn_param.client_param); uint32_t param_num = op_this->param_list_.size(); + uint32_t input_start = op_this->input_num_ + op_this->output_num_; std::vector input_types; for (uint32_t i = 0; i < op_this->input_num_; i++) { @@ -127,7 +128,7 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs, snprintf(kernel->info.name, VX_MAX_KERNEL_NAME, "%s", op_this->func_name_); kernel->unique_id = std::hash()(std::string(op_this->func_name_)); - vx_param_description_t kernel_param_def[param_num]; + vx_param_description_t kernel_param_def[param_num + input_start]; for (uint32_t i = 0; i < op_this->input_num_; i++) { kernel_param_def[i] = {VX_INPUT, VX_TYPE_TENSOR, @@ -145,7 +146,7 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs, kernel->info.parameters = kernel_param_def; kernel->info.enumeration = KERNEL_ID_PLACEHOLDER; - kernel->info.numParams = param_num; + kernel->info.numParams = param_num + input_start; kernel->info.initialize = reinterpret_cast(op_this->init_kernel_); @@ -162,11 +163,10 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs, auto node = vsi_nn_KernelCreateNodeExt(self->graph, kernel, resource); if (node) { - uint32_t input_start = op_this->input_num_ + op_this->output_num_; std::vector node_params(param_num + input_start); vsi_nn_kernel_node_param_t* node_params_ptr = node_params.data(); - vsi_nn_kernel_node_pack_io(node_params_ptr, param_num, inputs, + vsi_nn_kernel_node_pack_io(node_params_ptr, param_num + input_start, inputs, op_this->input_num_, outputs, op_this->output_num_); @@ -196,7 +196,7 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs, } input_start = op_this->input_num_ + op_this->output_num_; - status = vsi_nn_KernelNodePassParam(node, node_params_ptr, param_num); + status = vsi_nn_KernelNodePassParam(node, node_params_ptr, param_num + input_start); for (uint32_t i = 0; i < param_num; i++) { vsi_nn_kernel_scalar_release(&node_params_ptr[input_start + i]); }