From 4db479ece4c4fc849521057294baf365d7334b97 Mon Sep 17 00:00:00 2001 From: Kee Date: Fri, 11 Nov 2022 11:16:24 +0000 Subject: [PATCH] Set RNN internal dtype Init RNN internal dtype to avoid the internal FC OP to go to the CPU path Type:Code Improvement Signed-off-by: Kee --- .../src/ops/vsi_nn_op_unidirectional_sequence_rnn.c | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/tim/vx/internal/src/ops/vsi_nn_op_unidirectional_sequence_rnn.c b/src/tim/vx/internal/src/ops/vsi_nn_op_unidirectional_sequence_rnn.c index bf12b96..6c04ec6 100644 --- a/src/tim/vx/internal/src/ops/vsi_nn_op_unidirectional_sequence_rnn.c +++ b/src/tim/vx/internal/src/ops/vsi_nn_op_unidirectional_sequence_rnn.c @@ -233,6 +233,18 @@ static vsi_bool op_setup curr = vsi_nn_internal_new_node( self, VSI_NN_OP_RNNCELL_OVXLIB, 0, 0 ); curr->node->nn_param.rnncell_ovxlib.activation = curr_param->activation; + if ( reshape_output->attr.dtype.vx_type == VSI_NN_TYPE_BFLOAT16 || + reshape_output->attr.dtype.vx_type == VSI_NN_TYPE_FLOAT32 ) + { + int32_t k = 0; + for (k = 0; k < _cnt_of_array( curr_param->internal_dtype ); k++) + { + if (curr_param->internal_dtype[k].vx_type == VSI_NN_TYPE_NONE) + { + curr_param->internal_dtype[k] = reshape_output->attr.dtype; + } + } + } memcpy( curr->node->nn_param.rnncell_ovxlib.internal_dtype, curr_param->internal_dtype, sizeof( curr_param->internal_dtype ) ); curr->inputs[RNNCELL_INPUT_INPUT] = reshape_output;