From 11fd278d7a48e56fcca73f4610df18e283e3f2aa Mon Sep 17 00:00:00 2001 From: Feiyue Chen Date: Wed, 16 Nov 2022 10:50:03 +0800 Subject: [PATCH] Fixed BidirectionalSequenceLSTM bug Fixed input error of the backward direction Fixed golden error of unit test Type: Bug Fix Signed-off-by: Feiyue Chen --- include/tim/vx/ops.h | 1 + src/tim/vx/ops/bidirectional_sequence_lstm.cc | 66 +++++++++++++++---- .../ops/bidirectional_sequence_lstm_test.cc | 6 +- 3 files changed, 58 insertions(+), 15 deletions(-) diff --git a/include/tim/vx/ops.h b/include/tim/vx/ops.h index b9b16fc..a7ed369 100644 --- a/include/tim/vx/ops.h +++ b/include/tim/vx/ops.h @@ -93,5 +93,6 @@ #include "tim/vx/ops/conv3d.h" #include "tim/vx/ops/custom_base.h" #include "tim/vx/ops/topk.h" +#include "tim/vx/ops/bidirectional_sequence_lstm.h" #endif /* TIM_VX_OPS_H_ */ diff --git a/src/tim/vx/ops/bidirectional_sequence_lstm.cc b/src/tim/vx/ops/bidirectional_sequence_lstm.cc index 11c0c4b..c9bc0f1 100644 --- a/src/tim/vx/ops/bidirectional_sequence_lstm.cc +++ b/src/tim/vx/ops/bidirectional_sequence_lstm.cc @@ -23,6 +23,7 @@ *****************************************************************************/ #include "tim/vx/ops/bidirectional_sequence_lstm.h" #include "tim/vx/ops/unidirectional_sequence_lstm.h" +#include "tim/vx/ops/reverse.h" #include "vsi_nn_pub.h" #include "op_impl.h" @@ -122,15 +123,22 @@ class BidirectionalSequenceLstmImpl : public OpImpl { }; BidirectionalSequenceLstmImpl(Graph* graph, int input_cnt, int output_cnt, - DataLayout layout = DataLayout::ANY) + float cell_clip, float proj_clip, + tim::vx::ops::UnidirectionalSequenceLstm::ActivationType act_type, + float forget_bias, bool time_major, + tim::vx::ops::UnidirectionalSequenceLstm::ActivationType recurrent_act_type, + bool return_sequences, DataLayout layout = DataLayout::ANY) : OpImpl(graph, -1, input_cnt, output_cnt, layout) { - lstm_forward_ = graph->CreateOperation( - 0.0, 0.0, UnidirectionalSequenceLstm::kTANH, 0.0, false, - UnidirectionalSequenceLstm::kSIGMOID, true); - lstm_backward_ = - graph->CreateOperation( - 0.0, 0.0, UnidirectionalSequenceLstm::kTANH, 0.0, false, - UnidirectionalSequenceLstm::kSIGMOID, true); + lstm_forward_ = graph->CreateOperation( + cell_clip, proj_clip, act_type, forget_bias, time_major, + recurrent_act_type, return_sequences); + lstm_backward_ = graph->CreateOperation( + cell_clip, proj_clip, act_type, forget_bias, time_major, + recurrent_act_type, return_sequences); + reverse_input_ = graph->CreateOperation(time_major ? std::vector ({2}) : + std::vector ({1})); + reverse_output_ = graph->CreateOperation(time_major ? std::vector ({2}) : + std::vector ({1})); } ~BidirectionalSequenceLstmImpl() {} @@ -142,6 +150,12 @@ class BidirectionalSequenceLstmImpl : public OpImpl { if (this->input_tensor_index == INPUT_CNT - 1) { // Get all input tensor lstm_forward_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]); + reverse_input_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]); + TensorSpec bw_input_spec (in_tensors_[BI_LSTM_INPUT_INPUT]->GetDataType(), + in_tensors_[BI_LSTM_INPUT_INPUT]->GetShape(), + TensorAttribute::TRANSIENT); + bw_input_tensor_ = graph_->CreateTensor(bw_input_spec); + reverse_input_->BindOutput(bw_input_tensor_); lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_H_STATE]); lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_C_STATE]); @@ -172,7 +186,7 @@ class BidirectionalSequenceLstmImpl : public OpImpl { lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_LAYERNORM_C]); lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_LAYERNORM_O]); - lstm_backward_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]); + lstm_backward_->BindInput(bw_input_tensor_); lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_H_STATE]); lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_C_STATE]); @@ -216,7 +230,10 @@ class BidirectionalSequenceLstmImpl : public OpImpl { lstm_forward_->BindOutput(out_tensors_[BI_LSTM_FW_OUTPUT_H_STATE]); lstm_forward_->BindOutput(out_tensors_[BI_LSTM_FW_OUTPUT_C_STATE]); - lstm_backward_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_OUTPUT]); + bw_output_tensor_ = graph_->CreateTensor(out_tensors_[BI_LSTM_BW_OUTPUT_OUTPUT]->GetSpec()); + lstm_backward_->BindOutput(bw_output_tensor_); + reverse_output_->BindInput(bw_output_tensor_); + reverse_output_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_OUTPUT]); lstm_backward_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_H_STATE]); lstm_backward_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_C_STATE]); } @@ -236,11 +253,33 @@ class BidirectionalSequenceLstmImpl : public OpImpl { private: std::shared_ptr lstm_forward_; std::shared_ptr lstm_backward_; + std::shared_ptr reverse_input_; + std::shared_ptr reverse_output_; std::array, INPUT_CNT> in_tensors_; std::array, OUTPUT_CNT> out_tensors_; + std::shared_ptr bw_input_tensor_; + std::shared_ptr bw_output_tensor_; }; +UnidirectionalSequenceLstm::ActivationType interpreter(BidirectionalSequenceLstm::ActivationType act){ + switch (act){ + + case BidirectionalSequenceLstm::ActivationType::kRELU: + return UnidirectionalSequenceLstm::ActivationType::kRELU; + case BidirectionalSequenceLstm::ActivationType::kRELU6: + return UnidirectionalSequenceLstm::ActivationType::kRELU6; + case BidirectionalSequenceLstm::ActivationType::kTANH: + return UnidirectionalSequenceLstm::ActivationType::kTANH; + case BidirectionalSequenceLstm::ActivationType::kSIGMOID: + return UnidirectionalSequenceLstm::ActivationType::kSIGMOID; + case BidirectionalSequenceLstm::ActivationType::kHARDSIGMOID: + return UnidirectionalSequenceLstm::ActivationType::kHARDSIGMOID; + default: { + return UnidirectionalSequenceLstm::ActivationType::kNONE; + } + } +} BidirectionalSequenceLstm::BidirectionalSequenceLstm( Graph* graph, float cell_clip, float proj_clip, ActivationType act_type, float forget_bias, bool time_major, ActivationType recurrent_act_type, @@ -252,8 +291,11 @@ BidirectionalSequenceLstm::BidirectionalSequenceLstm( time_major_(time_major), recurrent_act_type_(recurrent_act_type), return_sequences_(return_sequences) { - impl_ = std::make_unique(graph, 0, 0, - DataLayout::ANY); + impl_ = std::make_unique(graph, 0, 0, cell_clip_, + proj_clip_, interpreter(act_type_), + forget_bias_,time_major_, + interpreter(recurrent_act_type_), + return_sequences_, DataLayout::ANY); } std::shared_ptr BidirectionalSequenceLstm::Clone( diff --git a/src/tim/vx/ops/bidirectional_sequence_lstm_test.cc b/src/tim/vx/ops/bidirectional_sequence_lstm_test.cc index 40235c1..926f3e3 100644 --- a/src/tim/vx/ops/bidirectional_sequence_lstm_test.cc +++ b/src/tim/vx/ops/bidirectional_sequence_lstm_test.cc @@ -288,9 +288,9 @@ TEST(Bidirectional_LSTM_CELL, shape_in_2_cell_4_out_4_float32) { -0.03716109, 0.12507336, 0.41193449, -0.20860538, -0.15053082, 0.09120187, 0.24278517, -0.12222792}; std::vector lstm_bw_golden_output = { - -0.02973187, 0.1229473, 0.20885126, -0.15358765, - -0.03716109, 0.12507336, 0.41193449, -0.20860538, - -0.15053082, 0.09120187, 0.24278517, -0.12222792}; + -0.0806187, 0.139077, 0.400476, -0.197842, + -0.0332076, 0.123838, 0.309777, -0.17621, + -0.0490733, 0.0739237, 0.067706, -0.0208124}; std::vector fw_output(lstm_fw_golden_output.size()); std::vector bw_output(lstm_bw_golden_output.size()); fw_output_tensor->CopyDataFromTensor(fw_output.data());