fixed groupconv2d params in internal

Type: Bug Fix
Signed-off-by: Feiyue Chen <Feiyue.Chen@verisilicon.com>
This commit is contained in:
Feiyue Chen 2022-12-16 13:32:24 +08:00 committed by Sven
parent 7582b57edc
commit c6919248e1
1 changed files with 17 additions and 17 deletions

View File

@ -156,28 +156,28 @@ static vsi_status op_compute
p_ext = &p_ext2->ext; p_ext = &p_ext2->ext;
//set ext relative parameters //set ext relative parameters
p_ext->khr.padding_x = self->nn_param.conv2d.pad[0]; p_ext->khr.padding_x = self->nn_param.grouped_conv2d.pad[0];
p_ext->khr.padding_y = self->nn_param.conv2d.pad[2]; p_ext->khr.padding_y = self->nn_param.grouped_conv2d.pad[2];
if (self->nn_param.conv2d.dilation[0] > 0) if (self->nn_param.grouped_conv2d.dilation[0] > 0)
{ {
p_ext->khr.dilation_x = self->nn_param.conv2d.dilation[0] - 1; p_ext->khr.dilation_x = self->nn_param.grouped_conv2d.dilation[0] - 1;
} }
if (self->nn_param.conv2d.dilation[1] > 0) if (self->nn_param.grouped_conv2d.dilation[1] > 0)
{ {
p_ext->khr.dilation_y = self->nn_param.conv2d.dilation[1] - 1; p_ext->khr.dilation_y = self->nn_param.grouped_conv2d.dilation[1] - 1;
} }
p_ext->khr.overflow_policy = self->vx_param.overflow_policy; p_ext->khr.overflow_policy = self->vx_param.overflow_policy;
p_ext->khr.rounding_policy = self->vx_param.rounding_policy; p_ext->khr.rounding_policy = self->vx_param.rounding_policy;
p_ext->khr.down_scale_size_rounding = self->vx_param.down_scale_size_rounding; p_ext->khr.down_scale_size_rounding = self->vx_param.down_scale_size_rounding;
p_ext->padding_x_right = self->nn_param.conv2d.pad[1]; p_ext->padding_x_right = self->nn_param.grouped_conv2d.pad[1];
p_ext->padding_y_bottom = self->nn_param.conv2d.pad[3]; p_ext->padding_y_bottom = self->nn_param.grouped_conv2d.pad[3];
p_ext->pad_mode = vsi_nn_get_vx_pad_mode(nn_param->pad_mode); p_ext->pad_mode = vsi_nn_get_vx_pad_mode(nn_param->pad_mode);
//set ext2 relative parameters //set ext2 relative parameters
p_ext2->depth_multiplier = self->nn_param.conv2d.multiplier; p_ext2->depth_multiplier = self->nn_param.grouped_conv2d.multiplier;
p_ext2->stride_x = self->nn_param.conv2d.stride[0]; p_ext2->stride_x = self->nn_param.grouped_conv2d.stride[0];
p_ext2->stride_y = self->nn_param.conv2d.stride[1]; p_ext2->stride_y = self->nn_param.grouped_conv2d.stride[1];
if( inputs[2] == NULL ) if( inputs[2] == NULL )
{ {
@ -259,7 +259,7 @@ static vsi_bool op_setup
vsi_size_t i, pad[_cnt_of_array(nn_param->pad)] = {0}; vsi_size_t i, pad[_cnt_of_array(nn_param->pad)] = {0};
for(i = 0; i < _cnt_of_array(nn_param->pad); i++) for(i = 0; i < _cnt_of_array(nn_param->pad); i++)
{ {
pad[i] = self->nn_param.conv2d.pad[i]; pad[i] = self->nn_param.grouped_conv2d.pad[i];
} }
vsi_nn_compute_padding( vsi_nn_compute_padding(
inputs[0]->attr.size, inputs[0]->attr.size,
@ -271,7 +271,7 @@ static vsi_bool op_setup
); );
for(i = 0; i < _cnt_of_array(nn_param->pad); i++) for(i = 0; i < _cnt_of_array(nn_param->pad); i++)
{ {
self->nn_param.conv2d.pad[i] = (uint32_t)pad[i]; self->nn_param.grouped_conv2d.pad[i] = (uint32_t)pad[i];
} }
} }
@ -295,13 +295,13 @@ static vsi_bool op_setup
nn_param->dilation[1], nn_param->dilation[1],
VSI_NN_ROUND_FLOOR VSI_NN_ROUND_FLOOR
); );
if(self->nn_param.conv2d.weights > 0) if(self->nn_param.grouped_conv2d.weights > 0)
{ {
outputs[0]->attr.size[2] = self->nn_param.conv2d.weights; outputs[0]->attr.size[2] = self->nn_param.grouped_conv2d.weights;
} }
else if(self->nn_param.conv2d.multiplier > 0) else if(self->nn_param.grouped_conv2d.multiplier > 0)
{ {
outputs[0]->attr.size[2] = inputs[0]->attr.size[2] * self->nn_param.conv2d.multiplier; outputs[0]->attr.size[2] = inputs[0]->attr.size[2] * self->nn_param.grouped_conv2d.multiplier;
} }
else else
{ {