fix bug of param num in custom op (#385)
ref to:https://github.com/VeriSilicon/TIM-VX/issues/378 Co-authored-by: zhouheng.zheng <zhouheng.zheng@ouotlook.com>
This commit is contained in:
parent
3f629d3910
commit
c09cdf79ad
|
|
@ -100,6 +100,7 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
|
|||
reinterpret_cast<CustomOpBase*>(self->nn_param.client_param);
|
||||
|
||||
uint32_t param_num = op_this->param_list_.size();
|
||||
uint32_t input_start = op_this->input_num_ + op_this->output_num_;
|
||||
|
||||
std::vector<tim::vx::DataType> input_types;
|
||||
for (uint32_t i = 0; i < op_this->input_num_; i++) {
|
||||
|
|
@ -127,7 +128,7 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
|
|||
snprintf(kernel->info.name, VX_MAX_KERNEL_NAME, "%s", op_this->func_name_);
|
||||
kernel->unique_id =
|
||||
std::hash<std::string>()(std::string(op_this->func_name_));
|
||||
vx_param_description_t kernel_param_def[param_num];
|
||||
vx_param_description_t kernel_param_def[param_num + input_start];
|
||||
|
||||
for (uint32_t i = 0; i < op_this->input_num_; i++) {
|
||||
kernel_param_def[i] = {VX_INPUT, VX_TYPE_TENSOR,
|
||||
|
|
@ -145,7 +146,7 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
|
|||
|
||||
kernel->info.parameters = kernel_param_def;
|
||||
kernel->info.enumeration = KERNEL_ID_PLACEHOLDER;
|
||||
kernel->info.numParams = param_num;
|
||||
kernel->info.numParams = param_num + input_start;
|
||||
kernel->info.initialize =
|
||||
reinterpret_cast<vx_kernel_initialize_f>(op_this->init_kernel_);
|
||||
|
||||
|
|
@ -162,11 +163,10 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
|
|||
|
||||
auto node = vsi_nn_KernelCreateNodeExt(self->graph, kernel, resource);
|
||||
if (node) {
|
||||
uint32_t input_start = op_this->input_num_ + op_this->output_num_;
|
||||
|
||||
std::vector<vsi_nn_kernel_node_param_t> node_params(param_num + input_start);
|
||||
vsi_nn_kernel_node_param_t* node_params_ptr = node_params.data();
|
||||
vsi_nn_kernel_node_pack_io(node_params_ptr, param_num, inputs,
|
||||
vsi_nn_kernel_node_pack_io(node_params_ptr, param_num + input_start, inputs,
|
||||
op_this->input_num_, outputs,
|
||||
op_this->output_num_);
|
||||
|
||||
|
|
@ -196,7 +196,7 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
|
|||
}
|
||||
|
||||
input_start = op_this->input_num_ + op_this->output_num_;
|
||||
status = vsi_nn_KernelNodePassParam(node, node_params_ptr, param_num);
|
||||
status = vsi_nn_KernelNodePassParam(node, node_params_ptr, param_num + input_start);
|
||||
for (uint32_t i = 0; i < param_num; i++) {
|
||||
vsi_nn_kernel_scalar_release(&node_params_ptr[input_start + i]);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue