Fixed BidirectionalSequenceLSTM bug

Fixed input error of  the backward direction
Fixed golden error of unit test

Type: Bug Fix
Signed-off-by: Feiyue Chen <Feiyue.Chen@verisilicon.com>
This commit is contained in:
Feiyue Chen 2022-11-16 10:50:03 +08:00 committed by Sven
parent 4db479ece4
commit 11fd278d7a
3 changed files with 58 additions and 15 deletions

View File

@ -93,5 +93,6 @@
#include "tim/vx/ops/conv3d.h"
#include "tim/vx/ops/custom_base.h"
#include "tim/vx/ops/topk.h"
#include "tim/vx/ops/bidirectional_sequence_lstm.h"
#endif /* TIM_VX_OPS_H_ */

View File

@ -23,6 +23,7 @@
*****************************************************************************/
#include "tim/vx/ops/bidirectional_sequence_lstm.h"
#include "tim/vx/ops/unidirectional_sequence_lstm.h"
#include "tim/vx/ops/reverse.h"
#include "vsi_nn_pub.h"
#include "op_impl.h"
@ -122,15 +123,22 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
};
BidirectionalSequenceLstmImpl(Graph* graph, int input_cnt, int output_cnt,
DataLayout layout = DataLayout::ANY)
float cell_clip, float proj_clip,
tim::vx::ops::UnidirectionalSequenceLstm::ActivationType act_type,
float forget_bias, bool time_major,
tim::vx::ops::UnidirectionalSequenceLstm::ActivationType recurrent_act_type,
bool return_sequences, DataLayout layout = DataLayout::ANY)
: OpImpl(graph, -1, input_cnt, output_cnt, layout) {
lstm_forward_ = graph->CreateOperation<UnidirectionalSequenceLstm>(
0.0, 0.0, UnidirectionalSequenceLstm::kTANH, 0.0, false,
UnidirectionalSequenceLstm::kSIGMOID, true);
lstm_backward_ =
graph->CreateOperation<tim::vx::ops::UnidirectionalSequenceLstm>(
0.0, 0.0, UnidirectionalSequenceLstm::kTANH, 0.0, false,
UnidirectionalSequenceLstm::kSIGMOID, true);
lstm_forward_ = graph->CreateOperation<UnidirectionalSequenceLstm>(
cell_clip, proj_clip, act_type, forget_bias, time_major,
recurrent_act_type, return_sequences);
lstm_backward_ = graph->CreateOperation<UnidirectionalSequenceLstm>(
cell_clip, proj_clip, act_type, forget_bias, time_major,
recurrent_act_type, return_sequences);
reverse_input_ = graph->CreateOperation<Reverse>(time_major ? std::vector<int32_t> ({2}) :
std::vector<int32_t> ({1}));
reverse_output_ = graph->CreateOperation<Reverse>(time_major ? std::vector<int32_t> ({2}) :
std::vector<int32_t> ({1}));
}
~BidirectionalSequenceLstmImpl() {}
@ -142,6 +150,12 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
if (this->input_tensor_index == INPUT_CNT - 1) {
// Get all input tensor
lstm_forward_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]);
reverse_input_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]);
TensorSpec bw_input_spec (in_tensors_[BI_LSTM_INPUT_INPUT]->GetDataType(),
in_tensors_[BI_LSTM_INPUT_INPUT]->GetShape(),
TensorAttribute::TRANSIENT);
bw_input_tensor_ = graph_->CreateTensor(bw_input_spec);
reverse_input_->BindOutput(bw_input_tensor_);
lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_H_STATE]);
lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_C_STATE]);
@ -172,7 +186,7 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_LAYERNORM_C]);
lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_LAYERNORM_O]);
lstm_backward_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]);
lstm_backward_->BindInput(bw_input_tensor_);
lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_H_STATE]);
lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_C_STATE]);
@ -216,7 +230,10 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
lstm_forward_->BindOutput(out_tensors_[BI_LSTM_FW_OUTPUT_H_STATE]);
lstm_forward_->BindOutput(out_tensors_[BI_LSTM_FW_OUTPUT_C_STATE]);
lstm_backward_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_OUTPUT]);
bw_output_tensor_ = graph_->CreateTensor(out_tensors_[BI_LSTM_BW_OUTPUT_OUTPUT]->GetSpec());
lstm_backward_->BindOutput(bw_output_tensor_);
reverse_output_->BindInput(bw_output_tensor_);
reverse_output_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_OUTPUT]);
lstm_backward_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_H_STATE]);
lstm_backward_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_C_STATE]);
}
@ -236,11 +253,33 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
private:
std::shared_ptr<tim::vx::Operation> lstm_forward_;
std::shared_ptr<tim::vx::Operation> lstm_backward_;
std::shared_ptr<tim::vx::Operation> reverse_input_;
std::shared_ptr<tim::vx::Operation> reverse_output_;
std::array<std::shared_ptr<tim::vx::Tensor>, INPUT_CNT> in_tensors_;
std::array<std::shared_ptr<tim::vx::Tensor>, OUTPUT_CNT> out_tensors_;
std::shared_ptr<Tensor> bw_input_tensor_;
std::shared_ptr<Tensor> bw_output_tensor_;
};
UnidirectionalSequenceLstm::ActivationType interpreter(BidirectionalSequenceLstm::ActivationType act){
switch (act){
case BidirectionalSequenceLstm::ActivationType::kRELU:
return UnidirectionalSequenceLstm::ActivationType::kRELU;
case BidirectionalSequenceLstm::ActivationType::kRELU6:
return UnidirectionalSequenceLstm::ActivationType::kRELU6;
case BidirectionalSequenceLstm::ActivationType::kTANH:
return UnidirectionalSequenceLstm::ActivationType::kTANH;
case BidirectionalSequenceLstm::ActivationType::kSIGMOID:
return UnidirectionalSequenceLstm::ActivationType::kSIGMOID;
case BidirectionalSequenceLstm::ActivationType::kHARDSIGMOID:
return UnidirectionalSequenceLstm::ActivationType::kHARDSIGMOID;
default: {
return UnidirectionalSequenceLstm::ActivationType::kNONE;
}
}
}
BidirectionalSequenceLstm::BidirectionalSequenceLstm(
Graph* graph, float cell_clip, float proj_clip, ActivationType act_type,
float forget_bias, bool time_major, ActivationType recurrent_act_type,
@ -252,8 +291,11 @@ BidirectionalSequenceLstm::BidirectionalSequenceLstm(
time_major_(time_major),
recurrent_act_type_(recurrent_act_type),
return_sequences_(return_sequences) {
impl_ = std::make_unique<BidirectionalSequenceLstmImpl>(graph, 0, 0,
DataLayout::ANY);
impl_ = std::make_unique<BidirectionalSequenceLstmImpl>(graph, 0, 0, cell_clip_,
proj_clip_, interpreter(act_type_),
forget_bias_,time_major_,
interpreter(recurrent_act_type_),
return_sequences_, DataLayout::ANY);
}
std::shared_ptr<Operation> BidirectionalSequenceLstm::Clone(

View File

@ -288,9 +288,9 @@ TEST(Bidirectional_LSTM_CELL, shape_in_2_cell_4_out_4_float32) {
-0.03716109, 0.12507336, 0.41193449, -0.20860538,
-0.15053082, 0.09120187, 0.24278517, -0.12222792};
std::vector<float> lstm_bw_golden_output = {
-0.02973187, 0.1229473, 0.20885126, -0.15358765,
-0.03716109, 0.12507336, 0.41193449, -0.20860538,
-0.15053082, 0.09120187, 0.24278517, -0.12222792};
-0.0806187, 0.139077, 0.400476, -0.197842,
-0.0332076, 0.123838, 0.309777, -0.17621,
-0.0490733, 0.0739237, 0.067706, -0.0208124};
std::vector<float> fw_output(lstm_fw_golden_output.size());
std::vector<float> bw_output(lstm_bw_golden_output.size());
fw_output_tensor->CopyDataFromTensor(fw_output.data());