Add support for strides when emitting convolution loop nest. (#76)
* Add support for strides when emitting convolution loop nest. * Only emit stride multiplication if strides is greater than one. * Add test.
This commit is contained in:
parent
adad9e24bd
commit
094be4f37a
|
@ -1808,6 +1808,8 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
|||
//
|
||||
// The loop nest will look as follows:
|
||||
//
|
||||
// strides = [s1, s2]
|
||||
//
|
||||
// kernelsPerGroup = M / group;
|
||||
// for n = 0 .. N:
|
||||
// for g = 0 .. group:
|
||||
|
@ -1820,9 +1822,14 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
|||
// for k1 = 0 .. KH:
|
||||
// for k2 = 0 .. KW:
|
||||
// R[n][kernel][r1][r2] =
|
||||
// D[n][g * (C / group) + c][r1 + k1][r2 + k2] *
|
||||
// D[n][g * (C / group) + c][s1 * r1 + k1][s2 * r2 + k2] *
|
||||
// K[kernel][c][k1][k2];
|
||||
//
|
||||
// Naming:
|
||||
// n, g, m: outer loop nest indices
|
||||
// r1, r2: spatial loop nest indices
|
||||
// c, k1, k2: inner loop nest indices
|
||||
//
|
||||
// TODO: handle padding.
|
||||
//
|
||||
// In the general case:
|
||||
|
@ -1969,7 +1976,7 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
|||
{
|
||||
// 4. Emit inner loop body
|
||||
// R[n][kernel][r1][r2] =
|
||||
// D[n][g * (C / group) + c][r1 + k1][r2 + k2] *
|
||||
// D[n][g * (C / group) + c][s1 * r1 + k1][s2 * r2 + k2] *
|
||||
// K[kernel][c][k1][k2];
|
||||
|
||||
// 4.1 Prepare indices for accesing the data tensor.
|
||||
|
@ -1983,12 +1990,24 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
|||
rewriter.create<MulIOp>(loc, subchannels,
|
||||
outerIterationBlock.getArguments()[1]));
|
||||
dataIndices.emplace_back(channelDepth);
|
||||
// rX + kX
|
||||
for (int i = 0; i < kernelShape.size() - 2; ++i)
|
||||
// sX * rX + kX
|
||||
auto stridesAttribute = convOp.stridesAttr();
|
||||
// Read strides attribute
|
||||
SmallVector<int, 4> strides;
|
||||
if (stridesAttribute)
|
||||
for (auto stride : stridesAttribute.getValue())
|
||||
strides.emplace_back(stride.cast<IntegerAttr>().getInt());
|
||||
for (int i = 0; i < kernelShape.size() - 2; ++i) {
|
||||
Value spatialIndex = spatialIterationBlock.getArguments()[i];
|
||||
// If strides are present then emit the correct access index.
|
||||
if (stridesAttribute && strides[i] > 1)
|
||||
spatialIndex = rewriter.create<MulIOp>(loc,
|
||||
rewriter.create<ConstantIndexOp>(loc, strides[i]),
|
||||
spatialIterationBlock.getArguments()[i]);
|
||||
dataIndices.emplace_back(
|
||||
rewriter.create<AddIOp>(loc,
|
||||
spatialIterationBlock.getArguments()[i],
|
||||
rewriter.create<AddIOp>(loc, spatialIndex,
|
||||
innerIterationBlock.getArguments()[i+1]));
|
||||
}
|
||||
|
||||
// 4.2 Prepare indices for accessing the kernel tensor.
|
||||
SmallVector<Value, 4> kernelIndices;
|
||||
|
|
|
@ -244,6 +244,7 @@ test_to_enable = [
|
|||
|
||||
# Conv
|
||||
"test_basic_conv_without_padding_cpu",
|
||||
"test_conv_with_strides_no_padding_cpu",
|
||||
|
||||
# Sign Op:
|
||||
"test_sign_cpu",
|
||||
|
|
|
@ -1019,3 +1019,50 @@ func @test_conv_no_bias_no_pad_w_group(%arg0 : tensor<1x9x32x64xf32>, %arg1 : te
|
|||
|
||||
// CHECK: return [[RES]] : memref<1x5x27x58xf32>
|
||||
}
|
||||
|
||||
func @test_conv_no_bias_no_pad_w_strides(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, strides = [2, 2]} : (tensor<1x9x32x64xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_conv_no_bias_no_pad_w_strides
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<1x5x14x29xf32>
|
||||
// CHECK: [[CONST0:%.+]] = constant 5 : index
|
||||
// CHECK: [[CONST1:%.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: [[CONST2:%.+]] = constant 9 : index
|
||||
// CHECK: [[OUTER_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_OUTER_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[OUTER_LOOPS]]#0, [[OUTER_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
|
||||
// CHECK: krnl.iterate([[OPT_OUTER_LOOPS]]#0, [[OPT_OUTER_LOOPS]]#1) with ([[OUTER_LOOPS]]#0 -> %arg2 = 0 to 1, [[OUTER_LOOPS]]#1 -> %arg3 = 0 to 5) {
|
||||
// CHECK: [[SPATIAL_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_SPATIAL_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[SPATIAL_LOOPS]]#0, [[SPATIAL_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
|
||||
// CHECK: krnl.iterate([[OPT_SPATIAL_LOOPS]]#0, [[OPT_SPATIAL_LOOPS]]#1) with ([[SPATIAL_LOOPS]]#0 -> %arg4 = 0 to 14, [[SPATIAL_LOOPS]]#1 -> %arg5 = 0 to 29) {
|
||||
// CHECK: store [[CONST1]], [[RES]][%arg2, %arg3, %arg4, %arg5] : memref<1x5x14x29xf32>
|
||||
// CHECK: [[INNER_LOOPS:%.+]]:3 = krnl.define_loops 3
|
||||
// CHECK: [[OPT_INNER_LOOPS:%.+]]:3 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[INNER_LOOPS]]#0, [[INNER_LOOPS]]#1, [[INNER_LOOPS]]#2
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
|
||||
|
||||
// CHECK: krnl.iterate([[OPT_INNER_LOOPS]]#0, [[OPT_INNER_LOOPS]]#1, [[OPT_INNER_LOOPS]]#2) with ([[INNER_LOOPS]]#0 -> %arg6 = 0 to 9, [[INNER_LOOPS]]#1 -> %arg7 = 0 to 6, [[INNER_LOOPS]]#2 -> %arg8 = 0 to 7) {
|
||||
// CHECK: [[CONST_STRIDE1:%.+]] = constant 2 : index
|
||||
// CHECK: [[MUL1:%.+]] = muli [[CONST_STRIDE1]], %arg4 : index
|
||||
// CHECK: [[R1PLUSK1:%.+]] = addi [[MUL1]], %arg7 : index
|
||||
// CHECK: [[CONST_STRIDE2:%.+]] = constant 2 : index
|
||||
// CHECK: [[MUL2:%.+]] = muli [[CONST_STRIDE2]], %arg5 : index
|
||||
// CHECK: [[R2PLUSK2:%.+]] = addi [[MUL2]], %arg8 : index
|
||||
// CHECK: [[DATA:%.+]] = load %arg0[%arg2, %arg6, [[R1PLUSK1]], [[R2PLUSK2]]] : memref<1x9x32x64xf32>
|
||||
// CHECK: [[KERNEL:%.+]] = load %arg1[%arg3, %arg6, %arg7, %arg8] : memref<5x9x6x7xf32>
|
||||
// CHECK: [[ACC_RES:%.+]] = load %0[%arg2, %arg3, %arg4, %arg5] : memref<1x5x14x29xf32>
|
||||
// CHECK: [[MUL:%.+]] = mulf [[DATA]], [[KERNEL]] : f32
|
||||
// CHECK: [[ADD:%.+]] = addf [[ACC_RES]], [[MUL]] : f32
|
||||
// CHECK: store [[ADD]], [[RES]][%arg2, %arg3, %arg4, %arg5] : memref<1x5x14x29xf32>
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
|
||||
// CHECK: return [[RES]] : memref<1x5x14x29xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue