add BroadcastInDim to internal expand_broadcast op (#364)

This commit is contained in:
Antkillerfarm 2022-04-18 13:59:18 +08:00 committed by GitHub
parent eb21143987
commit 954d264108
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 2 deletions

View File

@ -34,6 +34,8 @@ typedef struct _vsi_nn_expand_broadcast_param
{
uint32_t *shape;
uint32_t dim_num;
uint32_t *dimensions;
uint32_t dimensions_num;
} vsi_nn_expand_broadcast_param;
#ifdef __cplusplus

View File

@ -95,8 +95,11 @@ static vsi_bool op_setup
{
uint32_t i;
vsi_nn_tensor_attr_t attr;
vsi_nn_internal_tensor_t *input_1;
vsi_nn_internal_tensor_t* input_0 = NULL;
vsi_nn_internal_tensor_t *input_1 = NULL;
vsi_nn_internal_node_t* mul_node = NULL;
vsi_nn_tensor_t* mul_input = NULL;
int32_t use_virtual_tensor = 1;
vsi_nn_expand_broadcast_param *p = &self->nn_param.expand_broadcast;
vsi_nn_internal_init_node_wksp(self);
@ -112,8 +115,35 @@ static vsi_bool op_setup
}
input_1 = vsi_nn_internal_new_tensor( self, &attr, 1.0f );
if (p->dimensions_num > 0) {
vsi_nn_internal_node_t* reshape_node = NULL;
vsi_size_t* reshape_input_size = NULL;
memset(&attr, 0, sizeof(vsi_nn_tensor_attr_t));
vsi_nn_internal_init_tensor_attr(&attr, &inputs[0]->attr.dtype, use_virtual_tensor);
input_0 = vsi_nn_internal_new_tensor( self, &attr, 0.0f );
reshape_node = vsi_nn_internal_new_node( self, VSI_NN_OP_RESHAPE2, 0, 0 );
reshape_input_size = (vsi_size_t*)vsi_nn_internal_new_node_param(reshape_node,
VSI_NN_MAX_DIM_NUM * sizeof(vsi_size_t));
for(i = 0; i < p->dim_num; i++) {
reshape_input_size[i] = 1;
}
for (i = 0; i < p->dimensions_num; i++) {
reshape_input_size[p->dimensions[i]] = p->shape[p->dimensions[i]];
}
reshape_node->node->nn_param.reshape2.size = reshape_input_size;
reshape_node->node->nn_param.reshape2.dim_num = p->dim_num;
reshape_node->inputs[0] = inputs[0];
reshape_node->outputs[0] = input_0->t;
vsi_nn_internal_setup_node( self, reshape_node );
mul_input = input_0->t;
} else {
mul_input = inputs[0];
}
mul_node = vsi_nn_internal_new_node(self, VSI_NN_OP_MULTIPLY, 0, 0 );
mul_node->inputs[0] = inputs[0];
mul_node->inputs[0] = mul_input;
mul_node->inputs[1] = input_1->t;
mul_node->node->nn_param.multiply.scale = 1.0f;
mul_node->node->vx_param.overflow_policy = VX_CONVERT_POLICY_SATURATE;