fix buf of param init in custom op (#345)

Co-authored-by: zhouheng.zheng <zhouheng.zheng@ouotlook.com>
This commit is contained in:
Zhouheng Zheng 2022-04-06 17:21:54 +08:00 committed by GitHub
parent 70d2f410a8
commit b4091318ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 13 additions and 22 deletions

View File

@ -45,13 +45,6 @@ static vx_status derive_kernel_init(vx_node node, const vx_reference* param,
static std::map<void*, CustomOpBase*> node_base_map_;
namespace {
typedef struct DynamicArrayOfKernelParam_ {
uint32_t size_;
vsi_nn_kernel_node_param_t params_[1];
} DynamicArrayOfKernelParam;
}
CustomOpBase::CustomOpBase(Graph* graph, uint32_t input_num,
uint32_t output_num, int32_t kernel_id,
const char* kernel_name)
@ -169,33 +162,32 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
auto node = vsi_nn_KernelCreateNodeExt(self->graph, kernel, resource);
if (node) {
vsi_nn_kernel_node_param_t* node_params = nullptr;
DynamicArrayOfKernelParam* node_params_array = (DynamicArrayOfKernelParam*)malloc(sizeof(DynamicArrayOfKernelParam) + sizeof(vsi_nn_kernel_node_param_t)*(param_num - 1) );
node_params_array->size_ = param_num;
node_params = &node_params_array->params_[0];
vsi_nn_kernel_node_pack_io(node_params, param_num, inputs,
uint32_t input_start = op_this->input_num_ + op_this->output_num_;
std::vector<vsi_nn_kernel_node_param_t> node_params(param_num + input_start);
vsi_nn_kernel_node_param_t* node_params_ptr = node_params.data();
vsi_nn_kernel_node_pack_io(node_params_ptr, param_num, inputs,
op_this->input_num_, outputs,
op_this->output_num_);
uint32_t input_start = op_this->input_num_ + op_this->output_num_;
for (uint32_t i = 0; i < op_this->param_list_.size(); i++) {
if (op_this->param_list_[i].type == tim::vx::DataType::FLOAT32) {
node_params[input_start++] = vsi_nn_kernelScalarCreate(
node_params_ptr[input_start++] = vsi_nn_kernelScalarCreate(
self->graph, F32, &(op_this->param_list_[i].data.f));
} else if (op_this->param_list_[i].type == tim::vx::DataType::UINT32) {
node_params[input_start++] = vsi_nn_kernelScalarCreate(
node_params_ptr[input_start++] = vsi_nn_kernelScalarCreate(
self->graph, U32, &(op_this->param_list_[i].data.ui));
} else if (op_this->param_list_[i].type == tim::vx::DataType::INT32) {
node_params[input_start++] = vsi_nn_kernelScalarCreate(
node_params_ptr[input_start++] = vsi_nn_kernelScalarCreate(
self->graph, I32, &(op_this->param_list_[i].data.i));
} else if (op_this->param_list_[i].type == tim::vx::DataType::BOOL8) {
node_params[input_start++] = vsi_nn_kernelScalarCreate(
node_params_ptr[input_start++] = vsi_nn_kernelScalarCreate(
self->graph, BOOL8, &(op_this->param_list_[i].data.b));
}else if (op_this->param_list_[i].type == tim::vx::DataType::UINT8) {
node_params[input_start++] = vsi_nn_kernelScalarCreate(
node_params_ptr[input_start++] = vsi_nn_kernelScalarCreate(
self->graph, U8, &(op_this->param_list_[i].data.b));
} else if (op_this->param_list_[i].type == tim::vx::DataType::INT8) {
node_params[input_start++] = vsi_nn_kernelScalarCreate(
node_params_ptr[input_start++] = vsi_nn_kernelScalarCreate(
self->graph, I8, &(op_this->param_list_[i].data.b));
} else{
std::cout << "Can not find scalar type in op compute" << std::endl;
@ -204,12 +196,11 @@ vsi_bool op_compute(vsi_nn_node_t* self, vsi_nn_tensor_t** inputs,
}
input_start = op_this->input_num_ + op_this->output_num_;
status = vsi_nn_KernelNodePassParam(node, node_params, param_num);
status = vsi_nn_KernelNodePassParam(node, node_params_ptr, param_num);
for (uint32_t i = 0; i < param_num; i++) {
vsi_nn_kernel_scalar_release(&node_params[input_start + i]);
vsi_nn_kernel_scalar_release(&node_params_ptr[input_start + i]);
}
free(node_params_array);
}
self->n = (vx_node)node;