From 11572140d2c8cf7d446f3c825201f9ff2ce88bc3 Mon Sep 17 00:00:00 2001 From: chxin66 <57057788+chxin66@users.noreply.github.com> Date: Thu, 5 May 2022 17:18:09 +0800 Subject: [PATCH] Fixed layout inference bug for stack (#375) Signed-off-by: Chen Xin --- .../transform/ops/stack_layout_inference.h | 38 +++++++++---------- .../transform/stack_layout_inference_test.cc | 4 +- 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/src/tim/transform/ops/stack_layout_inference.h b/src/tim/transform/ops/stack_layout_inference.h index 6a5338c..71e4a10 100644 --- a/src/tim/transform/ops/stack_layout_inference.h +++ b/src/tim/transform/ops/stack_layout_inference.h @@ -53,30 +53,28 @@ class StackLayoutInfer : public OpLayoutInfer { (*stack).BindInput(context_->GetMapedTensor(i_src)); } - std::vector v; - uint32_t dim_num = src_input->GetShape().size(); if (axis < 0) { - axis += dim_num; - } - for (uint32_t i = 0; i < src_input->GetShape().size(); ++i) { - if (input_pv->At(i) > (uint32_t)axis) { - v.push_back(input_pv->At(i) + 1); - } else if (input_pv->At(i) == (uint32_t)axis) { - v.push_back(input_pv->At(i)); - v.push_back(input_pv->At(i) + 1); - } else { - v.push_back(input_pv->At(i)); - } - } - auto out_pv = - MakeShared(op_->impl()->OutputsTensor()[0]->GetShape().size()); - for (uint32_t i = 0; i < out_pv->Rank(); ++i) { - out_pv->At(i) = v[i]; + axis += src_input->GetShape().size(); } - auto out_infer = CreateOutputsTensor(out_pv); + auto output_pv = MakeShared(input_pv->Rank() + 1); + if (!input_pv->IsAligned()) { + output_pv->At(axis) = (uint32_t)axis; + for (uint32_t i = 0, j = 0; i < input_pv->Rank(); ++i, ++j) { + if ((uint32_t)axis == i) { + ++j; + } + if (input_pv->At(i) < (uint32_t)axis) { + output_pv->At(j) = input_pv->At(i); + } else { + output_pv->At(j) = input_pv->At(i) + 1; + } + } + } + + auto out_infer = CreateOutputsTensor(output_pv); (*stack).BindOutput(out_infer[0]); - context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], out_pv); + context_->SetPermuteVector(op_->impl()->OutputsTensor()[0], output_pv); // Add out tensor of src_graph into next_tensor next_tensors.push_back(op_->impl()->OutputsTensor()[0]); } diff --git a/src/tim/transform/stack_layout_inference_test.cc b/src/tim/transform/stack_layout_inference_test.cc index b045246..b84c3a5 100644 --- a/src/tim/transform/stack_layout_inference_test.cc +++ b/src/tim/transform/stack_layout_inference_test.cc @@ -38,7 +38,7 @@ TEST(Stack, LayoutinferernceTest_1) { 1, 1, 1, 1, 2, 1, 1, 5, 3, 1, 2, 3, 1, 1, 2, 1, 1, 1, }; std::vector golden = { - 64, 77, 49, 44, 81, 97, 64, 77, 49, 44, 81, 97 + 64, 64, 49, 49, 81, 81, 77, 77, 44, 44, 97, 97 }; auto kernel_tensor = graph->CreateTensor(kernel_spec, kernel_data.data()); @@ -180,7 +180,7 @@ TEST(Stack, LayoutinferernceTest_3) { 1, 1, 1, 1, 2, 1, 1, 5, 3, 1, 2, 3, 1, 1, 2, 1, 1, 1, }; std::vector golden = { - 55, 39, 21, 28, 37, 41, 49, 55, 28, 24, 40, 41 + 55, 49, 21, 28, 37, 40, 39, 55, 28, 24, 41, 41, }; auto kernel_tensor = graph->CreateTensor(kernel_spec, kernel_data.data()); auto kernel2_tensor = graph->CreateTensor(kernel_spec, kernel_data.data());