Fixed quant param lost in Bidirectional lstm (#649)
https://github.com/VeriSilicon/TIM-VX/issues/647 Type: Bug fix Signed-off-by: Chen <jack.chen@verisilicon.com>
This commit is contained in:
parent
61ea0091ca
commit
363c369bf6
|
|
@ -124,7 +124,7 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
|
|||
|
||||
BidirectionalSequenceLstmImpl(Graph* graph, int input_cnt, int output_cnt,
|
||||
float cell_clip, float proj_clip,
|
||||
tim::vx::ops::UnidirectionalSequenceLstm::ActivationType act_type,
|
||||
tim::vx::ops::UnidirectionalSequenceLstm::ActivationType act_type,
|
||||
float forget_bias, bool time_major,
|
||||
tim::vx::ops::UnidirectionalSequenceLstm::ActivationType recurrent_act_type,
|
||||
bool return_sequences, DataLayout layout = DataLayout::ANY)
|
||||
|
|
@ -135,9 +135,9 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
|
|||
lstm_backward_ = graph->CreateOperation<UnidirectionalSequenceLstm>(
|
||||
cell_clip, proj_clip, act_type, forget_bias, time_major,
|
||||
recurrent_act_type, return_sequences);
|
||||
reverse_input_ = graph->CreateOperation<Reverse>(time_major ? std::vector<int32_t> ({2}) :
|
||||
reverse_input_ = graph->CreateOperation<Reverse>(time_major ? std::vector<int32_t> ({2}) :
|
||||
std::vector<int32_t> ({1}));
|
||||
reverse_output_ = graph->CreateOperation<Reverse>(time_major ? std::vector<int32_t> ({2}) :
|
||||
reverse_output_ = graph->CreateOperation<Reverse>(time_major ? std::vector<int32_t> ({2}) :
|
||||
std::vector<int32_t> ({1}));
|
||||
}
|
||||
|
||||
|
|
@ -151,10 +151,8 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
|
|||
// Get all input tensor
|
||||
lstm_forward_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]);
|
||||
reverse_input_->BindInput(in_tensors_[BI_LSTM_INPUT_INPUT]);
|
||||
TensorSpec bw_input_spec (in_tensors_[BI_LSTM_INPUT_INPUT]->GetDataType(),
|
||||
in_tensors_[BI_LSTM_INPUT_INPUT]->GetShape(),
|
||||
TensorAttribute::TRANSIENT);
|
||||
bw_input_tensor_ = graph_->CreateTensor(bw_input_spec);
|
||||
TensorSpec bw_input_spec (in_tensors_[BI_LSTM_INPUT_INPUT]->GetSpec());
|
||||
bw_input_tensor_ = graph_->CreateTensor(bw_input_spec.AsTransientSpec());
|
||||
reverse_input_->BindOutput(bw_input_tensor_);
|
||||
lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_H_STATE]);
|
||||
lstm_forward_->BindInput(in_tensors_[BI_LSTM_FW_INPUT_C_STATE]);
|
||||
|
|
@ -232,7 +230,7 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
|
|||
|
||||
bw_output_tensor_ = graph_->CreateTensor(out_tensors_[BI_LSTM_BW_OUTPUT_OUTPUT]->GetSpec());
|
||||
lstm_backward_->BindOutput(bw_output_tensor_);
|
||||
reverse_output_->BindInput(bw_output_tensor_);
|
||||
reverse_output_->BindInput(bw_output_tensor_);
|
||||
reverse_output_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_OUTPUT]);
|
||||
lstm_backward_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_H_STATE]);
|
||||
lstm_backward_->BindOutput(out_tensors_[BI_LSTM_BW_OUTPUT_C_STATE]);
|
||||
|
|
@ -264,7 +262,7 @@ class BidirectionalSequenceLstmImpl : public OpImpl {
|
|||
|
||||
UnidirectionalSequenceLstm::ActivationType interpreter(BidirectionalSequenceLstm::ActivationType act){
|
||||
switch (act){
|
||||
|
||||
|
||||
case BidirectionalSequenceLstm::ActivationType::kRELU:
|
||||
return UnidirectionalSequenceLstm::ActivationType::kRELU;
|
||||
case BidirectionalSequenceLstm::ActivationType::kRELU6:
|
||||
|
|
@ -291,7 +289,7 @@ BidirectionalSequenceLstm::BidirectionalSequenceLstm(
|
|||
time_major_(time_major),
|
||||
recurrent_act_type_(recurrent_act_type),
|
||||
return_sequences_(return_sequences) {
|
||||
impl_ = std::make_unique<BidirectionalSequenceLstmImpl>(graph, 0, 0, cell_clip_,
|
||||
impl_ = std::make_unique<BidirectionalSequenceLstmImpl>(graph, 0, 0, cell_clip_,
|
||||
proj_clip_, interpreter(act_type_),
|
||||
forget_bias_,time_major_,
|
||||
interpreter(recurrent_act_type_),
|
||||
|
|
|
|||
Loading…
Reference in New Issue