fix bug of param num in custom op (#385)

ref to:https://github.com/VeriSilicon/TIM-VX/issues/378

Co-authored-by: zhouheng.zheng <zhouheng.zheng@ouotlook.com>
This commit is contained in:
Zhouheng Zheng 2022-05-05 17:04:38 +08:00 committed by GitHub
parent 3f629d3910
commit c09cdf79ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 5 deletions

View File

@ -100,6 +100,7 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
reinterpret_cast<CustomOpBase*>(self->nn_param.client_param);
uint32_t param_num = op_this->param_list_.size();
uint32_t input_start = op_this->input_num_ + op_this->output_num_;
std::vector<tim::vx::DataType> input_types;
for (uint32_t i = 0; i < op_this->input_num_; i++) {
@ -127,7 +128,7 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
snprintf(kernel->info.name, VX_MAX_KERNEL_NAME, "%s", op_this->func_name_);
kernel->unique_id =
std::hash<std::string>()(std::string(op_this->func_name_));
vx_param_description_t kernel_param_def[param_num];
vx_param_description_t kernel_param_def[param_num + input_start];
for (uint32_t i = 0; i < op_this->input_num_; i++) {
kernel_param_def[i] = {VX_INPUT, VX_TYPE_TENSOR,
@ -145,7 +146,7 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
kernel->info.parameters = kernel_param_def;
kernel->info.enumeration = KERNEL_ID_PLACEHOLDER;
kernel->info.numParams = param_num;
kernel->info.numParams = param_num + input_start;
kernel->info.initialize =
reinterpret_cast<vx_kernel_initialize_f>(op_this->init_kernel_);
@ -162,11 +163,10 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
auto node = vsi_nn_KernelCreateNodeExt(self->graph, kernel, resource);
if (node) {
uint32_t input_start = op_this->input_num_ + op_this->output_num_;
std::vector<vsi_nn_kernel_node_param_t> node_params(param_num + input_start);
vsi_nn_kernel_node_param_t* node_params_ptr = node_params.data();
vsi_nn_kernel_node_pack_io(node_params_ptr, param_num, inputs,
vsi_nn_kernel_node_pack_io(node_params_ptr, param_num + input_start, inputs,
op_this->input_num_, outputs,
op_this->output_num_);
@ -196,7 +196,7 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
}
input_start = op_this->input_num_ + op_this->output_num_;
status = vsi_nn_KernelNodePassParam(node, node_params_ptr, param_num);
status = vsi_nn_KernelNodePassParam(node, node_params_ptr, param_num + input_start);
for (uint32_t i = 0; i < param_num; i++) {
vsi_nn_kernel_scalar_release(&node_params_ptr[input_start + i]);
}