add BroadcastInDim to internal expand_broadcast op (#364)
This commit is contained in:
parent
eb21143987
commit
954d264108
|
|
@ -34,6 +34,8 @@ typedef struct _vsi_nn_expand_broadcast_param
|
|||
{
|
||||
uint32_t *shape;
|
||||
uint32_t dim_num;
|
||||
uint32_t *dimensions;
|
||||
uint32_t dimensions_num;
|
||||
} vsi_nn_expand_broadcast_param;
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
|||
|
|
@ -95,8 +95,11 @@ static vsi_bool op_setup
|
|||
{
|
||||
uint32_t i;
|
||||
vsi_nn_tensor_attr_t attr;
|
||||
vsi_nn_internal_tensor_t *input_1;
|
||||
vsi_nn_internal_tensor_t* input_0 = NULL;
|
||||
vsi_nn_internal_tensor_t *input_1 = NULL;
|
||||
vsi_nn_internal_node_t* mul_node = NULL;
|
||||
vsi_nn_tensor_t* mul_input = NULL;
|
||||
int32_t use_virtual_tensor = 1;
|
||||
vsi_nn_expand_broadcast_param *p = &self->nn_param.expand_broadcast;
|
||||
|
||||
vsi_nn_internal_init_node_wksp(self);
|
||||
|
|
@ -112,8 +115,35 @@ static vsi_bool op_setup
|
|||
}
|
||||
input_1 = vsi_nn_internal_new_tensor( self, &attr, 1.0f );
|
||||
|
||||
if (p->dimensions_num > 0) {
|
||||
vsi_nn_internal_node_t* reshape_node = NULL;
|
||||
vsi_size_t* reshape_input_size = NULL;
|
||||
memset(&attr, 0, sizeof(vsi_nn_tensor_attr_t));
|
||||
vsi_nn_internal_init_tensor_attr(&attr, &inputs[0]->attr.dtype, use_virtual_tensor);
|
||||
input_0 = vsi_nn_internal_new_tensor( self, &attr, 0.0f );
|
||||
reshape_node = vsi_nn_internal_new_node( self, VSI_NN_OP_RESHAPE2, 0, 0 );
|
||||
reshape_input_size = (vsi_size_t*)vsi_nn_internal_new_node_param(reshape_node,
|
||||
VSI_NN_MAX_DIM_NUM * sizeof(vsi_size_t));
|
||||
|
||||
for(i = 0; i < p->dim_num; i++) {
|
||||
reshape_input_size[i] = 1;
|
||||
}
|
||||
for (i = 0; i < p->dimensions_num; i++) {
|
||||
reshape_input_size[p->dimensions[i]] = p->shape[p->dimensions[i]];
|
||||
}
|
||||
|
||||
reshape_node->node->nn_param.reshape2.size = reshape_input_size;
|
||||
reshape_node->node->nn_param.reshape2.dim_num = p->dim_num;
|
||||
reshape_node->inputs[0] = inputs[0];
|
||||
reshape_node->outputs[0] = input_0->t;
|
||||
vsi_nn_internal_setup_node( self, reshape_node );
|
||||
mul_input = input_0->t;
|
||||
} else {
|
||||
mul_input = inputs[0];
|
||||
}
|
||||
|
||||
mul_node = vsi_nn_internal_new_node(self, VSI_NN_OP_MULTIPLY, 0, 0 );
|
||||
mul_node->inputs[0] = inputs[0];
|
||||
mul_node->inputs[0] = mul_input;
|
||||
mul_node->inputs[1] = input_1->t;
|
||||
mul_node->node->nn_param.multiply.scale = 1.0f;
|
||||
mul_node->node->vx_param.overflow_policy = VX_CONVERT_POLICY_SATURATE;
|
||||
|
|
|
|||
Loading…
Reference in New Issue