Fixed bug when input's index is not 0

Signed-off-by: Chen Xin <jack.chen@verisilicon.com>
This commit is contained in:
Chen Xin 2022-09-28 14:53:05 +08:00 committed by Sven
parent 4c6299e7fd
commit 535c9da867
1 changed files with 7 additions and 4 deletions

View File

@ -41,9 +41,12 @@ class StackLayoutInfer : public OpLayoutInfer {
: OpLayoutInfer(op, context) {}
void OnInputs(
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
auto src_input = op_->impl()->InputsTensor()[0];
auto input_pv = context_->GetPermuteVector(src_input);
auto src_inputs = op_->impl()->InputsTensor();
std::shared_ptr<tim::vx::Tensor> normal_input;
int input_cnt=0;
for(; src_inputs[input_cnt]->IsConstTensor(); ++input_cnt);
normal_input = src_inputs[input_cnt];
auto input_pv = context_->GetPermuteVector(src_inputs[input_cnt]);
int32_t axis = op_->impl()->node()->nn_param.stack.axis;
auto stack = context_->infer_graph_->CreateOperation<vx::ops::Stack>(
axis, op_->impl()->input_cnt_);
@ -54,7 +57,7 @@ class StackLayoutInfer : public OpLayoutInfer {
}
if (axis < 0) {
axis += src_input->GetShape().size();
axis += normal_input->GetShape().size();
}
auto output_pv = MakeShared(input_pv->Rank() + 1);