fix bug of param num in custom op (#385)
ref to:https://github.com/VeriSilicon/TIM-VX/issues/378 Co-authored-by: zhouheng.zheng <zhouheng.zheng@ouotlook.com>
This commit is contained in:
parent
3f629d3910
commit
c09cdf79ad
|
|
@ -100,6 +100,7 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
|
||||||
reinterpret_cast<CustomOpBase*>(self->nn_param.client_param);
|
reinterpret_cast<CustomOpBase*>(self->nn_param.client_param);
|
||||||
|
|
||||||
uint32_t param_num = op_this->param_list_.size();
|
uint32_t param_num = op_this->param_list_.size();
|
||||||
|
uint32_t input_start = op_this->input_num_ + op_this->output_num_;
|
||||||
|
|
||||||
std::vector<tim::vx::DataType> input_types;
|
std::vector<tim::vx::DataType> input_types;
|
||||||
for (uint32_t i = 0; i < op_this->input_num_; i++) {
|
for (uint32_t i = 0; i < op_this->input_num_; i++) {
|
||||||
|
|
@ -127,7 +128,7 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
|
||||||
snprintf(kernel->info.name, VX_MAX_KERNEL_NAME, "%s", op_this->func_name_);
|
snprintf(kernel->info.name, VX_MAX_KERNEL_NAME, "%s", op_this->func_name_);
|
||||||
kernel->unique_id =
|
kernel->unique_id =
|
||||||
std::hash<std::string>()(std::string(op_this->func_name_));
|
std::hash<std::string>()(std::string(op_this->func_name_));
|
||||||
vx_param_description_t kernel_param_def[param_num];
|
vx_param_description_t kernel_param_def[param_num + input_start];
|
||||||
|
|
||||||
for (uint32_t i = 0; i < op_this->input_num_; i++) {
|
for (uint32_t i = 0; i < op_this->input_num_; i++) {
|
||||||
kernel_param_def[i] = {VX_INPUT, VX_TYPE_TENSOR,
|
kernel_param_def[i] = {VX_INPUT, VX_TYPE_TENSOR,
|
||||||
|
|
@ -145,7 +146,7 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
|
||||||
|
|
||||||
kernel->info.parameters = kernel_param_def;
|
kernel->info.parameters = kernel_param_def;
|
||||||
kernel->info.enumeration = KERNEL_ID_PLACEHOLDER;
|
kernel->info.enumeration = KERNEL_ID_PLACEHOLDER;
|
||||||
kernel->info.numParams = param_num;
|
kernel->info.numParams = param_num + input_start;
|
||||||
kernel->info.initialize =
|
kernel->info.initialize =
|
||||||
reinterpret_cast<vx_kernel_initialize_f>(op_this->init_kernel_);
|
reinterpret_cast<vx_kernel_initialize_f>(op_this->init_kernel_);
|
||||||
|
|
||||||
|
|
@ -162,11 +163,10 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
|
||||||
|
|
||||||
auto node = vsi_nn_KernelCreateNodeExt(self->graph, kernel, resource);
|
auto node = vsi_nn_KernelCreateNodeExt(self->graph, kernel, resource);
|
||||||
if (node) {
|
if (node) {
|
||||||
uint32_t input_start = op_this->input_num_ + op_this->output_num_;
|
|
||||||
|
|
||||||
std::vector<vsi_nn_kernel_node_param_t> node_params(param_num + input_start);
|
std::vector<vsi_nn_kernel_node_param_t> node_params(param_num + input_start);
|
||||||
vsi_nn_kernel_node_param_t* node_params_ptr = node_params.data();
|
vsi_nn_kernel_node_param_t* node_params_ptr = node_params.data();
|
||||||
vsi_nn_kernel_node_pack_io(node_params_ptr, param_num, inputs,
|
vsi_nn_kernel_node_pack_io(node_params_ptr, param_num + input_start, inputs,
|
||||||
op_this->input_num_, outputs,
|
op_this->input_num_, outputs,
|
||||||
op_this->output_num_);
|
op_this->output_num_);
|
||||||
|
|
||||||
|
|
@ -196,7 +196,7 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
|
||||||
}
|
}
|
||||||
|
|
||||||
input_start = op_this->input_num_ + op_this->output_num_;
|
input_start = op_this->input_num_ + op_this->output_num_;
|
||||||
status = vsi_nn_KernelNodePassParam(node, node_params_ptr, param_num);
|
status = vsi_nn_KernelNodePassParam(node, node_params_ptr, param_num + input_start);
|
||||||
for (uint32_t i = 0; i < param_num; i++) {
|
for (uint32_t i = 0; i < param_num; i++) {
|
||||||
vsi_nn_kernel_scalar_release(&node_params_ptr[input_start + i]);
|
vsi_nn_kernel_scalar_release(&node_params_ptr[input_start + i]);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue