Fixed BidirectionalSequenceLSTM bug
Fixed input error of the backward direction Fixed golden error of unit test Type: Bug Fix Signed-off-by: Feiyue Chen <Feiyue.Chen@verisilicon.com>
This commit is contained in:
parent
4db479ece4
commit
11fd278d7a
|
|
@ -93,5 +93,6 @@
|
||||||
#include "tim/vx/ops/conv3d.h"
|
#include "tim/vx/ops/conv3d.h"
|
||||||
#include "tim/vx/ops/custom_base.h"
|
#include "tim/vx/ops/custom_base.h"
|
||||||
#include "tim/vx/ops/topk.h"
|
#include "tim/vx/ops/topk.h"
|
||||||
|
#include "tim/vx/ops/bidirectional_sequence_lstm.h"
|
||||||
|
|
||||||
#endif /* TIM_VX_OPS_H_ */
|
#endif /* TIM_VX_OPS_H_ */
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@
|
||||||
*****************************************************************************/
|
*****************************************************************************/
|
||||||
#include "tim/vx/ops/bidirectional_sequence_lstm.h"
|
#include "tim/vx/ops/bidirectional_sequence_lstm.h"
|
||||||
#include "tim/vx/ops/unidirectional_sequence_lstm.h"
|
#include "tim/vx/ops/unidirectional_sequence_lstm.h"
|
||||||
|
#include "tim/vx/ops/reverse.h"
|
||||||
#include "vsi_nn_pub.h"
|
#include "vsi_nn_pub.h"
|
||||||
#include "op_impl.h"
|
#include "op_impl.h"
|
||||||
|
|
||||||
|
|
@ -122,15 +123,22 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
|
||||||
};
|
};
|
||||||
|
|
||||||
BidirectionalSequenceLstmImpl(Graph* graph, int input_cnt, int output_cnt,
|
BidirectionalSequenceLstmImpl(Graph* graph, int input_cnt, int output_cnt,
|
||||||
DataLayout layout = DataLayout::ANY)
|
float cell_clip, float proj_clip,
|
||||||
|
tim::vx::ops::UnidirectionalSequenceLstm::ActivationType act_type,
|
||||||
|
float forget_bias, bool time_major,
|
||||||
|
tim::vx::ops::UnidirectionalSequenceLstm::ActivationType recurrent_act_type,
|
||||||
|
bool return_sequences, DataLayout layout = DataLayout::ANY)
|
||||||
: OpImpl(graph, -1, input_cnt, output_cnt, layout) {
|
: OpImpl(graph, -1, input_cnt, output_cnt, layout) {
|
||||||
lstm_forward_ = graph->CreateOperation<UnidirectionalSequenceLstm>(
|
lstm_forward_ = graph->CreateOperation<UnidirectionalSequenceLstm>(
|
||||||
0.0, 0.0, UnidirectionalSequenceLstm::kTANH, 0.0, false,
|
cell_clip, proj_clip, act_type, forget_bias, time_major,
|
||||||
UnidirectionalSequenceLstm::kSIGMOID, true);
|
recurrent_act_type, return_sequences);
|
||||||
lstm_backward_ =
|
lstm_backward_ = graph->CreateOperation<UnidirectionalSequenceLstm>(
|
||||||
graph->CreateOperation<tim::vx::ops::UnidirectionalSequenceLstm>(
|
cell_clip, proj_clip, act_type, forget_bias, time_major,
|
||||||
0.0, 0.0, UnidirectionalSequenceLstm::kTANH, 0.0, false,
|
recurrent_act_type, return_sequences);
|
||||||
UnidirectionalSequenceLstm::kSIGMOID, true);
|
reverse_input_ = graph->CreateOperation<Reverse>(time_major ? std::vector<int32_t> ({2}) :
|
||||||
|
std::vector<int32_t> ({1}));
|
||||||
|
reverse_output_ = graph->CreateOperation<Reverse>(time_major ? std::vector<int32_t> ({2}) :
|
||||||
|
std::vector<int32_t> ({1}));
|
||||||
}
|
}
|
||||||
|
|
||||||
~BidirectionalSequenceLstmImpl() {}
|
~BidirectionalSequenceLstmImpl() {}
|
||||||
|
|
@ -142,6 +150,12 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
|
||||||
if (this->input_tensor_index == INPUT_CNT - 1) {
|
if (this->input_tensor_index == INPUT_CNT - 1) {
|
||||||
// Get all input tensor
|
// Get all input tensor
|
||||||
lstm_forward_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]);
|
lstm_forward_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]);
|
||||||
|
reverse_input_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]);
|
||||||
|
TensorSpec bw_input_spec (in_tensors_[BI_LSTM_INPUT_INPUT]->GetDataType(),
|
||||||
|
in_tensors_[BI_LSTM_INPUT_INPUT]->GetShape(),
|
||||||
|
TensorAttribute::TRANSIENT);
|
||||||
|
bw_input_tensor_ = graph_->CreateTensor(bw_input_spec);
|
||||||
|
reverse_input_->BindOutput(bw_input_tensor_);
|
||||||
lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_H_STATE]);
|
lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_H_STATE]);
|
||||||
lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_C_STATE]);
|
lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_C_STATE]);
|
||||||
|
|
||||||
|
|
@ -172,7 +186,7 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
|
||||||
lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_LAYERNORM_C]);
|
lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_LAYERNORM_C]);
|
||||||
lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_LAYERNORM_O]);
|
lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_LAYERNORM_O]);
|
||||||
|
|
||||||
lstm_backward_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]);
|
lstm_backward_->BindInput(bw_input_tensor_);
|
||||||
lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_H_STATE]);
|
lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_H_STATE]);
|
||||||
lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_C_STATE]);
|
lstm_backward_->BindInput(in_tensors_[BI_LSTM_BW_INPUT_C_STATE]);
|
||||||
|
|
||||||
|
|
@ -216,7 +230,10 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
|
||||||
lstm_forward_->BindOutput(out_tensors_[BI_LSTM_FW_OUTPUT_H_STATE]);
|
lstm_forward_->BindOutput(out_tensors_[BI_LSTM_FW_OUTPUT_H_STATE]);
|
||||||
lstm_forward_->BindOutput(out_tensors_[BI_LSTM_FW_OUTPUT_C_STATE]);
|
lstm_forward_->BindOutput(out_tensors_[BI_LSTM_FW_OUTPUT_C_STATE]);
|
||||||
|
|
||||||
lstm_backward_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_OUTPUT]);
|
bw_output_tensor_ = graph_->CreateTensor(out_tensors_[BI_LSTM_BW_OUTPUT_OUTPUT]->GetSpec());
|
||||||
|
lstm_backward_->BindOutput(bw_output_tensor_);
|
||||||
|
reverse_output_->BindInput(bw_output_tensor_);
|
||||||
|
reverse_output_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_OUTPUT]);
|
||||||
lstm_backward_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_H_STATE]);
|
lstm_backward_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_H_STATE]);
|
||||||
lstm_backward_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_C_STATE]);
|
lstm_backward_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_C_STATE]);
|
||||||
}
|
}
|
||||||
|
|
@ -236,11 +253,33 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<tim::vx::Operation> lstm_forward_;
|
std::shared_ptr<tim::vx::Operation> lstm_forward_;
|
||||||
std::shared_ptr<tim::vx::Operation> lstm_backward_;
|
std::shared_ptr<tim::vx::Operation> lstm_backward_;
|
||||||
|
std::shared_ptr<tim::vx::Operation> reverse_input_;
|
||||||
|
std::shared_ptr<tim::vx::Operation> reverse_output_;
|
||||||
|
|
||||||
std::array<std::shared_ptr<tim::vx::Tensor>, INPUT_CNT> in_tensors_;
|
std::array<std::shared_ptr<tim::vx::Tensor>, INPUT_CNT> in_tensors_;
|
||||||
std::array<std::shared_ptr<tim::vx::Tensor>, OUTPUT_CNT> out_tensors_;
|
std::array<std::shared_ptr<tim::vx::Tensor>, OUTPUT_CNT> out_tensors_;
|
||||||
|
std::shared_ptr<Tensor> bw_input_tensor_;
|
||||||
|
std::shared_ptr<Tensor> bw_output_tensor_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
UnidirectionalSequenceLstm::ActivationType interpreter(BidirectionalSequenceLstm::ActivationType act){
|
||||||
|
switch (act){
|
||||||
|
|
||||||
|
case BidirectionalSequenceLstm::ActivationType::kRELU:
|
||||||
|
return UnidirectionalSequenceLstm::ActivationType::kRELU;
|
||||||
|
case BidirectionalSequenceLstm::ActivationType::kRELU6:
|
||||||
|
return UnidirectionalSequenceLstm::ActivationType::kRELU6;
|
||||||
|
case BidirectionalSequenceLstm::ActivationType::kTANH:
|
||||||
|
return UnidirectionalSequenceLstm::ActivationType::kTANH;
|
||||||
|
case BidirectionalSequenceLstm::ActivationType::kSIGMOID:
|
||||||
|
return UnidirectionalSequenceLstm::ActivationType::kSIGMOID;
|
||||||
|
case BidirectionalSequenceLstm::ActivationType::kHARDSIGMOID:
|
||||||
|
return UnidirectionalSequenceLstm::ActivationType::kHARDSIGMOID;
|
||||||
|
default: {
|
||||||
|
return UnidirectionalSequenceLstm::ActivationType::kNONE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
BidirectionalSequenceLstm::BidirectionalSequenceLstm(
|
BidirectionalSequenceLstm::BidirectionalSequenceLstm(
|
||||||
Graph* graph, float cell_clip, float proj_clip, ActivationType act_type,
|
Graph* graph, float cell_clip, float proj_clip, ActivationType act_type,
|
||||||
float forget_bias, bool time_major, ActivationType recurrent_act_type,
|
float forget_bias, bool time_major, ActivationType recurrent_act_type,
|
||||||
|
|
@ -252,8 +291,11 @@ BidirectionalSequenceLstm::BidirectionalSequenceLstm(
|
||||||
time_major_(time_major),
|
time_major_(time_major),
|
||||||
recurrent_act_type_(recurrent_act_type),
|
recurrent_act_type_(recurrent_act_type),
|
||||||
return_sequences_(return_sequences) {
|
return_sequences_(return_sequences) {
|
||||||
impl_ = std::make_unique<BidirectionalSequenceLstmImpl>(graph, 0, 0,
|
impl_ = std::make_unique<BidirectionalSequenceLstmImpl>(graph, 0, 0, cell_clip_,
|
||||||
DataLayout::ANY);
|
proj_clip_, interpreter(act_type_),
|
||||||
|
forget_bias_,time_major_,
|
||||||
|
interpreter(recurrent_act_type_),
|
||||||
|
return_sequences_, DataLayout::ANY);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Operation> BidirectionalSequenceLstm::Clone(
|
std::shared_ptr<Operation> BidirectionalSequenceLstm::Clone(
|
||||||
|
|
|
||||||
|
|
@ -288,9 +288,9 @@ TEST(Bidirectional_LSTM_CELL, shape_in_2_cell_4_out_4_float32) {
|
||||||
-0.03716109, 0.12507336, 0.41193449, -0.20860538,
|
-0.03716109, 0.12507336, 0.41193449, -0.20860538,
|
||||||
-0.15053082, 0.09120187, 0.24278517, -0.12222792};
|
-0.15053082, 0.09120187, 0.24278517, -0.12222792};
|
||||||
std::vector<float> lstm_bw_golden_output = {
|
std::vector<float> lstm_bw_golden_output = {
|
||||||
-0.02973187, 0.1229473, 0.20885126, -0.15358765,
|
-0.0806187, 0.139077, 0.400476, -0.197842,
|
||||||
-0.03716109, 0.12507336, 0.41193449, -0.20860538,
|
-0.0332076, 0.123838, 0.309777, -0.17621,
|
||||||
-0.15053082, 0.09120187, 0.24278517, -0.12222792};
|
-0.0490733, 0.0739237, 0.067706, -0.0208124};
|
||||||
std::vector<float> fw_output(lstm_fw_golden_output.size());
|
std::vector<float> fw_output(lstm_fw_golden_output.size());
|
||||||
std::vector<float> bw_output(lstm_bw_golden_output.size());
|
std::vector<float> bw_output(lstm_bw_golden_output.size());
|
||||||
fw_output_tensor->CopyDataFromTensor(fw_output.data());
|
fw_output_tensor->CopyDataFromTensor(fw_output.data());
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue