diff --git a/src/tim/vx/ops/bidirectional_sequence_lstm.cc b/src/tim/vx/ops/bidirectional_sequence_lstm.cc index ebf9d8e..7a8745e 100644 --- a/src/tim/vx/ops/bidirectional_sequence_lstm.cc +++ b/src/tim/vx/ops/bidirectional_sequence_lstm.cc @@ -124,7 +124,7 @@ class BidirectionalSequenceLstmImpl : public OpImpl { BidirectionalSequenceLstmImpl(Graph* graph, int input_cnt, int output_cnt, float cell_clip, float proj_clip, - tim::vx::ops::UnidirectionalSequenceLstm::ActivationType act_type, + tim::vx::ops::UnidirectionalSequenceLstm::ActivationType act_type, float forget_bias, bool time_major, tim::vx::ops::UnidirectionalSequenceLstm::ActivationType recurrent_act_type, bool return_sequences, DataLayout layout = DataLayout::ANY) @@ -135,9 +135,9 @@ class BidirectionalSequenceLstmImpl : public OpImpl { lstm_backward_ = graph->CreateOperation( cell_clip, proj_clip, act_type, forget_bias, time_major, recurrent_act_type, return_sequences); - reverse_input_ = graph->CreateOperation(time_major ? std::vector ({2}) : + reverse_input_ = graph->CreateOperation(time_major ? std::vector ({2}) : std::vector ({1})); - reverse_output_ = graph->CreateOperation(time_major ? std::vector ({2}) : + reverse_output_ = graph->CreateOperation(time_major ? std::vector ({2}) : std::vector ({1})); } @@ -151,10 +151,8 @@ class BidirectionalSequenceLstmImpl : public OpImpl { // Get all input tensor lstm_forward_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]); reverse_input_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]); - TensorSpec bw_input_spec (in_tensors_[BI_LSTM_INPUT_INPUT]->GetDataType(), - in_tensors_[BI_LSTM_INPUT_INPUT]->GetShape(), - TensorAttribute::TRANSIENT); - bw_input_tensor_ = graph_->CreateTensor(bw_input_spec); + TensorSpec bw_input_spec (in_tensors_[BI_LSTM_INPUT_INPUT]->GetSpec()); + bw_input_tensor_ = graph_->CreateTensor(bw_input_spec.AsTransientSpec()); reverse_input_->BindOutput(bw_input_tensor_); lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_H_STATE]); lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_C_STATE]); @@ -232,7 +230,7 @@ class BidirectionalSequenceLstmImpl : public OpImpl { bw_output_tensor_ = graph_->CreateTensor(out_tensors_[BI_LSTM_BW_OUTPUT_OUTPUT]->GetSpec()); lstm_backward_->BindOutput(bw_output_tensor_); - reverse_output_->BindInput(bw_output_tensor_); + reverse_output_->BindInput(bw_output_tensor_); reverse_output_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_OUTPUT]); lstm_backward_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_H_STATE]); lstm_backward_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_C_STATE]); @@ -264,7 +262,7 @@ class BidirectionalSequenceLstmImpl : public OpImpl { UnidirectionalSequenceLstm::ActivationType interpreter(BidirectionalSequenceLstm::ActivationType act){ switch (act){ - + case BidirectionalSequenceLstm::ActivationType::kRELU: return UnidirectionalSequenceLstm::ActivationType::kRELU; case BidirectionalSequenceLstm::ActivationType::kRELU6: @@ -291,7 +289,7 @@ BidirectionalSequenceLstm::BidirectionalSequenceLstm( time_major_(time_major), recurrent_act_type_(recurrent_act_type), return_sequences_(return_sequences) { - impl_ = std::make_unique(graph, 0, 0, cell_clip_, + impl_ = std::make_unique(graph, 0, 0, cell_clip_, proj_clip_, interpreter(act_type_), forget_bias_,time_major_, interpreter(recurrent_act_type_),