Set RNN internal dtype

Init RNN internal dtype to avoid the
internal FC OP to go to the CPU path

Type:Code Improvement

Signed-off-by: Kee <xuke537@hotmail.com>
This commit is contained in:
Kee 2022-11-11 11:16:24 +00:00 committed by Sven
parent b53fd14375
commit 4db479ece4
1 changed files with 12 additions and 0 deletions

View File

@ -233,6 +233,18 @@ static vsi_bool op_setup
curr = vsi_nn_internal_new_node( self, VSI_NN_OP_RNNCELL_OVXLIB, 0, 0 );
curr->node->nn_param.rnncell_ovxlib.activation = curr_param->activation;
if ( reshape_output->attr.dtype.vx_type == VSI_NN_TYPE_BFLOAT16 ||
reshape_output->attr.dtype.vx_type == VSI_NN_TYPE_FLOAT32 )
{
int32_t k = 0;
for (k = 0; k < _cnt_of_array( curr_param->internal_dtype ); k++)
{
if (curr_param->internal_dtype[k].vx_type == VSI_NN_TYPE_NONE)
{
curr_param->internal_dtype[k] = reshape_output->attr.dtype;
}
}
}
memcpy( curr->node->nn_param.rnncell_ovxlib.internal_dtype,
curr_param->internal_dtype, sizeof( curr_param->internal_dtype ) );
curr->inputs[RNNCELL_INPUT_INPUT] = reshape_output;