From 535c9da8674d0676ac19fa1594326661d63d9072 Mon Sep 17 00:00:00 2001 From: Chen Xin Date: Wed, 28 Sep 2022 14:53:05 +0800 Subject: [PATCH] Fixed bug when input's index is not 0 Signed-off-by: Chen Xin --- src/tim/transform/ops/stack_layout_inference.h | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/tim/transform/ops/stack_layout_inference.h b/src/tim/transform/ops/stack_layout_inference.h index 7de26bd..f8e8fc4 100644 --- a/src/tim/transform/ops/stack_layout_inference.h +++ b/src/tim/transform/ops/stack_layout_inference.h @@ -41,9 +41,12 @@ class StackLayoutInfer : public OpLayoutInfer { : OpLayoutInfer(op, context) {} void OnInputs( std::vector>& next_tensors) override { - auto src_input = op_->impl()->InputsTensor()[0]; - auto input_pv = context_->GetPermuteVector(src_input); - + auto src_inputs = op_->impl()->InputsTensor(); + std::shared_ptr normal_input; + int input_cnt=0; + for(; src_inputs[input_cnt]->IsConstTensor(); ++input_cnt); + normal_input = src_inputs[input_cnt]; + auto input_pv = context_->GetPermuteVector(src_inputs[input_cnt]); int32_t axis = op_->impl()->node()->nn_param.stack.axis; auto stack = context_->infer_graph_->CreateOperation( axis, op_->impl()->input_cnt_); @@ -54,7 +57,7 @@ class StackLayoutInfer : public OpLayoutInfer { } if (axis < 0) { - axis += src_input->GetShape().size(); + axis += normal_input->GetShape().size(); } auto output_pv = MakeShared(input_pv->Rank() + 1);