diff --git a/src/tim/vx/internal/include/ops/vsi_nn_op_expand_broadcast.h b/src/tim/vx/internal/include/ops/vsi_nn_op_expand_broadcast.h index 93e1d3c..e9ce86a 100644 --- a/src/tim/vx/internal/include/ops/vsi_nn_op_expand_broadcast.h +++ b/src/tim/vx/internal/include/ops/vsi_nn_op_expand_broadcast.h @@ -34,6 +34,8 @@ typedef struct _vsi_nn_expand_broadcast_param { uint32_t *shape; uint32_t dim_num; + uint32_t *dimensions; + uint32_t dimensions_num; } vsi_nn_expand_broadcast_param; #ifdef __cplusplus diff --git a/src/tim/vx/internal/src/ops/vsi_nn_op_expand_broadcast.c b/src/tim/vx/internal/src/ops/vsi_nn_op_expand_broadcast.c index 4ba10e9..df4aa95 100644 --- a/src/tim/vx/internal/src/ops/vsi_nn_op_expand_broadcast.c +++ b/src/tim/vx/internal/src/ops/vsi_nn_op_expand_broadcast.c @@ -95,8 +95,11 @@ static vsi_bool op_setup { uint32_t i; vsi_nn_tensor_attr_t attr; - vsi_nn_internal_tensor_t *input_1; + vsi_nn_internal_tensor_t* input_0 = NULL; + vsi_nn_internal_tensor_t *input_1 = NULL; vsi_nn_internal_node_t* mul_node = NULL; + vsi_nn_tensor_t* mul_input = NULL; + int32_t use_virtual_tensor = 1; vsi_nn_expand_broadcast_param *p = &self->nn_param.expand_broadcast; vsi_nn_internal_init_node_wksp(self); @@ -112,8 +115,35 @@ static vsi_bool op_setup } input_1 = vsi_nn_internal_new_tensor( self, &attr, 1.0f ); + if (p->dimensions_num > 0) { + vsi_nn_internal_node_t* reshape_node = NULL; + vsi_size_t* reshape_input_size = NULL; + memset(&attr, 0, sizeof(vsi_nn_tensor_attr_t)); + vsi_nn_internal_init_tensor_attr(&attr, &inputs[0]->attr.dtype, use_virtual_tensor); + input_0 = vsi_nn_internal_new_tensor( self, &attr, 0.0f ); + reshape_node = vsi_nn_internal_new_node( self, VSI_NN_OP_RESHAPE2, 0, 0 ); + reshape_input_size = (vsi_size_t*)vsi_nn_internal_new_node_param(reshape_node, + VSI_NN_MAX_DIM_NUM * sizeof(vsi_size_t)); + + for(i = 0; i < p->dim_num; i++) { + reshape_input_size[i] = 1; + } + for (i = 0; i < p->dimensions_num; i++) { + reshape_input_size[p->dimensions[i]] = p->shape[p->dimensions[i]]; + } + + reshape_node->node->nn_param.reshape2.size = reshape_input_size; + reshape_node->node->nn_param.reshape2.dim_num = p->dim_num; + reshape_node->inputs[0] = inputs[0]; + reshape_node->outputs[0] = input_0->t; + vsi_nn_internal_setup_node( self, reshape_node ); + mul_input = input_0->t; + } else { + mul_input = inputs[0]; + } + mul_node = vsi_nn_internal_new_node(self, VSI_NN_OP_MULTIPLY, 0, 0 ); - mul_node->inputs[0] = inputs[0]; + mul_node->inputs[0] = mul_input; mul_node->inputs[1] = input_1->t; mul_node->node->nn_param.multiply.scale = 1.0f; mul_node->node->vx_param.overflow_policy = VX_CONVERT_POLICY_SATURATE;