Fixed conv2d grouped_conv2d deconv2d layoutinfer bug (#622)

Signed-off-by: Chen <jack.chen@verisilicon.com>
Co-authored-by: Chen <jack.chen@verisilicon.com>
This commit is contained in:
chxin66 2023-07-24 17:10:24 +08:00 committed by GitHub
parent 315adcf076
commit 680e8d59cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 26 additions and 16 deletions

View File

@ -36,7 +36,8 @@ TEST(LayoutInference, simple_conv2d) {
kernel_shape[0], tim::vx::PadType::AUTO,
std::array<uint32_t, 2>({kernel_shape[2], kernel_shape[1]}),
std::array<uint32_t, 2>({1, 1}), std::array<uint32_t, 2>({0, 0}),
std::array<uint32_t, 4>({0, 0, 0, 0}), 0, tim::vx::DataLayout::CWHN);
std::array<uint32_t, 4>({0, 0, 0, 0}), 0, tim::vx::DataLayout::CWHN,
tim::vx::DataLayout::IcWHOc);
(*conv2d).BindInputs({input, kernel, bias}).BindOutput(output);
// Do layout inference
auto transform = tim::transform::LayoutInference(src_graph, ctx);

View File

@ -67,8 +67,8 @@ class Conv2dLayoutInfer : public OpLayoutInfer {
case vx::DataLayout::IcWHOc: // Support nnapi & tflite Kernel Layout
weight_required_pv = std::make_shared<PermuteVector<4>>(kIcWHOc2WHIcOc);
break;
default: // Default set to IWHO for compatibility with previous APIs
weight_required_pv = std::make_shared<PermuteVector<4>>(kIcWHOc2WHIcOc);
default:
weight_required_pv = std::make_shared<PermuteVector<4>>();
break;
}

View File

@ -67,8 +67,8 @@ class DeConv2dLayoutInfer : public OpLayoutInfer {
case vx::DataLayout::IcWHOc: // Support nnapi & tflite Kernel Layout
weight_required_pv = std::make_shared<PermuteVector<4>>(kIcWHOc2WHIcOc);
break;
default: // Default set to IWHO for compatibility with previous APIs
weight_required_pv = std::make_shared<PermuteVector<4>>(kIcWHOc2WHIcOc);
default:
weight_required_pv = std::make_shared<PermuteVector<4>>();
break;
}

View File

@ -67,8 +67,8 @@ class GroupedConv2dLayoutInfer : public OpLayoutInfer {
case vx::DataLayout::IcWHOc: // Support nnapi & tflite Kernel Layout
weight_required_pv = std::make_shared<PermuteVector<4>>(kIcWHOc2WHIcOc);
break;
default: // Default set to IWHO for compatibility with previous APIs
weight_required_pv = std::make_shared<PermuteVector<4>>(kIcWHOc2WHIcOc);
default:
weight_required_pv = std::make_shared<PermuteVector<4>>();
break;
}

View File

@ -42,7 +42,8 @@ TEST(Pad, layout_inference) {
std::array<uint32_t, 2> dilation({1, 1});
auto op1 = graph->CreateOperation<tim::vx::ops::Conv2d>(
tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN);
tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN,
tim::vx::DataLayout::IcWHOc);
(*op1)
.BindInputs({input_tensor, kernel_tensor})
.BindOutputs({conv2dout_tensor});

View File

@ -46,7 +46,8 @@ TEST(Stack, DISABLED_LayoutinferernceTest_1) {
std::array<uint32_t, 2> dilation({1, 1});
auto op1 = graph->CreateOperation<tim::vx::ops::Conv2d>(
tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN);
tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN,
tim::vx::DataLayout::IcWHOc);
(*op1)
.BindInputs({input_tensor, kernel_tensor})
.BindOutputs({conv2dout_tensor});
@ -115,7 +116,8 @@ TEST(Stack, LayoutinferernceTest_2) {
std::array<uint32_t, 2> dilation({1, 1});
auto op1 = graph->CreateOperation<tim::vx::ops::Conv2d>(
tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN);
tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN,
tim::vx::DataLayout::IcWHOc);
(*op1)
.BindInputs({input_tensor, kernel_tensor})
.BindOutputs({conv2dout_tensor});
@ -188,12 +190,14 @@ TEST(Stack, LayoutinferernceTest_3) {
std::array<uint32_t, 2> stride({1, 1});
std::array<uint32_t, 2> dilation({1, 1});
auto op1 = graph->CreateOperation<tim::vx::ops::Conv2d>(
tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN);
tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN,
tim::vx::DataLayout::IcWHOc);
(*op1)
.BindInputs({input_tensor, kernel_tensor})
.BindOutputs({conv2dout_tensor});
auto op11 = graph->CreateOperation<tim::vx::ops::Conv2d>(
tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN);
tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN,
tim::vx::DataLayout::IcWHOc);
(*op11)
.BindInputs({input2_tensor, kernel2_tensor})
.BindOutputs({conv2dout2_tensor});

View File

@ -52,7 +52,8 @@ TEST(StridedSlice, endmask_2_shrinkmask_2) {
std::array<uint32_t, 2> dilation({1, 1});
auto op1 = graph->CreateOperation<tim::vx::ops::Conv2d>(
tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN);
tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN,
tim::vx::DataLayout::IcWHOc);
(*op1)
.BindInputs({input_tensor, kernel_tensor})
.BindOutputs({conv2dout_tensor});
@ -118,7 +119,8 @@ TEST(StridedSlice, endmask_6_shrinkmask_5) {
std::array<uint32_t, 2> dilation({1, 1});
auto op1 = graph->CreateOperation<tim::vx::ops::Conv2d>(
tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN);
tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN,
tim::vx::DataLayout::IcWHOc);
(*op1)
.BindInputs({input_tensor, kernel_tensor})
.BindOutputs({conv2dout_tensor});
@ -187,7 +189,8 @@ TEST(StridedSlice, endmask_1_shrinkmask_1) {
std::array<uint32_t, 2> dilation({1, 1});
auto op1 = graph->CreateOperation<tim::vx::ops::Conv2d>(
tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN);
tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN,
tim::vx::DataLayout::IcWHOc);
(*op1)
.BindInputs({input_tensor, kernel_tensor})
.BindOutputs({conv2dout_tensor});
@ -254,7 +257,8 @@ TEST(StridedSlice, beginmask_9_endmask_15) {
std::array<uint32_t, 2> dilation({1, 1});
auto op1 = graph->CreateOperation<tim::vx::ops::Conv2d>(
tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN);
tim::vx::PadType::VALID, stride, dilation, 0, tim::vx::DataLayout::CWHN,
tim::vx::DataLayout::IcWHOc);
(*op1)
.BindInputs({input_tensor, kernel_tensor})
.BindOutputs({conv2dout_tensor});