From 2814ea38982e888eaae409b09446ad72103e8452 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Wed, 18 Mar 2020 22:55:50 +0900 Subject: [PATCH] Support dilations and enable the remaining e2e tests for MaxPoolSingleOut (#31) * Support dilations and enable e2e tests * Fix allocating memory for dynamic shape * Edit comments * Do dilation by computing an offset from kernel index * Correct dilation formula, add an example of out-of-bound, and add a test for dilation Co-authored-by: Gheorghe-Teodor Bercea --- src/conversion/onnx_to_krnl/nn/pooling.cpp | 90 ++++++++++++++++------ test/backend/test.py | 7 ++ test/mlir/onnx/onnx_lowering.mlir | 74 ++++++++++++++---- 3 files changed, 135 insertions(+), 36 deletions(-) diff --git a/src/conversion/onnx_to_krnl/nn/pooling.cpp b/src/conversion/onnx_to_krnl/nn/pooling.cpp index 17e5b9d..2ab761d 100644 --- a/src/conversion/onnx_to_krnl/nn/pooling.cpp +++ b/src/conversion/onnx_to_krnl/nn/pooling.cpp @@ -108,7 +108,7 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { for (int i = batchRank; i < resultShape.size(); ++i) { if (resultShape[i] < 0) { // dim = - // let numerator = (input + pad - (kernel - 1) * dilation + 1) + // let numerator = (input + pad - (kernel - 1) * dilation - 1) // in let denomitor = stride // in // if (ceilMode) @@ -117,13 +117,13 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { // floor(numerator / denominator) + 1 int spatialIndex = i - batchRank; - // numerator = (input + pad - (kernel - 1) * dilation + 1) + // numerator = (input + pad - (kernel - 1) * dilation - 1) auto inputDim = rewriter.create(loc, inputOperand, i); auto inputVal = rewriter.create( loc, inputDim, rewriter.getIntegerType(64)); int64_t padKernelDilation = (pads[spatialIndex] + pads[spatialIndex + spatialRank]) - - (kernelShape[spatialIndex] - 1) * dilations[spatialIndex] + 1; + (kernelShape[spatialIndex] - 1) * dilations[spatialIndex] - 1; auto padKernelDilationVal = rewriter.create( loc, rewriter.getIntegerAttr( rewriter.getIntegerType(64), padKernelDilation)); @@ -177,12 +177,14 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { // R[n][c][r1][r2] = negative_infinity; // for k1 = 0 .. KH: // for k2 = 0 .. KW: - // t = D[n][c][s1 * r1 + k1][s2 * r2 + k2]; + // t = D[n][c][s1 * r1 + k1 * d1][s2 * r2 + k2 * d2]; // R[n][c][r1][r2] = max(R[n][c][r1][r2], t); // // Naming: // n, c, r1, r2: outer loop nest indices // k1, k2: inner loop nest indices + // s1, s2: strides + // d1, d2: dilations // // TODO: handle padding. // @@ -217,45 +219,87 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { rewriter.setInsertionPointToStart(innerLoops.getIterateBlock()); { // 3. Emit inner loop body - // t = D[n][c][s1 * r1 + k1][s2 * r2 + k2]; + // t = D[n][c][s1 * r1 + k1 * d1][s2 * r2 + k2 * d2]; // R[n][c][r1][r2] = max(R[n][c][r1][r2], t); // 3.1 Prepare indices for accesing the data tensor. SmallVector dataIndices; - // Batch indices: n, c + // 3.1.1 Batch indices: n, c for (int i = 0; i < batchRank; ++i) dataIndices.emplace_back(outerLoops.getInductionVar(i)); - // Spatial indices: sX * rX + kX + // 3.1.2 Insert spatial indices: sX * rX + kX * dX for (int i = batchRank; i < nOuterLoops; ++i) { + // Get index along the inner loop's induction variables. + // It is used to obtain kernel/pad/stride/dilation index. + int j = i - batchRank; + Value spatialIndex = outerLoops.getInductionVar(i); - // If strides are present then emit the correct access index. - if (stridesAttribute && strides[i - batchRank] > 1) { - spatialIndex = rewriter.create(loc, - rewriter.create(loc, strides[i - batchRank]), - outerLoops.getInductionVar(i)); + // If strides are present (not default) then emit the correct access + // index. + // sX *= rX + if (strides[i - batchRank] > 1) { + auto strideIndex = emitConstantOp( + rewriter, loc, rewriter.getIndexType(), strides[j]); + spatialIndex = rewriter.create( + loc, strideIndex, outerLoops.getInductionVar(i)); } - spatialIndex = rewriter.create( - loc, spatialIndex, innerLoops.getInductionVar(i - batchRank)); - // If ceil mode is enabled, then the calculated access index may - // exceed its dimension. In such a case, we will use the maximum - // index, which causes multiple visits to the element of the + + // Dilate the kernel index only if the dilation value is not one (not + // default). Otherwise, just add kernelIndex. + auto kernelIndex = innerLoops.getInductionVar(j); + if (dilations[j] > 1) { + // sX += dX * kW + auto dilationIndex = emitConstantOp( + rewriter, loc, rewriter.getIndexType(), dilations[j]); + auto dilationKernelIndex = + rewriter.create(loc, dilationIndex, kernelIndex); + spatialIndex = + rewriter.create(loc, spatialIndex, dilationKernelIndex); + } else { + // sX += kX + spatialIndex = + rewriter.create(loc, spatialIndex, kernelIndex); + } + + // If ceil mode or dilation is enabled, then the calculated access + // index may exceed its dimension. In such a case, we will use the + // maximum index, which causes multiple visits to the element of the // maximum index. // TODO: Avoid multiple visits. - if (ceilMode) { - Value inputIndex; + // Example of out-of-bound. + // - Given a 5x5 input X + // X = [[0, 0, 0, 0, 0], + // [1, 1, 1, 1, 1], + // [2, 2, 2, 2, 2], + // [3, 3, 3, 3, 3], + // [4, 4, 4, 4, 4]] + // - Do MaxPool with strides=[2, 2], kernel=[2, 2], ceilMode=true, + // output is a 3x3 array: + // Y = [[1, 1, 1], + // [3, 3, 3], + // [4, 4, 4]] + // - When computing Y[2, 0]: + // - In case of kernelIndex = 1, stride = 2 + // - No dilation: spatialIndex = 2 * 2 + 1 = 5 + // => out of bound + // - dilation = 2: spatialIndex = 2 * 2 + 2 * 1 = 6 + // => out of bound + if (dilations[j] > 1 or ceilMode) { + Value upperIndex; if (inputShape[i] < 0) { Value inputDim = rewriter.create(loc, inputOperand, i); Value one = rewriter.create(loc, 1); - inputIndex = rewriter.create(loc, inputDim, one); + upperIndex = rewriter.create(loc, inputDim, one); } else { - inputIndex = + upperIndex = rewriter.create(loc, inputShape[i] - 1); } auto greaterCondition = rewriter.create( - loc, CmpIPredicate::sgt, spatialIndex, inputIndex); + loc, CmpIPredicate::sgt, spatialIndex, upperIndex); spatialIndex = rewriter.create( - loc, greaterCondition, inputIndex, spatialIndex); + loc, greaterCondition, upperIndex, spatialIndex); } + dataIndices.emplace_back(spatialIndex); } diff --git a/test/backend/test.py b/test/backend/test.py index 95c29ce..902f7e5 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -312,6 +312,13 @@ test_to_enable = [ "test_maxpool_1d_default_cpu", "test_maxpool_2d_ceil_cpu", "test_maxpool_2d_default_cpu", + "test_maxpool_2d_dilations_cpu", + "test_maxpool_2d_pads_cpu", + "test_maxpool_2d_precomputed_pads_cpu", + "test_maxpool_2d_precomputed_same_upper_cpu", + "test_maxpool_2d_precomputed_strides_cpu", + "test_maxpool_2d_same_lower_cpu", + "test_maxpool_2d_same_upper_cpu", "test_maxpool_2d_strides_cpu", "test_maxpool_3d_default_cpu", diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 3f70f26..df5239c 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -1429,15 +1429,17 @@ func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode(%arg0 : tensor<1x3x // CHECK: [[STRIDE_0:%.+]] = constant 2 : index // CHECK: [[MUL_0:%.+]] = muli [[STRIDE_0]], %arg3 : index // CHECK: [[SPATIAL_H:%.+]] = addi [[MUL_0]], %arg5 : index - // CHECK: [[INPUT_INDEX_0:%.+]] = constant 31 : index - // CHECK: [[CMP_0:%.+]] = cmpi "sgt", [[SPATIAL_H]], [[INPUT_INDEX_0]] : index - // CHECK: [[H:%.+]] = select [[CMP_0]], [[INPUT_INDEX_0]], [[SPATIAL_H]] : index + // CHECK: [[UPPER_INDEX_0:%.+]] = constant 31 : index + // CHECK: [[GREATER_THAN_UPPER_0:%.+]] = cmpi "sgt", [[SPATIAL_H]], [[UPPER_INDEX_0]] : index + // CHECK: [[H:%.+]] = select [[GREATER_THAN_UPPER_0]], [[UPPER_INDEX_0]], [[SPATIAL_H]] : index + // CHECK: [[STRIDE_1:%.+]] = constant 2 : index // CHECK: [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg4 : index // CHECK: [[SPATIAL_W:%.+]] = addi [[MUL_1]], %arg6 : index - // CHECK: [[INPUT_INDEX_1:%.+]] = constant 31 : index - // CHECK: [[CMP_1:%.+]] = cmpi "sgt", [[SPATIAL_W]], [[INPUT_INDEX_1]] : index - // CHECK: [[W:%.+]] = select [[CMP_1]], [[INPUT_INDEX_1]], [[SPATIAL_W]] : index + // CHECK: [[UPPER_INDEX_1:%.+]] = constant 31 : index + // CHECK: [[GREATER_THAN_UPPER_1:%.+]] = cmpi "sgt", [[SPATIAL_W]], [[UPPER_INDEX_1]] : index + // CHECK: [[W:%.+]] = select [[GREATER_THAN_UPPER_1]], [[UPPER_INDEX_1]], [[SPATIAL_W]] : index + // CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<1x3x32x32xf32> // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> // CHECK: [[CMP_2:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32 @@ -1448,6 +1450,52 @@ func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode(%arg0 : tensor<1x3x // CHECK: return [[RES]] : memref<1x3x16x16xf32> } +func @test_maxpooling_singleout_no_pad_w_strides_w_dilation(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*xf32> { + %0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [3, 3], strides = [2, 2], dilations = [2, 2]} : (tensor<1x3x32x32xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_maxpooling_singleout_no_pad_w_strides_w_dilation + // CHECK: [[RES:%.+]] = alloc() : memref<1x3x14x14xf32> + // CHECK: [[DEF_LOOPS_0:%.+]]:4 = krnl.define_loops 4 + // CHECK: [[OPT_LOOPS_0:%.+]]:4 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS_0]]#0, [[DEF_LOOPS_0]]#1, [[DEF_LOOPS_0]]#2, [[DEF_LOOPS_0]]#3 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS_0]]#0, [[OPT_LOOPS_0]]#1, [[OPT_LOOPS_0]]#2, [[OPT_LOOPS_0]]#3) with ([[DEF_LOOPS_0]]#0 -> %arg1 = 0 to 1, [[DEF_LOOPS_0]]#1 -> %arg2 = 0 to 3, [[DEF_LOOPS_0]]#2 -> %arg3 = 0 to 14, [[DEF_LOOPS_0]]#3 -> %arg4 = 0 to 14) { + // CHECK: [[NEGATIVE_INFINITY:%.+]] = constant 0xFF800000 : f32 + // CHECK: store [[NEGATIVE_INFINITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x14x14xf32> + // CHECK: [[DEF_LOOPS_1:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS_1:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS_1]]#0, [[DEF_LOOPS_1]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS_1]]#0, [[OPT_LOOPS_1]]#1) with ([[DEF_LOOPS_1]]#0 -> %arg5 = 0 to 3, [[DEF_LOOPS_1]]#1 -> %arg6 = 0 to 3) { + // CHECK: [[STRIDE_0:%.+]] = constant 2 : index + // CHECK: [[MUL_0:%.+]] = muli [[STRIDE_0]], %arg3 : index + // CHECK: [[STRIDE_1:%.+]] = constant 2 : index + // CHECK: [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg5 : index + // CHECK: [[SPATIAL_H:%.+]] = addi [[MUL_0]], [[MUL_1]] : index + // CHECK: [[UPPER_INDEX_0:%.+]] = constant 31 : index + // CHECK: [[GREATER_THAN_UPPER_0:%.+]] = cmpi "sgt", [[SPATIAL_H]], [[UPPER_INDEX_0]] : index + // CHECK: [[H:%.+]] = select [[GREATER_THAN_UPPER_0]], [[UPPER_INDEX_0]], [[SPATIAL_H]] : index + + // CHECK: [[STRIDE_0_1:%.+]] = constant 2 : index + // CHECK: [[MUL_0_1:%.+]] = muli [[STRIDE_0_1]], %arg4 : index + // CHECK: [[STRIDE_1_1:%.+]] = constant 2 : index + // CHECK: [[MUL_1_1:%.+]] = muli [[STRIDE_1_1]], %arg6 : index + // CHECK: [[SPATIAL_W:%.+]] = addi [[MUL_0_1]], [[MUL_1_1]] : index + // CHECK: [[UPPER_INDEX_1:%.+]] = constant 31 : index + // CHECK: [[GREATER_THAN_UPPER_1:%.+]] = cmpi "sgt", [[SPATIAL_W]], [[UPPER_INDEX_1]] : index + // CHECK: [[W:%.+]] = select [[GREATER_THAN_UPPER_1]], [[UPPER_INDEX_1]], [[SPATIAL_W]] : index + + // CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<1x3x32x32xf32> + // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x14x14xf32> + // CHECK: [[CMP_2:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32 + // CHECK: [[SELECT:%.+]] = select [[CMP_2]], [[LOAD_Y]], [[LOAD_X]] : f32 + // CHECK: store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x14x14xf32> + // CHECK: } + // CHECK: } + // CHECK: return [[RES]] : memref<1x3x14x14xf32> +} + func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode_w_unknown_dims(%arg0 : tensor) -> tensor<*xf32> { %0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [3, 3], strides = [2, 2], ceil_mode = 1} : (tensor) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () @@ -1459,7 +1507,7 @@ func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode_w_unknown_dims(%arg // CHECK: [[ONE:%.+]] = constant 1 : i64 // CHECK: [[DIM_1:%.+]] = dim %arg0, 2 : memref // CHECK: [[DIM_1_i64:%.+]] = index_cast [[DIM_1]] : index to i64 - // CHECK: [[KERNEL_PAD_DILATION:%.+]] = constant -1 : i64 + // CHECK: [[KERNEL_PAD_DILATION:%.+]] = constant -3 : i64 // CHECK: [[NUMERATOR:%.+]] = addi [[DIM_1_i64]], [[KERNEL_PAD_DILATION]] : i64 // CHECK: [[DENOMINATOR:%.+]] = constant 2 : i64 // CHECK: [[DIV:%.+]] = divi_signed [[NUMERATOR]], [[DENOMINATOR]] : i64 @@ -1490,16 +1538,16 @@ func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode_w_unknown_dims(%arg // CHECK: [[SPATIAL_H:%.+]] = addi [[MUL_0]], %arg5 : index // CHECK: [[DIM_0_0:%.+]] = dim %arg0, 2 : memref // CHECK: [[ONE_INDEX:%.+]] = constant 1 : index - // CHECK: [[INPUT_INDEX_0:%.+]] = subi [[DIM_0_0]], [[ONE_INDEX]] : index - // CHECK: [[CMP_0:%.+]] = cmpi "sgt", [[SPATIAL_H]], [[INPUT_INDEX_0]] : index - // CHECK: [[H:%.+]] = select [[CMP_0]], [[INPUT_INDEX_0]], [[SPATIAL_H]] : index + // CHECK: [[UPPER_INDEX_0:%.+]] = subi [[DIM_0_0]], [[ONE_INDEX]] : index + // CHECK: [[GREATER_THAN_UPPER_0:%.+]] = cmpi "sgt", [[SPATIAL_H]], [[UPPER_INDEX_0]] : index + // CHECK: [[H:%.+]] = select [[GREATER_THAN_UPPER_0]], [[UPPER_INDEX_0]], [[SPATIAL_H]] : index // CHECK: [[STRIDE_1:%.+]] = constant 2 : index // CHECK: [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg4 : index // CHECK: [[SPATIAL_W:%.+]] = addi [[MUL_1]], %arg6 : index - // CHECK: [[INPUT_INDEX_1:%.+]] = constant 31 : index - // CHECK: [[CMP_1:%.+]] = cmpi "sgt", [[SPATIAL_W]], [[INPUT_INDEX_1]] : index - // CHECK: [[W:%.+]] = select [[CMP_1]], [[INPUT_INDEX_1]], [[SPATIAL_W]] : index + // CHECK: [[UPPER_INDEX_1:%.+]] = constant 31 : index + // CHECK: [[GREATER_THAN_UPPER_1:%.+]] = cmpi "sgt", [[SPATIAL_W]], [[UPPER_INDEX_1]] : index + // CHECK: [[W:%.+]] = select [[GREATER_THAN_UPPER_1]], [[UPPER_INDEX_1]], [[SPATIAL_W]] : index // CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref