Fixed bug when input's index is not 0
Signed-off-by: Chen Xin <jack.chen@verisilicon.com>
This commit is contained in:
parent
4c6299e7fd
commit
535c9da867
|
|
@ -41,9 +41,12 @@ class StackLayoutInfer : public OpLayoutInfer {
|
|||
: OpLayoutInfer(op, context) {}
|
||||
void OnInputs(
|
||||
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
|
||||
auto src_input = op_->impl()->InputsTensor()[0];
|
||||
auto input_pv = context_->GetPermuteVector(src_input);
|
||||
|
||||
auto src_inputs = op_->impl()->InputsTensor();
|
||||
std::shared_ptr<tim::vx::Tensor> normal_input;
|
||||
int input_cnt=0;
|
||||
for(; src_inputs[input_cnt]->IsConstTensor(); ++input_cnt);
|
||||
normal_input = src_inputs[input_cnt];
|
||||
auto input_pv = context_->GetPermuteVector(src_inputs[input_cnt]);
|
||||
int32_t axis = op_->impl()->node()->nn_param.stack.axis;
|
||||
auto stack = context_->infer_graph_->CreateOperation<vx::ops::Stack>(
|
||||
axis, op_->impl()->input_cnt_);
|
||||
|
|
@ -54,7 +57,7 @@ class StackLayoutInfer : public OpLayoutInfer {
|
|||
}
|
||||
|
||||
if (axis < 0) {
|
||||
axis += src_input->GetShape().size();
|
||||
axis += normal_input->GetShape().size();
|
||||
}
|
||||
|
||||
auto output_pv = MakeShared(input_pv->Rank() + 1);
|
||||
|
|
|
|||
Loading…
Reference in New Issue