Fixed BidirectionalSequenceLSTM bug
Fixed input error of the backward direction Fixed golden error of unit test Type: Bug Fix Signed-off-by: Feiyue Chen <Feiyue.Chen@verisilicon.com>
This commit is contained in:
parent
4db479ece4
commit
11fd278d7a
|
|
@ -93,5 +93,6 @@
|
|||
#include "tim/vx/ops/conv3d.h"
|
||||
#include "tim/vx/ops/custom_base.h"
|
||||
#include "tim/vx/ops/topk.h"
|
||||
#include "tim/vx/ops/bidirectional_sequence_lstm.h"
|
||||
|
||||
#endif /* TIM_VX_OPS_H_ */
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@
|
|||
*****************************************************************************/
|
||||
#include "tim/vx/ops/bidirectional_sequence_lstm.h"
|
||||
#include "tim/vx/ops/unidirectional_sequence_lstm.h"
|
||||
#include "tim/vx/ops/reverse.h"
|
||||
#include "vsi_nn_pub.h"
|
||||
#include "op_impl.h"
|
||||
|
||||
|
|
@ -122,15 +123,22 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
|
|||
};
|
||||
|
||||
BidirectionalSequenceLstmImpl(Graph* graph, int input_cnt, int output_cnt,
|
||||
DataLayout layout = DataLayout::ANY)
|
||||
float cell_clip, float proj_clip,
|
||||
tim::vx::ops::UnidirectionalSequenceLstm::ActivationType act_type,
|
||||
float forget_bias, bool time_major,
|
||||
tim::vx::ops::UnidirectionalSequenceLstm::ActivationType recurrent_act_type,
|
||||
bool return_sequences, DataLayout layout = DataLayout::ANY)
|
||||
: OpImpl(graph, -1, input_cnt, output_cnt, layout) {
|
||||
lstm_forward_ = graph->CreateOperation<UnidirectionalSequenceLstm>(
|
||||
0.0, 0.0, UnidirectionalSequenceLstm::kTANH, 0.0, false,
|
||||
UnidirectionalSequenceLstm::kSIGMOID, true);
|
||||
lstm_backward_ =
|
||||
graph->CreateOperation<tim::vx::ops::UnidirectionalSequenceLstm>(
|
||||
0.0, 0.0, UnidirectionalSequenceLstm::kTANH, 0.0, false,
|
||||
UnidirectionalSequenceLstm::kSIGMOID, true);
|
||||
lstm_forward_ = graph->CreateOperation<UnidirectionalSequenceLstm>(
|
||||
cell_clip, proj_clip, act_type, forget_bias, time_major,
|
||||
recurrent_act_type, return_sequences);
|
||||
lstm_backward_ = graph->CreateOperation<UnidirectionalSequenceLstm>(
|
||||
cell_clip, proj_clip, act_type, forget_bias, time_major,
|
||||
recurrent_act_type, return_sequences);
|
||||
reverse_input_ = graph->CreateOperation<Reverse>(time_major ? std::vector<int32_t> ({2}) :
|
||||
std::vector<int32_t> ({1}));
|
||||
reverse_output_ = graph->CreateOperation<Reverse>(time_major ? std::vector<int32_t> ({2}) :
|
||||
std::vector<int32_t> ({1}));
|
||||
}
|
||||
|
||||
~BidirectionalSequenceLstmImpl() {}
|
||||
|
|
@ -142,6 +150,12 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
|
|||
if (this->input_tensor_index == INPUT_CNT - 1) {
|
||||
// Get all input tensor
|
||||
lstm_forward_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]);
|
||||
reverse_input_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]);
|
||||
TensorSpec bw_input_spec (in_tensors_[BI_LSTM_INPUT_INPUT]->GetDataType(),
|
||||
in_tensors_[BI_LSTM_INPUT_INPUT]->GetShape(),
|
||||
TensorAttribute::TRANSIENT);
|
||||
bw_input_tensor_ = graph_->CreateTensor(bw_input_spec);
|
||||
reverse_input_->BindOutput(bw_input_tensor_);
|
||||
lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_H_STATE]);
|
||||
lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_C_STATE]);
|
||||
|
||||
|
|
@ -172,7 +186,7 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
|
|||
lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_LAYERNORM_C]);
|
||||
lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_LAYERNORM_O]);
|
||||
|
||||
lstm_backward_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]);
|
||||
lstm_backward_->BindInput(bw_input_tensor_);
|
||||
lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_H_STATE]);
|
||||
lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_C_STATE]);
|
||||
|
||||
|
|
@ -216,7 +230,10 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
|
|||
lstm_forward_->BindOutput(out_tensors_[BI_LSTM_FW_OUTPUT_H_STATE]);
|
||||
lstm_forward_->BindOutput(out_tensors_[BI_LSTM_FW_OUTPUT_C_STATE]);
|
||||
|
||||
lstm_backward_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_OUTPUT]);
|
||||
bw_output_tensor_ = graph_->CreateTensor(out_tensors_[BI_LSTM_BW_OUTPUT_OUTPUT]->GetSpec());
|
||||
lstm_backward_->BindOutput(bw_output_tensor_);
|
||||
reverse_output_->BindInput(bw_output_tensor_);
|
||||
reverse_output_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_OUTPUT]);
|
||||
lstm_backward_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_H_STATE]);
|
||||
lstm_backward_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_C_STATE]);
|
||||
}
|
||||
|
|
@ -236,11 +253,33 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
|
|||
private:
|
||||
std::shared_ptr<tim::vx::Operation> lstm_forward_;
|
||||
std::shared_ptr<tim::vx::Operation> lstm_backward_;
|
||||
std::shared_ptr<tim::vx::Operation> reverse_input_;
|
||||
std::shared_ptr<tim::vx::Operation> reverse_output_;
|
||||
|
||||
std::array<std::shared_ptr<tim::vx::Tensor>, INPUT_CNT> in_tensors_;
|
||||
std::array<std::shared_ptr<tim::vx::Tensor>, OUTPUT_CNT> out_tensors_;
|
||||
std::shared_ptr<Tensor> bw_input_tensor_;
|
||||
std::shared_ptr<Tensor> bw_output_tensor_;
|
||||
};
|
||||
|
||||
UnidirectionalSequenceLstm::ActivationType interpreter(BidirectionalSequenceLstm::ActivationType act){
|
||||
switch (act){
|
||||
|
||||
case BidirectionalSequenceLstm::ActivationType::kRELU:
|
||||
return UnidirectionalSequenceLstm::ActivationType::kRELU;
|
||||
case BidirectionalSequenceLstm::ActivationType::kRELU6:
|
||||
return UnidirectionalSequenceLstm::ActivationType::kRELU6;
|
||||
case BidirectionalSequenceLstm::ActivationType::kTANH:
|
||||
return UnidirectionalSequenceLstm::ActivationType::kTANH;
|
||||
case BidirectionalSequenceLstm::ActivationType::kSIGMOID:
|
||||
return UnidirectionalSequenceLstm::ActivationType::kSIGMOID;
|
||||
case BidirectionalSequenceLstm::ActivationType::kHARDSIGMOID:
|
||||
return UnidirectionalSequenceLstm::ActivationType::kHARDSIGMOID;
|
||||
default: {
|
||||
return UnidirectionalSequenceLstm::ActivationType::kNONE;
|
||||
}
|
||||
}
|
||||
}
|
||||
BidirectionalSequenceLstm::BidirectionalSequenceLstm(
|
||||
Graph* graph, float cell_clip, float proj_clip, ActivationType act_type,
|
||||
float forget_bias, bool time_major, ActivationType recurrent_act_type,
|
||||
|
|
@ -252,8 +291,11 @@ BidirectionalSequenceLstm::BidirectionalSequenceLstm(
|
|||
time_major_(time_major),
|
||||
recurrent_act_type_(recurrent_act_type),
|
||||
return_sequences_(return_sequences) {
|
||||
impl_ = std::make_unique<BidirectionalSequenceLstmImpl>(graph, 0, 0,
|
||||
DataLayout::ANY);
|
||||
impl_ = std::make_unique<BidirectionalSequenceLstmImpl>(graph, 0, 0, cell_clip_,
|
||||
proj_clip_, interpreter(act_type_),
|
||||
forget_bias_,time_major_,
|
||||
interpreter(recurrent_act_type_),
|
||||
return_sequences_, DataLayout::ANY);
|
||||
}
|
||||
|
||||
std::shared_ptr<Operation> BidirectionalSequenceLstm::Clone(
|
||||
|
|
|
|||
|
|
@ -288,9 +288,9 @@ TEST(Bidirectional_LSTM_CELL, shape_in_2_cell_4_out_4_float32) {
|
|||
-0.03716109, 0.12507336, 0.41193449, -0.20860538,
|
||||
-0.15053082, 0.09120187, 0.24278517, -0.12222792};
|
||||
std::vector<float> lstm_bw_golden_output = {
|
||||
-0.02973187, 0.1229473, 0.20885126, -0.15358765,
|
||||
-0.03716109, 0.12507336, 0.41193449, -0.20860538,
|
||||
-0.15053082, 0.09120187, 0.24278517, -0.12222792};
|
||||
-0.0806187, 0.139077, 0.400476, -0.197842,
|
||||
-0.0332076, 0.123838, 0.309777, -0.17621,
|
||||
-0.0490733, 0.0739237, 0.067706, -0.0208124};
|
||||
std::vector<float> fw_output(lstm_fw_golden_output.size());
|
||||
std::vector<float> bw_output(lstm_bw_golden_output.size());
|
||||
fw_output_tensor->CopyDataFromTensor(fw_output.data());
|
||||
|
|
|
|||
Loading…
Reference in New Issue