From c231c54a662620f43b1057092e236e6388fb7ebd Mon Sep 17 00:00:00 2001 From: Feiyue Chen Date: Tue, 22 Nov 2022 14:20:21 +0800 Subject: [PATCH] Fixed BidirectionalSequenceRnn bugs Added layout inference for BidirectionalRnn Fixed wrong datatype and wrong output order of internal about backward rnn Corrected golden in BidirectionalRnn&BidirectionalRnnExt unit test Modified copyright and log message Type: Bug Fix Signed-off-by: Feiyue Chen --- src/tim/transform/layout_inference.cc | 2 + .../ops/bidirectional_rnn_layout_inference.h | 99 +++++++++++++++++++ .../vsi_nn_op_bidirectional_sequence_rnn.c | 27 +++-- src/tim/vx/ops/bidirectional_sequence_rnn.cc | 4 +- .../vx/ops/bidirectional_sequence_rnn_ext.cc | 2 +- .../bidirectional_sequence_rnn_ext_test.cc | 16 +-- .../vx/ops/bidirectional_sequence_rnn_test.cc | 32 +++--- 7 files changed, 146 insertions(+), 36 deletions(-) create mode 100644 src/tim/transform/ops/bidirectional_rnn_layout_inference.h diff --git a/src/tim/transform/layout_inference.cc b/src/tim/transform/layout_inference.cc index 4ad680f..6e25694 100644 --- a/src/tim/transform/layout_inference.cc +++ b/src/tim/transform/layout_inference.cc @@ -65,6 +65,7 @@ #include "ops/unidirectional_lstm_layout_inference.h" #include "ops/broadcast_layout_inference.h" #include "ops/unidirectional_rnn_layout_inference.h" +#include "ops/bidirectional_rnn_layout_inference.h" #include #include @@ -269,6 +270,7 @@ std::vector> HandleLayoutInfer( REGIST_LAYOUT_INFERENCE(VSI_NN_OP_LSTM_OVXLIB, UnidirectionalLstm); REGIST_LAYOUT_INFERENCE(VSI_NN_OP_EXPAND_BROADCAST, Broadcast); REGIST_LAYOUT_INFERENCE(VSI_NN_OP_UNIDIRECTIONAL_SEQUENCE_RNN, UnidirectionalRnn); + REGIST_LAYOUT_INFERENCE(VSI_NN_OP_BIDIRECTIONAL_SEQUENCE_RNN, BidirectionalRnn); REGIST_LOGICAL_LAYOUT_INFERENCE(VSI_NN_OP_LOGICAL_OPS); REGIST_REDUCE_LAYOUT_INFERENCE(VSI_NN_OP_REDUCE); // use default layout inference diff --git a/src/tim/transform/ops/bidirectional_rnn_layout_inference.h b/src/tim/transform/ops/bidirectional_rnn_layout_inference.h new file mode 100644 index 0000000..a08d30a --- /dev/null +++ b/src/tim/transform/ops/bidirectional_rnn_layout_inference.h @@ -0,0 +1,99 @@ +/**************************************************************************** + * + * Copyright (c) 2022 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#ifndef TIM_LAYOUT_INFER_BIDIRECTIONAL_RNN_LAYOUT_INFERENCE_H_ +#define TIM_LAYOUT_INFER_BIDIRECTIONAL_RNN_LAYOUT_INFERENCE_H_ + +#include "tim/vx/ops/reshape.h" +#include "tim/vx/ops/nbg.h" +#include "tim/vx/ops/transpose.h" +#include "tim/vx/ops/batchnorm.h" +#include "tim/vx/ops/clip.h" + +#include "ops/op_layout_inference.h" +#include "permute_vector.h" +#include "builtin_op_impl.h" + +namespace tim { +namespace transform { + +class BidirectionalRnnLayoutInfer : public OpLayoutInfer { + public: + BidirectionalRnnLayoutInfer( + const std::shared_ptr op, + std::shared_ptr& context) + : OpLayoutInfer(op, context) {} + + // reverse any applied permute on it's input tensor + void OnInputs( + std::vector>& next_tensors) override { + ReverseInputsPermuteVector(); + + auto cloned_op = op_->Clone(context_->infer_graph_); + + for (const auto& i_src : op_->impl()->InputsTensor()) { + std::shared_ptr infer_tensor; + std::shared_ptr required_pv; + if ((i_src->IsConstTensor() && + !(i_src->GetSpec().attr_ & vx::TensorAttribute::INPUT))) { + infer_tensor = context_->infer_graph_->CreateTensor( + i_src->GetSpec(), i_src->GetDataRef()); + context_->UpdateTensorMap(i_src, infer_tensor); + } + if (i_src->GetId() == (uint32_t)-1) { + infer_tensor = context_->infer_graph_->CreateTensorPlaceHolder(); + context_->UpdateTensorMap(i_src, infer_tensor); + } + required_pv = MakeShared(i_src->GetShape().size()); + context_->SetPermuteVector(i_src, required_pv); + } + + + for (const auto& i_src : op_->impl()->InputsTensor()) { + (*cloned_op).BindInput(context_->GetMapedTensor(i_src)); + } + + + std::vector> required_pv_lst; + for (auto out_tensor : op_->impl()->OutputsTensor()) { + std::shared_ptr infer_tensor; + if (out_tensor->GetId() == (uint32_t)-1) { + out_tensor = context_->infer_graph_->CreateTensorPlaceHolder(); + } + required_pv_lst.push_back(MakeShared(out_tensor->GetShape().size())); + } + auto out_infer = CreateOutputsTensor(required_pv_lst); + + (*cloned_op).BindOutputs(out_infer); + uint32_t i = 0; + for (auto out_tensor : op_->impl()->OutputsTensor()) { + context_->SetPermuteVector(out_tensor, required_pv_lst[i++]); + next_tensors.push_back(out_tensor); + } + } +}; + +} // namespace transform +} // namespace tim + +#endif \ No newline at end of file diff --git a/src/tim/vx/internal/src/ops/vsi_nn_op_bidirectional_sequence_rnn.c b/src/tim/vx/internal/src/ops/vsi_nn_op_bidirectional_sequence_rnn.c index da6af26..19bc82f 100644 --- a/src/tim/vx/internal/src/ops/vsi_nn_op_bidirectional_sequence_rnn.c +++ b/src/tim/vx/internal/src/ops/vsi_nn_op_bidirectional_sequence_rnn.c @@ -339,8 +339,8 @@ static vsi_bool op_setup output_tensor = vsi_nn_internal_new_tensor( self, &attr, 0.0f ); rnncell_out1 = output_tensor->t; - if (reshape_output_tensors[time_step - 1 - i]->attr.dtype.vx_type == VSI_NN_TYPE_FLOAT32 && - inputs[BI_RNN_BW_INPUT_WEIGHT_I]->attr.dtype.vx_type == VSI_NN_TYPE_FLOAT32 && + if (reshape_output_tensors[i]->attr.dtype.vx_type == VSI_NN_TYPE_FLOAT32 && + inputs[BI_RNN_FW_INPUT_WEIGHT_I]->attr.dtype.vx_type == VSI_NN_TYPE_FLOAT32 && curr_param->internal_dtype[RNNCELL_QUANTIZE_PARAM_I].qnt_type == VSI_NN_QNT_TYPE_NONE && curr_param->internal_dtype[RNNCELL_QUANTIZE_PARAM_I].vx_type == VSI_NN_TYPE_NONE) { @@ -349,16 +349,16 @@ static vsi_bool op_setup if (last_step_h_state_fw && last_step_h_state_fw->attr.dtype.vx_type == VSI_NN_TYPE_FLOAT32 && - inputs[BI_RNN_BW_INPUT_WEIGHT_H]->attr.dtype.vx_type == VSI_NN_TYPE_FLOAT32 && + inputs[BI_RNN_FW_INPUT_WEIGHT_H]->attr.dtype.vx_type == VSI_NN_TYPE_FLOAT32 && curr_param->internal_dtype[RNNCELL_QUANTIZE_PARAM_H].qnt_type == VSI_NN_QNT_TYPE_NONE && curr_param->internal_dtype[RNNCELL_QUANTIZE_PARAM_H].vx_type == VSI_NN_TYPE_NONE) { curr_param->internal_dtype[RNNCELL_QUANTIZE_PARAM_H].vx_type = VSI_NN_TYPE_FLOAT32; } - if (has_aux_input&& - aux_reshape_output_tensors[time_step - 1 - i]->attr.dtype.vx_type == VSI_NN_TYPE_FLOAT32 && - inputs[BI_RNN_BW_AUX_INPUT_WEIGHT]->attr.dtype.vx_type == VSI_NN_TYPE_FLOAT32 && + if (has_aux_input && + aux_reshape_output_tensors[i]->attr.dtype.vx_type == VSI_NN_TYPE_FLOAT32 && + inputs[BI_RNN_FW_AUX_INPUT_WEIGHT]->attr.dtype.vx_type == VSI_NN_TYPE_FLOAT32 && curr_param->internal_dtype[RNNCELL_QUANTIZE_PARAM_AUX].qnt_type == VSI_NN_QNT_TYPE_NONE && curr_param->internal_dtype[RNNCELL_QUANTIZE_PARAM_AUX].vx_type == VSI_NN_TYPE_NONE) { @@ -410,8 +410,17 @@ static vsi_bool op_setup vsi_nn_tensor_t* rnncell_out1 = NULL; /* rnncell output */ - vsi_nn_internal_init_tensor_attr(&attr, + + if(curr_param->merge_outputs) + { + vsi_nn_internal_init_tensor_attr(&attr, + &outputs[BI_RNN_FW_OUTPUT_OUTPUT]->attr.dtype, use_virtual_tensor); + } + else + { + vsi_nn_internal_init_tensor_attr(&attr, &outputs[BI_RNN_BW_OUTPUT_OUTPUT]->attr.dtype, use_virtual_tensor); + } output_tensor = vsi_nn_internal_new_tensor( self, &attr, 0.0f ); rnncell_out0 = output_tensor->t; @@ -438,7 +447,7 @@ static vsi_bool op_setup curr_param->internal_dtype[RNNCELL_QUANTIZE_PARAM_H].vx_type = VSI_NN_TYPE_FLOAT32; } - if (has_aux_input&& + if (has_aux_input && aux_reshape_output_tensors[time_step - 1 - i]->attr.dtype.vx_type == VSI_NN_TYPE_FLOAT32 && inputs[BI_RNN_BW_AUX_INPUT_WEIGHT]->attr.dtype.vx_type == VSI_NN_TYPE_FLOAT32 && curr_param->internal_dtype[RNNCELL_QUANTIZE_PARAM_AUX].qnt_type == VSI_NN_QNT_TYPE_NONE && @@ -481,7 +490,7 @@ static vsi_bool op_setup /* reshape output to 3-dims */ output_tensor = vsi_nn_rnn_reshape_cell_output(self, rnncell_out0, (uint32_t)batch_size, use_virtual_tensor); - rnncell_reshape_output_tensors_bw[i] = output_tensor->t; + rnncell_reshape_output_tensors_bw[time_step - 1 - i] = output_tensor->t; } if(curr_param->merge_outputs) diff --git a/src/tim/vx/ops/bidirectional_sequence_rnn.cc b/src/tim/vx/ops/bidirectional_sequence_rnn.cc index 8308f9a..a891864 100644 --- a/src/tim/vx/ops/bidirectional_sequence_rnn.cc +++ b/src/tim/vx/ops/bidirectional_sequence_rnn.cc @@ -1,6 +1,6 @@ /**************************************************************************** * -* Copyright (c) 2020 Vivante Corporation +* Copyright (c) 2022 Vivante Corporation * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -45,7 +45,7 @@ vsi_nn_activation_e downcast_act_type(BidirectionalSequenceRnn::ActivationType a case BidirectionalSequenceRnn::ActivationType::kHARDSIGMOID: return VSI_NN_ACT_HARD_SIGMOID; default: { - VSILOGW("Not supported activition type for RNN = %d", static_cast(act)); + VSILOGW("Not supported activition type for BidirectionalSequenceRNN = %d", static_cast(act)); return VSI_NN_ACT_NONE; } } diff --git a/src/tim/vx/ops/bidirectional_sequence_rnn_ext.cc b/src/tim/vx/ops/bidirectional_sequence_rnn_ext.cc index 78e3f54..a4a5da5 100644 --- a/src/tim/vx/ops/bidirectional_sequence_rnn_ext.cc +++ b/src/tim/vx/ops/bidirectional_sequence_rnn_ext.cc @@ -1,6 +1,6 @@ /**************************************************************************** * -* Copyright (c) 2021 Vivante Corporation +* Copyright (c) 2022 Vivante Corporation * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), diff --git a/src/tim/vx/ops/bidirectional_sequence_rnn_ext_test.cc b/src/tim/vx/ops/bidirectional_sequence_rnn_ext_test.cc index 3cc38a7..7834550 100644 --- a/src/tim/vx/ops/bidirectional_sequence_rnn_ext_test.cc +++ b/src/tim/vx/ops/bidirectional_sequence_rnn_ext_test.cc @@ -29,7 +29,7 @@ #include "test_utils.h" -TEST(BidirectionalSequenceRnnExt, shape_2_3_4_float_sigmoid) { +TEST(BidirectionalSequenceRnnExt, shape_2_3_2_float_sigmoid) { auto ctx = tim::vx::Context::Create(); auto graph = ctx->CreateGraph(); @@ -98,7 +98,7 @@ TEST(BidirectionalSequenceRnnExt, shape_2_3_4_float_sigmoid) { std::vector bias_data = { 0.1, 0.1, 0.1, 0.1, 0.0, 0.0, 0.0, 0.0, - 0.1, 0.1, 0.1, 0.1, //bug 不能被获取到 + 0.1, 0.1, 0.1, 0.1, 0.0, 0.0, 0.0, 0.0, }; std::vector state_in_data = { @@ -113,15 +113,15 @@ TEST(BidirectionalSequenceRnnExt, shape_2_3_4_float_sigmoid) { 0.5986, 0.5986, 0.5986, 0.5986, 0.6899, 0.6899, 0.6899, 0.6899, 0.7685, 0.7685, 0.7685, 0.7685, - 0.8320, 0.8320, 0.8320, 0.8320, - 0.8807, 0.8807, 0.8807, 0.8807, - 0.9168, 0.9168, 0.9168, 0.9168, + 0.6754, 0.6754, 0.6754, 0.6754, + 0.7599, 0.7599, 0.7599, 0.7599, + 0.8273, 0.8273, 0.8273, 0.8273, 0.8628, 0.8628, 0.8628, 0.8628, 0.9068, 0.9068, 0.9068, 0.9068, 0.9374, 0.9374, 0.9374, 0.9374, - 0.6754, 0.6754, 0.6754, 0.6754, - 0.7599, 0.7599, 0.7599, 0.7599, - 0.8273, 0.8273, 0.8273, 0.8273 + 0.8320, 0.8320, 0.8320, 0.8320, + 0.8807, 0.8807, 0.8807, 0.8807, + 0.9168, 0.9168, 0.9168, 0.9168, }; std::vector state_out_golden = { 0.8628, 0.8628, 0.8628, 0.8628, diff --git a/src/tim/vx/ops/bidirectional_sequence_rnn_test.cc b/src/tim/vx/ops/bidirectional_sequence_rnn_test.cc index fdf887f..c416bfb 100644 --- a/src/tim/vx/ops/bidirectional_sequence_rnn_test.cc +++ b/src/tim/vx/ops/bidirectional_sequence_rnn_test.cc @@ -1,6 +1,6 @@ /**************************************************************************** * -* Copyright (c) 2021 Vivante Corporation +* Copyright (c) 2022 Vivante Corporation * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), @@ -29,19 +29,19 @@ #include "test_utils.h" -TEST(BidirectionalSequenceRnn, shape_2_3_4_float_sigmoid) { +TEST(BidirectionalSequenceRnn, shape_2_3_2_float_sigmoid) { auto ctx = tim::vx::Context::Create(); auto graph = ctx->CreateGraph(); - uint32_t input_size = 2, batch_size = 3, num_units = 4; + uint32_t input_size = 2, batch_size = 3, time_step = 2, num_units = 4; - tim::vx::ShapeType input_shape({input_size, batch_size, 2}); + tim::vx::ShapeType input_shape({input_size, batch_size, time_step}); tim::vx::ShapeType weights_shape({input_size, num_units}); tim::vx::ShapeType recurrent_weights_shape({num_units, num_units}); tim::vx::ShapeType bias_shape({num_units}); tim::vx::ShapeType recurrent_bias_shape({num_units}); tim::vx::ShapeType state_in_shape({num_units, batch_size}); - tim::vx::ShapeType output_shape({num_units, batch_size, 2}); + tim::vx::ShapeType output_shape({num_units, batch_size, time_step}); tim::vx::ShapeType state_out_shape({num_units, batch_size}); tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, @@ -123,12 +123,12 @@ TEST(BidirectionalSequenceRnn, shape_2_3_4_float_sigmoid) { 0.9374, 0.9374, 0.9374, 0.9374, }; std::vector bw_output_golden = { - 0.8320, 0.8320, 0.8320, 0.8320, - 0.8807, 0.8807, 0.8807, 0.8807, - 0.9168, 0.9168, 0.9168, 0.9168, 0.6754, 0.6754, 0.6754, 0.6754, 0.7599, 0.7599, 0.7599, 0.7599, - 0.8273, 0.8273, 0.8273, 0.8273 + 0.8273, 0.8273, 0.8273, 0.8273, + 0.8320, 0.8320, 0.8320, 0.8320, + 0.8807, 0.8807, 0.8807, 0.8807, + 0.9168, 0.9168, 0.9168, 0.9168, }; std::vector bw_state_out_golden = { 0.6754, 0.6754, 0.6754, 0.6754, @@ -183,19 +183,19 @@ TEST(BidirectionalSequenceRnn, shape_2_3_4_float_sigmoid) { EXPECT_TRUE(ArraysMatch(bw_state_out_golden, bw_state_out,1e-3f)); } -TEST(BidirectionalSequenceRnn, shape_2_3_4_float_relu) { +TEST(BidirectionalSequenceRnn, shape_2_3_2_float_relu) { auto ctx = tim::vx::Context::Create(); auto graph = ctx->CreateGraph(); - uint32_t input_size = 2, batch_size = 3, num_units = 4; + uint32_t input_size = 2, batch_size = 3, num_units = 4, time_step = 2; - tim::vx::ShapeType input_shape({input_size, batch_size, 2}); + tim::vx::ShapeType input_shape({input_size, batch_size, time_step}); tim::vx::ShapeType weights_shape({input_size, num_units}); tim::vx::ShapeType recurrent_weights_shape({num_units, num_units}); tim::vx::ShapeType bias_shape({num_units}); tim::vx::ShapeType recurrent_bias_shape({num_units}); tim::vx::ShapeType state_in_shape({num_units, batch_size}); - tim::vx::ShapeType output_shape({num_units, batch_size, 2}); + tim::vx::ShapeType output_shape({num_units, batch_size, time_step}); tim::vx::ShapeType state_out_shape({num_units, batch_size}); tim::vx::TensorSpec input_spec(tim::vx::DataType::FLOAT32, @@ -277,12 +277,12 @@ TEST(BidirectionalSequenceRnn, shape_2_3_4_float_relu) { 2.88, 2.88, 2.88, 2.88, }; std::vector bw_output_golden = { - 1.6, 1.6, 1.6, 1.6, - 2.0, 2.0, 2.0, 2.0, - 2.4, 2.4, 2.4, 2.4, 1.04, 1.04, 1.04, 1.04, 1.6, 1.6, 1.6, 1.6, 2.16, 2.16, 2.16, 2.16, + 1.6, 1.6, 1.6, 1.6, + 2.0, 2.0, 2.0, 2.0, + 2.4, 2.4, 2.4, 2.4, }; std::vector bw_state_out_golden = { 1.04, 1.04, 1.04, 1.04,