From 680e8d59cb9762035c7f05af033477ee7c7b0c8f Mon Sep 17 00:00:00 2001 From: chxin66 <57057788+chxin66@users.noreply.github.com> Date: Mon, 24 Jul 2023 17:10:24 +0800 Subject: [PATCH] Fixed conv2d grouped_conv2d deconv2d layoutinfer bug (#622) Signed-off-by: Chen Co-authored-by: Chen --- src/tim/transform/layout_inference_test.cc | 3 ++- src/tim/transform/ops/conv2d_layout_inference.h | 4 ++-- src/tim/transform/ops/deconv2d_layout_inference.h | 4 ++-- .../transform/ops/grouped_conv2d_layout_inference.h | 4 ++-- src/tim/transform/pad_layout_inference_test.cc | 3 ++- src/tim/transform/stack_layout_inference_test.cc | 12 ++++++++---- .../transform/stridedslice_layout_inference_test.cc | 12 ++++++++---- 7 files changed, 26 insertions(+), 16 deletions(-) diff --git a/src/tim/transform/layout_inference_test.cc b/src/tim/transform/layout_inference_test.cc index d33e370..be00b68 100644 --- a/src/tim/transform/layout_inference_test.cc +++ b/src/tim/transform/layout_inference_test.cc @@ -36,7 +36,8 @@ TEST(LayoutInference, simple_conv2d) { kernel_shape[0], tim::vx::PadType::AUTO, std::array({kernel_shape[2], kernel_shape[1]}), std::array({1, 1}), std::array({0, 0}), - std::array({0, 0, 0, 0}), 0, tim::vx::DataLayout::CWHN); + std::array({0, 0, 0, 0}), 0, tim::vx::DataLayout::CWHN, + tim::vx::DataLayout::IcWHOc); (*conv2d).BindInputs({input, kernel, bias}).BindOutput(output); // Do layout inference auto transform = tim::transform::LayoutInference(src_graph, ctx); diff --git a/src/tim/transform/ops/conv2d_layout_inference.h b/src/tim/transform/ops/conv2d_layout_inference.h index 5b5b3a4..96b46ab 100644 --- a/src/tim/transform/ops/conv2d_layout_inference.h +++ b/src/tim/transform/ops/conv2d_layout_inference.h @@ -67,8 +67,8 @@ class Conv2dLayoutInfer : public OpLayoutInfer { case vx::DataLayout::IcWHOc: // Support nnapi & tflite Kernel Layout weight_required_pv = std::make_shared>(kIcWHOc2WHIcOc); break; - default: // Default set to IWHO for compatibility with previous APIs - weight_required_pv = std::make_shared>(kIcWHOc2WHIcOc); + default: + weight_required_pv = std::make_shared>(); break; } diff --git a/src/tim/transform/ops/deconv2d_layout_inference.h b/src/tim/transform/ops/deconv2d_layout_inference.h index 8fce5ab..8788c1d 100644 --- a/src/tim/transform/ops/deconv2d_layout_inference.h +++ b/src/tim/transform/ops/deconv2d_layout_inference.h @@ -67,8 +67,8 @@ class DeConv2dLayoutInfer : public OpLayoutInfer { case vx::DataLayout::IcWHOc: // Support nnapi & tflite Kernel Layout weight_required_pv = std::make_shared>(kIcWHOc2WHIcOc); break; - default: // Default set to IWHO for compatibility with previous APIs - weight_required_pv = std::make_shared>(kIcWHOc2WHIcOc); + default: + weight_required_pv = std::make_shared>(); break; } diff --git a/src/tim/transform/ops/grouped_conv2d_layout_inference.h b/src/tim/transform/ops/grouped_conv2d_layout_inference.h index f55e76d..b2df948 100644 --- a/src/tim/transform/ops/grouped_conv2d_layout_inference.h +++ b/src/tim/transform/ops/grouped_conv2d_layout_inference.h @@ -67,8 +67,8 @@ class GroupedConv2dLayoutInfer : public OpLayoutInfer { case vx::DataLayout::IcWHOc: // Support nnapi & tflite Kernel Layout weight_required_pv = std::make_shared>(kIcWHOc2WHIcOc); break; - default: // Default set to IWHO for compatibility with previous APIs - weight_required_pv = std::make_shared>(kIcWHOc2WHIcOc); + default: + weight_required_pv = std::make_shared>(); break; } diff --git a/src/tim/transform/pad_layout_inference_test.cc b/src/tim/transform/pad_layout_inference_test.cc index c354eb7..7c8b474 100644 --- a/src/tim/transform/pad_layout_inference_test.cc +++ b/src/tim/transform/pad_layout_inference_test.cc @@ -42,7 +42,8 @@ TEST(Pad, layout_inference) { std::array dilation({1, 1}); auto op1 = graph->CreateOperation( - tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN); + tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN, + tim::vx::DataLayout::IcWHOc); (*op1) .BindInputs({input_tensor, kernel_tensor}) .BindOutputs({conv2dout_tensor}); diff --git a/src/tim/transform/stack_layout_inference_test.cc b/src/tim/transform/stack_layout_inference_test.cc index b101750..9d51106 100644 --- a/src/tim/transform/stack_layout_inference_test.cc +++ b/src/tim/transform/stack_layout_inference_test.cc @@ -46,7 +46,8 @@ TEST(Stack, DISABLED_LayoutinferernceTest_1) { std::array dilation({1, 1}); auto op1 = graph->CreateOperation( - tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN); + tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN, + tim::vx::DataLayout::IcWHOc); (*op1) .BindInputs({input_tensor, kernel_tensor}) .BindOutputs({conv2dout_tensor}); @@ -115,7 +116,8 @@ TEST(Stack, LayoutinferernceTest_2) { std::array dilation({1, 1}); auto op1 = graph->CreateOperation( - tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN); + tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN, + tim::vx::DataLayout::IcWHOc); (*op1) .BindInputs({input_tensor, kernel_tensor}) .BindOutputs({conv2dout_tensor}); @@ -188,12 +190,14 @@ TEST(Stack, LayoutinferernceTest_3) { std::array stride({1, 1}); std::array dilation({1, 1}); auto op1 = graph->CreateOperation( - tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN); + tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN, + tim::vx::DataLayout::IcWHOc); (*op1) .BindInputs({input_tensor, kernel_tensor}) .BindOutputs({conv2dout_tensor}); auto op11 = graph->CreateOperation( - tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN); + tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN, + tim::vx::DataLayout::IcWHOc); (*op11) .BindInputs({input2_tensor, kernel2_tensor}) .BindOutputs({conv2dout2_tensor}); diff --git a/src/tim/transform/stridedslice_layout_inference_test.cc b/src/tim/transform/stridedslice_layout_inference_test.cc index 52afd26..932e816 100644 --- a/src/tim/transform/stridedslice_layout_inference_test.cc +++ b/src/tim/transform/stridedslice_layout_inference_test.cc @@ -52,7 +52,8 @@ TEST(StridedSlice, endmask_2_shrinkmask_2) { std::array dilation({1, 1}); auto op1 = graph->CreateOperation( - tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN); + tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN, + tim::vx::DataLayout::IcWHOc); (*op1) .BindInputs({input_tensor, kernel_tensor}) .BindOutputs({conv2dout_tensor}); @@ -118,7 +119,8 @@ TEST(StridedSlice, endmask_6_shrinkmask_5) { std::array dilation({1, 1}); auto op1 = graph->CreateOperation( - tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN); + tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN, + tim::vx::DataLayout::IcWHOc); (*op1) .BindInputs({input_tensor, kernel_tensor}) .BindOutputs({conv2dout_tensor}); @@ -187,7 +189,8 @@ TEST(StridedSlice, endmask_1_shrinkmask_1) { std::array dilation({1, 1}); auto op1 = graph->CreateOperation( - tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN); + tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN, + tim::vx::DataLayout::IcWHOc); (*op1) .BindInputs({input_tensor, kernel_tensor}) .BindOutputs({conv2dout_tensor}); @@ -254,7 +257,8 @@ TEST(StridedSlice, beginmask_9_endmask_15) { std::array dilation({1, 1}); auto op1 = graph->CreateOperation( - tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN); + tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN, + tim::vx::DataLayout::IcWHOc); (*op1) .BindInputs({input_tensor, kernel_tensor}) .BindOutputs({conv2dout_tensor});