Fixed bug when input's index is not 0
Signed-off-by: Chen Xin <jack.chen@verisilicon.com>
This commit is contained in:
parent
4c6299e7fd
commit
535c9da867
|
|
@ -41,9 +41,12 @@ class StackLayoutInfer : public OpLayoutInfer {
|
||||||
: OpLayoutInfer(op, context) {}
|
: OpLayoutInfer(op, context) {}
|
||||||
void OnInputs(
|
void OnInputs(
|
||||||
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
|
std::vector<std::shared_ptr<vx::Tensor>>& next_tensors) override {
|
||||||
auto src_input = op_->impl()->InputsTensor()[0];
|
auto src_inputs = op_->impl()->InputsTensor();
|
||||||
auto input_pv = context_->GetPermuteVector(src_input);
|
std::shared_ptr<tim::vx::Tensor> normal_input;
|
||||||
|
int input_cnt=0;
|
||||||
|
for(; src_inputs[input_cnt]->IsConstTensor(); ++input_cnt);
|
||||||
|
normal_input = src_inputs[input_cnt];
|
||||||
|
auto input_pv = context_->GetPermuteVector(src_inputs[input_cnt]);
|
||||||
int32_t axis = op_->impl()->node()->nn_param.stack.axis;
|
int32_t axis = op_->impl()->node()->nn_param.stack.axis;
|
||||||
auto stack = context_->infer_graph_->CreateOperation<vx::ops::Stack>(
|
auto stack = context_->infer_graph_->CreateOperation<vx::ops::Stack>(
|
||||||
axis, op_->impl()->input_cnt_);
|
axis, op_->impl()->input_cnt_);
|
||||||
|
|
@ -54,7 +57,7 @@ class StackLayoutInfer : public OpLayoutInfer {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (axis < 0) {
|
if (axis < 0) {
|
||||||
axis += src_input->GetShape().size();
|
axis += normal_input->GetShape().size();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto output_pv = MakeShared(input_pv->Rank() + 1);
|
auto output_pv = MakeShared(input_pv->Rank() + 1);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue