Add support for strides when emitting convolution loop nest. (#76)

* Add support for strides when emitting convolution loop nest.

* Only emit stride multiplication if strides is greater than one.

* Add test.
This commit is contained in:
Gheorghe-Teodor Bercea 2020-02-11 11:53:13 -05:00 committed by GitHub
parent adad9e24bd
commit 094be4f37a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 73 additions and 6 deletions

View File

@ -1808,6 +1808,8 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
// //
// The loop nest will look as follows: // The loop nest will look as follows:
// //
// strides = [s1, s2]
//
// kernelsPerGroup = M / group; // kernelsPerGroup = M / group;
// for n = 0 .. N: // for n = 0 .. N:
// for g = 0 .. group: // for g = 0 .. group:
@ -1820,9 +1822,14 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
// for k1 = 0 .. KH: // for k1 = 0 .. KH:
// for k2 = 0 .. KW: // for k2 = 0 .. KW:
// R[n][kernel][r1][r2] = // R[n][kernel][r1][r2] =
// D[n][g * (C / group) + c][r1 + k1][r2 + k2] * // D[n][g * (C / group) + c][s1 * r1 + k1][s2 * r2 + k2] *
// K[kernel][c][k1][k2]; // K[kernel][c][k1][k2];
// //
// Naming:
// n, g, m: outer loop nest indices
// r1, r2: spatial loop nest indices
// c, k1, k2: inner loop nest indices
//
// TODO: handle padding. // TODO: handle padding.
// //
// In the general case: // In the general case:
@ -1969,7 +1976,7 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
{ {
// 4. Emit inner loop body // 4. Emit inner loop body
// R[n][kernel][r1][r2] = // R[n][kernel][r1][r2] =
// D[n][g * (C / group) + c][r1 + k1][r2 + k2] * // D[n][g * (C / group) + c][s1 * r1 + k1][s2 * r2 + k2] *
// K[kernel][c][k1][k2]; // K[kernel][c][k1][k2];
// 4.1 Prepare indices for accesing the data tensor. // 4.1 Prepare indices for accesing the data tensor.
@ -1983,12 +1990,24 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
rewriter.create<MulIOp>(loc, subchannels, rewriter.create<MulIOp>(loc, subchannels,
outerIterationBlock.getArguments()[1])); outerIterationBlock.getArguments()[1]));
dataIndices.emplace_back(channelDepth); dataIndices.emplace_back(channelDepth);
// rX + kX // sX * rX + kX
for (int i = 0; i < kernelShape.size() - 2; ++i) auto stridesAttribute = convOp.stridesAttr();
// Read strides attribute
SmallVector<int, 4> strides;
if (stridesAttribute)
for (auto stride : stridesAttribute.getValue())
strides.emplace_back(stride.cast<IntegerAttr>().getInt());
for (int i = 0; i < kernelShape.size() - 2; ++i) {
Value spatialIndex = spatialIterationBlock.getArguments()[i];
// If strides are present then emit the correct access index.
if (stridesAttribute && strides[i] > 1)
spatialIndex = rewriter.create<MulIOp>(loc,
rewriter.create<ConstantIndexOp>(loc, strides[i]),
spatialIterationBlock.getArguments()[i]);
dataIndices.emplace_back( dataIndices.emplace_back(
rewriter.create<AddIOp>(loc, rewriter.create<AddIOp>(loc, spatialIndex,
spatialIterationBlock.getArguments()[i],
innerIterationBlock.getArguments()[i+1])); innerIterationBlock.getArguments()[i+1]));
}
// 4.2 Prepare indices for accessing the kernel tensor. // 4.2 Prepare indices for accessing the kernel tensor.
SmallVector<Value, 4> kernelIndices; SmallVector<Value, 4> kernelIndices;

View File

@ -244,6 +244,7 @@ test_to_enable = [
# Conv # Conv
"test_basic_conv_without_padding_cpu", "test_basic_conv_without_padding_cpu",
"test_conv_with_strides_no_padding_cpu",
# Sign Op: # Sign Op:
"test_sign_cpu", "test_sign_cpu",

View File

@ -1019,3 +1019,50 @@ func @test_conv_no_bias_no_pad_w_group(%arg0 : tensor<1x9x32x64xf32>, %arg1 : te
// CHECK: return [[RES]] : memref<1x5x27x58xf32> // CHECK: return [[RES]] : memref<1x5x27x58xf32>
} }
func @test_conv_no_bias_no_pad_w_strides(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>) -> tensor<*xf32> {
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, strides = [2, 2]} : (tensor<1x9x32x64xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_conv_no_bias_no_pad_w_strides
// CHECK: [[RES:%.+]] = alloc() : memref<1x5x14x29xf32>
// CHECK: [[CONST0:%.+]] = constant 5 : index
// CHECK: [[CONST1:%.+]] = constant 0.000000e+00 : f32
// CHECK: [[CONST2:%.+]] = constant 9 : index
// CHECK: [[OUTER_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_OUTER_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[OUTER_LOOPS]]#0, [[OUTER_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_OUTER_LOOPS]]#0, [[OPT_OUTER_LOOPS]]#1) with ([[OUTER_LOOPS]]#0 -> %arg2 = 0 to 1, [[OUTER_LOOPS]]#1 -> %arg3 = 0 to 5) {
// CHECK: [[SPATIAL_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_SPATIAL_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[SPATIAL_LOOPS]]#0, [[SPATIAL_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_SPATIAL_LOOPS]]#0, [[OPT_SPATIAL_LOOPS]]#1) with ([[SPATIAL_LOOPS]]#0 -> %arg4 = 0 to 14, [[SPATIAL_LOOPS]]#1 -> %arg5 = 0 to 29) {
// CHECK: store [[CONST1]], [[RES]][%arg2, %arg3, %arg4, %arg5] : memref<1x5x14x29xf32>
// CHECK: [[INNER_LOOPS:%.+]]:3 = krnl.define_loops 3
// CHECK: [[OPT_INNER_LOOPS:%.+]]:3 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[INNER_LOOPS]]#0, [[INNER_LOOPS]]#1, [[INNER_LOOPS]]#2
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_INNER_LOOPS]]#0, [[OPT_INNER_LOOPS]]#1, [[OPT_INNER_LOOPS]]#2) with ([[INNER_LOOPS]]#0 -> %arg6 = 0 to 9, [[INNER_LOOPS]]#1 -> %arg7 = 0 to 6, [[INNER_LOOPS]]#2 -> %arg8 = 0 to 7) {
// CHECK: [[CONST_STRIDE1:%.+]] = constant 2 : index
// CHECK: [[MUL1:%.+]] = muli [[CONST_STRIDE1]], %arg4 : index
// CHECK: [[R1PLUSK1:%.+]] = addi [[MUL1]], %arg7 : index
// CHECK: [[CONST_STRIDE2:%.+]] = constant 2 : index
// CHECK: [[MUL2:%.+]] = muli [[CONST_STRIDE2]], %arg5 : index
// CHECK: [[R2PLUSK2:%.+]] = addi [[MUL2]], %arg8 : index
// CHECK: [[DATA:%.+]] = load %arg0[%arg2, %arg6, [[R1PLUSK1]], [[R2PLUSK2]]] : memref<1x9x32x64xf32>
// CHECK: [[KERNEL:%.+]] = load %arg1[%arg3, %arg6, %arg7, %arg8] : memref<5x9x6x7xf32>
// CHECK: [[ACC_RES:%.+]] = load %0[%arg2, %arg3, %arg4, %arg5] : memref<1x5x14x29xf32>
// CHECK: [[MUL:%.+]] = mulf [[DATA]], [[KERNEL]] : f32
// CHECK: [[ADD:%.+]] = addf [[ACC_RES]], [[MUL]] : f32
// CHECK: store [[ADD]], [[RES]][%arg2, %arg3, %arg4, %arg5] : memref<1x5x14x29xf32>
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: return [[RES]] : memref<1x5x14x29xf32>
}