Fixed quant param lost in Bidirectional lstm (#649)
https://github.com/VeriSilicon/TIM-VX/issues/647 Type: Bug fix Signed-off-by: Chen <jack.chen@verisilicon.com>
This commit is contained in:
parent
61ea0091ca
commit
363c369bf6
|
|
@ -151,10 +151,8 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
|
||||||
// Get all input tensor
|
// Get all input tensor
|
||||||
lstm_forward_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]);
|
lstm_forward_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]);
|
||||||
reverse_input_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]);
|
reverse_input_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]);
|
||||||
TensorSpec bw_input_spec (in_tensors_[BI_LSTM_INPUT_INPUT]->GetDataType(),
|
TensorSpec bw_input_spec (in_tensors_[BI_LSTM_INPUT_INPUT]->GetSpec());
|
||||||
in_tensors_[BI_LSTM_INPUT_INPUT]->GetShape(),
|
bw_input_tensor_ = graph_->CreateTensor(bw_input_spec.AsTransientSpec());
|
||||||
TensorAttribute::TRANSIENT);
|
|
||||||
bw_input_tensor_ = graph_->CreateTensor(bw_input_spec);
|
|
||||||
reverse_input_->BindOutput(bw_input_tensor_);
|
reverse_input_->BindOutput(bw_input_tensor_);
|
||||||
lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_H_STATE]);
|
lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_H_STATE]);
|
||||||
lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_C_STATE]);
|
lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_C_STATE]);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue