Fixed quant param lost in Bidirectional lstm (#649)

https://github.com/VeriSilicon/TIM-VX/issues/647

Type: Bug fix

Signed-off-by: Chen <jack.chen@verisilicon.com>
This commit is contained in:
chxin66 2023-09-19 22:08:34 +08:00 committed by GitHub
parent 61ea0091ca
commit 363c369bf6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 10 deletions

View File

@ -124,7 +124,7 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
BidirectionalSequenceLstmImpl(Graph* graph, int input_cnt, int output_cnt,
float cell_clip, float proj_clip,
tim::vx::ops::UnidirectionalSequenceLstm::ActivationType act_type,
tim::vx::ops::UnidirectionalSequenceLstm::ActivationType act_type,
float forget_bias, bool time_major,
tim::vx::ops::UnidirectionalSequenceLstm::ActivationType recurrent_act_type,
bool return_sequences, DataLayout layout = DataLayout::ANY)
@ -135,9 +135,9 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
lstm_backward_ = graph->CreateOperation<UnidirectionalSequenceLstm>(
cell_clip, proj_clip, act_type, forget_bias, time_major,
recurrent_act_type, return_sequences);
reverse_input_ = graph->CreateOperation<Reverse>(time_major ? std::vector<int32_t> ({2}) :
reverse_input_ = graph->CreateOperation<Reverse>(time_major ? std::vector<int32_t> ({2}) :
std::vector<int32_t> ({1}));
reverse_output_ = graph->CreateOperation<Reverse>(time_major ? std::vector<int32_t> ({2}) :
reverse_output_ = graph->CreateOperation<Reverse>(time_major ? std::vector<int32_t> ({2}) :
std::vector<int32_t> ({1}));
}
@ -151,10 +151,8 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
// Get all input tensor
lstm_forward_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]);
reverse_input_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]);
TensorSpec bw_input_spec (in_tensors_[BI_LSTM_INPUT_INPUT]->GetDataType(),
in_tensors_[BI_LSTM_INPUT_INPUT]->GetShape(),
TensorAttribute::TRANSIENT);
bw_input_tensor_ = graph_->CreateTensor(bw_input_spec);
TensorSpec bw_input_spec (in_tensors_[BI_LSTM_INPUT_INPUT]->GetSpec());
bw_input_tensor_ = graph_->CreateTensor(bw_input_spec.AsTransientSpec());
reverse_input_->BindOutput(bw_input_tensor_);
lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_H_STATE]);
lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_C_STATE]);
@ -232,7 +230,7 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
bw_output_tensor_ = graph_->CreateTensor(out_tensors_[BI_LSTM_BW_OUTPUT_OUTPUT]->GetSpec());
lstm_backward_->BindOutput(bw_output_tensor_);
reverse_output_->BindInput(bw_output_tensor_);
reverse_output_->BindInput(bw_output_tensor_);
reverse_output_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_OUTPUT]);
lstm_backward_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_H_STATE]);
lstm_backward_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_C_STATE]);
@ -264,7 +262,7 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
UnidirectionalSequenceLstm::ActivationType interpreter(BidirectionalSequenceLstm::ActivationType act){
switch (act){
case BidirectionalSequenceLstm::ActivationType::kRELU:
return UnidirectionalSequenceLstm::ActivationType::kRELU;
case BidirectionalSequenceLstm::ActivationType::kRELU6:
@ -291,7 +289,7 @@ BidirectionalSequenceLstm::BidirectionalSequenceLstm(
time_major_(time_major),
recurrent_act_type_(recurrent_act_type),
return_sequences_(return_sequences) {
impl_ = std::make_unique<BidirectionalSequenceLstmImpl>(graph, 0, 0, cell_clip_,
impl_ = std::make_unique<BidirectionalSequenceLstmImpl>(graph, 0, 0, cell_clip_,
proj_clip_, interpreter(act_type_),
forget_bias_,time_major_,
interpreter(recurrent_act_type_),