diff --git a/src/tim/transform/ops/stack_layout_inference.h b/src/tim/transform/ops/stack_layout_inference.h index 7de26bd..f8e8fc4 100644 --- a/src/tim/transform/ops/stack_layout_inference.h +++ b/src/tim/transform/ops/stack_layout_inference.h @@ -41,9 +41,12 @@ class StackLayoutInfer : public OpLayoutInfer { : OpLayoutInfer(op, context) {} void OnInputs( std::vector>& next_tensors) override { - auto src_input = op_->impl()->InputsTensor()[0]; - auto input_pv = context_->GetPermuteVector(src_input); - + auto src_inputs = op_->impl()->InputsTensor(); + std::shared_ptr normal_input; + int input_cnt=0; + for(; src_inputs[input_cnt]->IsConstTensor(); ++input_cnt); + normal_input = src_inputs[input_cnt]; + auto input_pv = context_->GetPermuteVector(src_inputs[input_cnt]); int32_t axis = op_->impl()->node()->nn_param.stack.axis; auto stack = context_->infer_graph_->CreateOperation( axis, op_->impl()->input_cnt_); @@ -54,7 +57,7 @@ class StackLayoutInfer : public OpLayoutInfer { } if (axis < 0) { - axis += src_input->GetShape().size(); + axis += normal_input->GetShape().size(); } auto output_pv = MakeShared(input_pv->Rank() + 1);