From 094be4f37ade82494650eae525cb1efe6831dd61 Mon Sep 17 00:00:00 2001 From: Gheorghe-Teodor Bercea Date: Tue, 11 Feb 2020 11:53:13 -0500 Subject: [PATCH] Add support for strides when emitting convolution loop nest. (#76) * Add support for strides when emitting convolution loop nest. * Only emit stride multiplication if strides is greater than one. * Add test. --- src/pass/lower_frontend_to_krnl.cpp | 31 +++++++++++++++---- test/backend/test.py | 1 + test/mlir/onnx/onnx_lowering.mlir | 47 +++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 6 deletions(-) diff --git a/src/pass/lower_frontend_to_krnl.cpp b/src/pass/lower_frontend_to_krnl.cpp index 265d92c..7ef4add 100644 --- a/src/pass/lower_frontend_to_krnl.cpp +++ b/src/pass/lower_frontend_to_krnl.cpp @@ -1808,6 +1808,8 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern { // // The loop nest will look as follows: // + // strides = [s1, s2] + // // kernelsPerGroup = M / group; // for n = 0 .. N: // for g = 0 .. group: @@ -1820,9 +1822,14 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern { // for k1 = 0 .. KH: // for k2 = 0 .. KW: // R[n][kernel][r1][r2] = - // D[n][g * (C / group) + c][r1 + k1][r2 + k2] * + // D[n][g * (C / group) + c][s1 * r1 + k1][s2 * r2 + k2] * // K[kernel][c][k1][k2]; // + // Naming: + // n, g, m: outer loop nest indices + // r1, r2: spatial loop nest indices + // c, k1, k2: inner loop nest indices + // // TODO: handle padding. // // In the general case: @@ -1969,7 +1976,7 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern { { // 4. Emit inner loop body // R[n][kernel][r1][r2] = - // D[n][g * (C / group) + c][r1 + k1][r2 + k2] * + // D[n][g * (C / group) + c][s1 * r1 + k1][s2 * r2 + k2] * // K[kernel][c][k1][k2]; // 4.1 Prepare indices for accesing the data tensor. @@ -1983,12 +1990,24 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern { rewriter.create(loc, subchannels, outerIterationBlock.getArguments()[1])); dataIndices.emplace_back(channelDepth); - // rX + kX - for (int i = 0; i < kernelShape.size() - 2; ++i) + // sX * rX + kX + auto stridesAttribute = convOp.stridesAttr(); + // Read strides attribute + SmallVector strides; + if (stridesAttribute) + for (auto stride : stridesAttribute.getValue()) + strides.emplace_back(stride.cast().getInt()); + for (int i = 0; i < kernelShape.size() - 2; ++i) { + Value spatialIndex = spatialIterationBlock.getArguments()[i]; + // If strides are present then emit the correct access index. + if (stridesAttribute && strides[i] > 1) + spatialIndex = rewriter.create(loc, + rewriter.create(loc, strides[i]), + spatialIterationBlock.getArguments()[i]); dataIndices.emplace_back( - rewriter.create(loc, - spatialIterationBlock.getArguments()[i], + rewriter.create(loc, spatialIndex, innerIterationBlock.getArguments()[i+1])); + } // 4.2 Prepare indices for accessing the kernel tensor. SmallVector kernelIndices; diff --git a/test/backend/test.py b/test/backend/test.py index 4a375a6..8550733 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -244,6 +244,7 @@ test_to_enable = [ # Conv "test_basic_conv_without_padding_cpu", + "test_conv_with_strides_no_padding_cpu", # Sign Op: "test_sign_cpu", diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index ff16551..09d48cb 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -1019,3 +1019,50 @@ func @test_conv_no_bias_no_pad_w_group(%arg0 : tensor<1x9x32x64xf32>, %arg1 : te // CHECK: return [[RES]] : memref<1x5x27x58xf32> } + +func @test_conv_no_bias_no_pad_w_strides(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>) -> tensor<*xf32> { + %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, strides = [2, 2]} : (tensor<1x9x32x64xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_conv_no_bias_no_pad_w_strides + // CHECK: [[RES:%.+]] = alloc() : memref<1x5x14x29xf32> + // CHECK: [[CONST0:%.+]] = constant 5 : index + // CHECK: [[CONST1:%.+]] = constant 0.000000e+00 : f32 + // CHECK: [[CONST2:%.+]] = constant 9 : index + // CHECK: [[OUTER_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_OUTER_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[OUTER_LOOPS]]#0, [[OUTER_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + + // CHECK: krnl.iterate([[OPT_OUTER_LOOPS]]#0, [[OPT_OUTER_LOOPS]]#1) with ([[OUTER_LOOPS]]#0 -> %arg2 = 0 to 1, [[OUTER_LOOPS]]#1 -> %arg3 = 0 to 5) { + // CHECK: [[SPATIAL_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_SPATIAL_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[SPATIAL_LOOPS]]#0, [[SPATIAL_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + + // CHECK: krnl.iterate([[OPT_SPATIAL_LOOPS]]#0, [[OPT_SPATIAL_LOOPS]]#1) with ([[SPATIAL_LOOPS]]#0 -> %arg4 = 0 to 14, [[SPATIAL_LOOPS]]#1 -> %arg5 = 0 to 29) { + // CHECK: store [[CONST1]], [[RES]][%arg2, %arg3, %arg4, %arg5] : memref<1x5x14x29xf32> + // CHECK: [[INNER_LOOPS:%.+]]:3 = krnl.define_loops 3 + // CHECK: [[OPT_INNER_LOOPS:%.+]]:3 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[INNER_LOOPS]]#0, [[INNER_LOOPS]]#1, [[INNER_LOOPS]]#2 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) + + // CHECK: krnl.iterate([[OPT_INNER_LOOPS]]#0, [[OPT_INNER_LOOPS]]#1, [[OPT_INNER_LOOPS]]#2) with ([[INNER_LOOPS]]#0 -> %arg6 = 0 to 9, [[INNER_LOOPS]]#1 -> %arg7 = 0 to 6, [[INNER_LOOPS]]#2 -> %arg8 = 0 to 7) { + // CHECK: [[CONST_STRIDE1:%.+]] = constant 2 : index + // CHECK: [[MUL1:%.+]] = muli [[CONST_STRIDE1]], %arg4 : index + // CHECK: [[R1PLUSK1:%.+]] = addi [[MUL1]], %arg7 : index + // CHECK: [[CONST_STRIDE2:%.+]] = constant 2 : index + // CHECK: [[MUL2:%.+]] = muli [[CONST_STRIDE2]], %arg5 : index + // CHECK: [[R2PLUSK2:%.+]] = addi [[MUL2]], %arg8 : index + // CHECK: [[DATA:%.+]] = load %arg0[%arg2, %arg6, [[R1PLUSK1]], [[R2PLUSK2]]] : memref<1x9x32x64xf32> + // CHECK: [[KERNEL:%.+]] = load %arg1[%arg3, %arg6, %arg7, %arg8] : memref<5x9x6x7xf32> + // CHECK: [[ACC_RES:%.+]] = load %0[%arg2, %arg3, %arg4, %arg5] : memref<1x5x14x29xf32> + // CHECK: [[MUL:%.+]] = mulf [[DATA]], [[KERNEL]] : f32 + // CHECK: [[ADD:%.+]] = addf [[ACC_RES]], [[MUL]] : f32 + // CHECK: store [[ADD]], [[RES]][%arg2, %arg3, %arg4, %arg5] : memref<1x5x14x29xf32> + // CHECK: } + // CHECK: } + // CHECK: } + + // CHECK: return [[RES]] : memref<1x5x14x29xf32> +}