Fixed quant param lost in Bidirectional lstm (#649)

https://github.com/VeriSilicon/TIM-VX/issues/647

Type: Bug fix

Signed-off-by: Chen <jack.chen@verisilicon.com>
This commit is contained in:
chxin66 2023-09-19 22:08:34 +08:00 committed by GitHub
parent 61ea0091ca
commit 363c369bf6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 10 deletions

View File

@ -151,10 +151,8 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
// Get all input tensor
lstm_forward_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]);
reverse_input_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]);
TensorSpec bw_input_spec (in_tensors_[BI_LSTM_INPUT_INPUT]->GetDataType(),
in_tensors_[BI_LSTM_INPUT_INPUT]->GetShape(),
TensorAttribute::TRANSIENT);
bw_input_tensor_ = graph_->CreateTensor(bw_input_spec);
TensorSpec bw_input_spec (in_tensors_[BI_LSTM_INPUT_INPUT]->GetSpec());
bw_input_tensor_ = graph_->CreateTensor(bw_input_spec.AsTransientSpec());
reverse_input_->BindOutput(bw_input_tensor_);
lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_H_STATE]);
lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_C_STATE]);