Support dilations and enable the remaining e2e tests for MaxPoolSingleOut (#31)
* Support dilations and enable e2e tests * Fix allocating memory for dynamic shape * Edit comments * Do dilation by computing an offset from kernel index * Correct dilation formula, add an example of out-of-bound, and add a test for dilation Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
4763e8a8bc
commit
2814ea3898
|
@ -108,7 +108,7 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern {
|
|||
for (int i = batchRank; i < resultShape.size(); ++i) {
|
||||
if (resultShape[i] < 0) {
|
||||
// dim =
|
||||
// let numerator = (input + pad - (kernel - 1) * dilation + 1)
|
||||
// let numerator = (input + pad - (kernel - 1) * dilation - 1)
|
||||
// in let denomitor = stride
|
||||
// in
|
||||
// if (ceilMode)
|
||||
|
@ -117,13 +117,13 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern {
|
|||
// floor(numerator / denominator) + 1
|
||||
int spatialIndex = i - batchRank;
|
||||
|
||||
// numerator = (input + pad - (kernel - 1) * dilation + 1)
|
||||
// numerator = (input + pad - (kernel - 1) * dilation - 1)
|
||||
auto inputDim = rewriter.create<DimOp>(loc, inputOperand, i);
|
||||
auto inputVal = rewriter.create<IndexCastOp>(
|
||||
loc, inputDim, rewriter.getIntegerType(64));
|
||||
int64_t padKernelDilation =
|
||||
(pads[spatialIndex] + pads[spatialIndex + spatialRank]) -
|
||||
(kernelShape[spatialIndex] - 1) * dilations[spatialIndex] + 1;
|
||||
(kernelShape[spatialIndex] - 1) * dilations[spatialIndex] - 1;
|
||||
auto padKernelDilationVal = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(64), padKernelDilation));
|
||||
|
@ -177,12 +177,14 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern {
|
|||
// R[n][c][r1][r2] = negative_infinity;
|
||||
// for k1 = 0 .. KH:
|
||||
// for k2 = 0 .. KW:
|
||||
// t = D[n][c][s1 * r1 + k1][s2 * r2 + k2];
|
||||
// t = D[n][c][s1 * r1 + k1 * d1][s2 * r2 + k2 * d2];
|
||||
// R[n][c][r1][r2] = max(R[n][c][r1][r2], t);
|
||||
//
|
||||
// Naming:
|
||||
// n, c, r1, r2: outer loop nest indices
|
||||
// k1, k2: inner loop nest indices
|
||||
// s1, s2: strides
|
||||
// d1, d2: dilations
|
||||
//
|
||||
// TODO: handle padding.
|
||||
//
|
||||
|
@ -217,45 +219,87 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern {
|
|||
rewriter.setInsertionPointToStart(innerLoops.getIterateBlock());
|
||||
{
|
||||
// 3. Emit inner loop body
|
||||
// t = D[n][c][s1 * r1 + k1][s2 * r2 + k2];
|
||||
// t = D[n][c][s1 * r1 + k1 * d1][s2 * r2 + k2 * d2];
|
||||
// R[n][c][r1][r2] = max(R[n][c][r1][r2], t);
|
||||
|
||||
// 3.1 Prepare indices for accesing the data tensor.
|
||||
SmallVector<Value, 4> dataIndices;
|
||||
// Batch indices: n, c
|
||||
// 3.1.1 Batch indices: n, c
|
||||
for (int i = 0; i < batchRank; ++i)
|
||||
dataIndices.emplace_back(outerLoops.getInductionVar(i));
|
||||
// Spatial indices: sX * rX + kX
|
||||
// 3.1.2 Insert spatial indices: sX * rX + kX * dX
|
||||
for (int i = batchRank; i < nOuterLoops; ++i) {
|
||||
// Get index along the inner loop's induction variables.
|
||||
// It is used to obtain kernel/pad/stride/dilation index.
|
||||
int j = i - batchRank;
|
||||
|
||||
Value spatialIndex = outerLoops.getInductionVar(i);
|
||||
// If strides are present then emit the correct access index.
|
||||
if (stridesAttribute && strides[i - batchRank] > 1) {
|
||||
spatialIndex = rewriter.create<MulIOp>(loc,
|
||||
rewriter.create<ConstantIndexOp>(loc, strides[i - batchRank]),
|
||||
outerLoops.getInductionVar(i));
|
||||
// If strides are present (not default) then emit the correct access
|
||||
// index.
|
||||
// sX *= rX
|
||||
if (strides[i - batchRank] > 1) {
|
||||
auto strideIndex = emitConstantOp(
|
||||
rewriter, loc, rewriter.getIndexType(), strides[j]);
|
||||
spatialIndex = rewriter.create<MulIOp>(
|
||||
loc, strideIndex, outerLoops.getInductionVar(i));
|
||||
}
|
||||
spatialIndex = rewriter.create<AddIOp>(
|
||||
loc, spatialIndex, innerLoops.getInductionVar(i - batchRank));
|
||||
// If ceil mode is enabled, then the calculated access index may
|
||||
// exceed its dimension. In such a case, we will use the maximum
|
||||
// index, which causes multiple visits to the element of the
|
||||
|
||||
// Dilate the kernel index only if the dilation value is not one (not
|
||||
// default). Otherwise, just add kernelIndex.
|
||||
auto kernelIndex = innerLoops.getInductionVar(j);
|
||||
if (dilations[j] > 1) {
|
||||
// sX += dX * kW
|
||||
auto dilationIndex = emitConstantOp(
|
||||
rewriter, loc, rewriter.getIndexType(), dilations[j]);
|
||||
auto dilationKernelIndex =
|
||||
rewriter.create<MulIOp>(loc, dilationIndex, kernelIndex);
|
||||
spatialIndex =
|
||||
rewriter.create<AddIOp>(loc, spatialIndex, dilationKernelIndex);
|
||||
} else {
|
||||
// sX += kX
|
||||
spatialIndex =
|
||||
rewriter.create<AddIOp>(loc, spatialIndex, kernelIndex);
|
||||
}
|
||||
|
||||
// If ceil mode or dilation is enabled, then the calculated access
|
||||
// index may exceed its dimension. In such a case, we will use the
|
||||
// maximum index, which causes multiple visits to the element of the
|
||||
// maximum index.
|
||||
// TODO: Avoid multiple visits.
|
||||
if (ceilMode) {
|
||||
Value inputIndex;
|
||||
// Example of out-of-bound.
|
||||
// - Given a 5x5 input X
|
||||
// X = [[0, 0, 0, 0, 0],
|
||||
// [1, 1, 1, 1, 1],
|
||||
// [2, 2, 2, 2, 2],
|
||||
// [3, 3, 3, 3, 3],
|
||||
// [4, 4, 4, 4, 4]]
|
||||
// - Do MaxPool with strides=[2, 2], kernel=[2, 2], ceilMode=true,
|
||||
// output is a 3x3 array:
|
||||
// Y = [[1, 1, 1],
|
||||
// [3, 3, 3],
|
||||
// [4, 4, 4]]
|
||||
// - When computing Y[2, 0]:
|
||||
// - In case of kernelIndex = 1, stride = 2
|
||||
// - No dilation: spatialIndex = 2 * 2 + 1 = 5
|
||||
// => out of bound
|
||||
// - dilation = 2: spatialIndex = 2 * 2 + 2 * 1 = 6
|
||||
// => out of bound
|
||||
if (dilations[j] > 1 or ceilMode) {
|
||||
Value upperIndex;
|
||||
if (inputShape[i] < 0) {
|
||||
Value inputDim = rewriter.create<DimOp>(loc, inputOperand, i);
|
||||
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
|
||||
inputIndex = rewriter.create<SubIOp>(loc, inputDim, one);
|
||||
upperIndex = rewriter.create<SubIOp>(loc, inputDim, one);
|
||||
} else {
|
||||
inputIndex =
|
||||
upperIndex =
|
||||
rewriter.create<ConstantIndexOp>(loc, inputShape[i] - 1);
|
||||
}
|
||||
auto greaterCondition = rewriter.create<CmpIOp>(
|
||||
loc, CmpIPredicate::sgt, spatialIndex, inputIndex);
|
||||
loc, CmpIPredicate::sgt, spatialIndex, upperIndex);
|
||||
spatialIndex = rewriter.create<SelectOp>(
|
||||
loc, greaterCondition, inputIndex, spatialIndex);
|
||||
loc, greaterCondition, upperIndex, spatialIndex);
|
||||
}
|
||||
|
||||
dataIndices.emplace_back(spatialIndex);
|
||||
}
|
||||
|
||||
|
|
|
@ -312,6 +312,13 @@ test_to_enable = [
|
|||
"test_maxpool_1d_default_cpu",
|
||||
"test_maxpool_2d_ceil_cpu",
|
||||
"test_maxpool_2d_default_cpu",
|
||||
"test_maxpool_2d_dilations_cpu",
|
||||
"test_maxpool_2d_pads_cpu",
|
||||
"test_maxpool_2d_precomputed_pads_cpu",
|
||||
"test_maxpool_2d_precomputed_same_upper_cpu",
|
||||
"test_maxpool_2d_precomputed_strides_cpu",
|
||||
"test_maxpool_2d_same_lower_cpu",
|
||||
"test_maxpool_2d_same_upper_cpu",
|
||||
"test_maxpool_2d_strides_cpu",
|
||||
"test_maxpool_3d_default_cpu",
|
||||
|
||||
|
|
|
@ -1429,15 +1429,17 @@ func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode(%arg0 : tensor<1x3x
|
|||
// CHECK: [[STRIDE_0:%.+]] = constant 2 : index
|
||||
// CHECK: [[MUL_0:%.+]] = muli [[STRIDE_0]], %arg3 : index
|
||||
// CHECK: [[SPATIAL_H:%.+]] = addi [[MUL_0]], %arg5 : index
|
||||
// CHECK: [[INPUT_INDEX_0:%.+]] = constant 31 : index
|
||||
// CHECK: [[CMP_0:%.+]] = cmpi "sgt", [[SPATIAL_H]], [[INPUT_INDEX_0]] : index
|
||||
// CHECK: [[H:%.+]] = select [[CMP_0]], [[INPUT_INDEX_0]], [[SPATIAL_H]] : index
|
||||
// CHECK: [[UPPER_INDEX_0:%.+]] = constant 31 : index
|
||||
// CHECK: [[GREATER_THAN_UPPER_0:%.+]] = cmpi "sgt", [[SPATIAL_H]], [[UPPER_INDEX_0]] : index
|
||||
// CHECK: [[H:%.+]] = select [[GREATER_THAN_UPPER_0]], [[UPPER_INDEX_0]], [[SPATIAL_H]] : index
|
||||
|
||||
// CHECK: [[STRIDE_1:%.+]] = constant 2 : index
|
||||
// CHECK: [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg4 : index
|
||||
// CHECK: [[SPATIAL_W:%.+]] = addi [[MUL_1]], %arg6 : index
|
||||
// CHECK: [[INPUT_INDEX_1:%.+]] = constant 31 : index
|
||||
// CHECK: [[CMP_1:%.+]] = cmpi "sgt", [[SPATIAL_W]], [[INPUT_INDEX_1]] : index
|
||||
// CHECK: [[W:%.+]] = select [[CMP_1]], [[INPUT_INDEX_1]], [[SPATIAL_W]] : index
|
||||
// CHECK: [[UPPER_INDEX_1:%.+]] = constant 31 : index
|
||||
// CHECK: [[GREATER_THAN_UPPER_1:%.+]] = cmpi "sgt", [[SPATIAL_W]], [[UPPER_INDEX_1]] : index
|
||||
// CHECK: [[W:%.+]] = select [[GREATER_THAN_UPPER_1]], [[UPPER_INDEX_1]], [[SPATIAL_W]] : index
|
||||
|
||||
// CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<1x3x32x32xf32>
|
||||
// CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32>
|
||||
// CHECK: [[CMP_2:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32
|
||||
|
@ -1448,6 +1450,52 @@ func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode(%arg0 : tensor<1x3x
|
|||
// CHECK: return [[RES]] : memref<1x3x16x16xf32>
|
||||
}
|
||||
|
||||
func @test_maxpooling_singleout_no_pad_w_strides_w_dilation(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [3, 3], strides = [2, 2], dilations = [2, 2]} : (tensor<1x3x32x32xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_maxpooling_singleout_no_pad_w_strides_w_dilation
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<1x3x14x14xf32>
|
||||
// CHECK: [[DEF_LOOPS_0:%.+]]:4 = krnl.define_loops 4
|
||||
// CHECK: [[OPT_LOOPS_0:%.+]]:4 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS_0]]#0, [[DEF_LOOPS_0]]#1, [[DEF_LOOPS_0]]#2, [[DEF_LOOPS_0]]#3
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS_0]]#0, [[OPT_LOOPS_0]]#1, [[OPT_LOOPS_0]]#2, [[OPT_LOOPS_0]]#3) with ([[DEF_LOOPS_0]]#0 -> %arg1 = 0 to 1, [[DEF_LOOPS_0]]#1 -> %arg2 = 0 to 3, [[DEF_LOOPS_0]]#2 -> %arg3 = 0 to 14, [[DEF_LOOPS_0]]#3 -> %arg4 = 0 to 14) {
|
||||
// CHECK: [[NEGATIVE_INFINITY:%.+]] = constant 0xFF800000 : f32
|
||||
// CHECK: store [[NEGATIVE_INFINITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x14x14xf32>
|
||||
// CHECK: [[DEF_LOOPS_1:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS_1:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS_1]]#0, [[DEF_LOOPS_1]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS_1]]#0, [[OPT_LOOPS_1]]#1) with ([[DEF_LOOPS_1]]#0 -> %arg5 = 0 to 3, [[DEF_LOOPS_1]]#1 -> %arg6 = 0 to 3) {
|
||||
// CHECK: [[STRIDE_0:%.+]] = constant 2 : index
|
||||
// CHECK: [[MUL_0:%.+]] = muli [[STRIDE_0]], %arg3 : index
|
||||
// CHECK: [[STRIDE_1:%.+]] = constant 2 : index
|
||||
// CHECK: [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg5 : index
|
||||
// CHECK: [[SPATIAL_H:%.+]] = addi [[MUL_0]], [[MUL_1]] : index
|
||||
// CHECK: [[UPPER_INDEX_0:%.+]] = constant 31 : index
|
||||
// CHECK: [[GREATER_THAN_UPPER_0:%.+]] = cmpi "sgt", [[SPATIAL_H]], [[UPPER_INDEX_0]] : index
|
||||
// CHECK: [[H:%.+]] = select [[GREATER_THAN_UPPER_0]], [[UPPER_INDEX_0]], [[SPATIAL_H]] : index
|
||||
|
||||
// CHECK: [[STRIDE_0_1:%.+]] = constant 2 : index
|
||||
// CHECK: [[MUL_0_1:%.+]] = muli [[STRIDE_0_1]], %arg4 : index
|
||||
// CHECK: [[STRIDE_1_1:%.+]] = constant 2 : index
|
||||
// CHECK: [[MUL_1_1:%.+]] = muli [[STRIDE_1_1]], %arg6 : index
|
||||
// CHECK: [[SPATIAL_W:%.+]] = addi [[MUL_0_1]], [[MUL_1_1]] : index
|
||||
// CHECK: [[UPPER_INDEX_1:%.+]] = constant 31 : index
|
||||
// CHECK: [[GREATER_THAN_UPPER_1:%.+]] = cmpi "sgt", [[SPATIAL_W]], [[UPPER_INDEX_1]] : index
|
||||
// CHECK: [[W:%.+]] = select [[GREATER_THAN_UPPER_1]], [[UPPER_INDEX_1]], [[SPATIAL_W]] : index
|
||||
|
||||
// CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<1x3x32x32xf32>
|
||||
// CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x14x14xf32>
|
||||
// CHECK: [[CMP_2:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32
|
||||
// CHECK: [[SELECT:%.+]] = select [[CMP_2]], [[LOAD_Y]], [[LOAD_X]] : f32
|
||||
// CHECK: store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x14x14xf32>
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: return [[RES]] : memref<1x3x14x14xf32>
|
||||
}
|
||||
|
||||
func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode_w_unknown_dims(%arg0 : tensor<?x3x?x32xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [3, 3], strides = [2, 2], ceil_mode = 1} : (tensor<?x3x?x32xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
@ -1459,7 +1507,7 @@ func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode_w_unknown_dims(%arg
|
|||
// CHECK: [[ONE:%.+]] = constant 1 : i64
|
||||
// CHECK: [[DIM_1:%.+]] = dim %arg0, 2 : memref<?x3x?x32xf32>
|
||||
// CHECK: [[DIM_1_i64:%.+]] = index_cast [[DIM_1]] : index to i64
|
||||
// CHECK: [[KERNEL_PAD_DILATION:%.+]] = constant -1 : i64
|
||||
// CHECK: [[KERNEL_PAD_DILATION:%.+]] = constant -3 : i64
|
||||
// CHECK: [[NUMERATOR:%.+]] = addi [[DIM_1_i64]], [[KERNEL_PAD_DILATION]] : i64
|
||||
// CHECK: [[DENOMINATOR:%.+]] = constant 2 : i64
|
||||
// CHECK: [[DIV:%.+]] = divi_signed [[NUMERATOR]], [[DENOMINATOR]] : i64
|
||||
|
@ -1490,16 +1538,16 @@ func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode_w_unknown_dims(%arg
|
|||
// CHECK: [[SPATIAL_H:%.+]] = addi [[MUL_0]], %arg5 : index
|
||||
// CHECK: [[DIM_0_0:%.+]] = dim %arg0, 2 : memref<?x3x?x32xf32>
|
||||
// CHECK: [[ONE_INDEX:%.+]] = constant 1 : index
|
||||
// CHECK: [[INPUT_INDEX_0:%.+]] = subi [[DIM_0_0]], [[ONE_INDEX]] : index
|
||||
// CHECK: [[CMP_0:%.+]] = cmpi "sgt", [[SPATIAL_H]], [[INPUT_INDEX_0]] : index
|
||||
// CHECK: [[H:%.+]] = select [[CMP_0]], [[INPUT_INDEX_0]], [[SPATIAL_H]] : index
|
||||
// CHECK: [[UPPER_INDEX_0:%.+]] = subi [[DIM_0_0]], [[ONE_INDEX]] : index
|
||||
// CHECK: [[GREATER_THAN_UPPER_0:%.+]] = cmpi "sgt", [[SPATIAL_H]], [[UPPER_INDEX_0]] : index
|
||||
// CHECK: [[H:%.+]] = select [[GREATER_THAN_UPPER_0]], [[UPPER_INDEX_0]], [[SPATIAL_H]] : index
|
||||
|
||||
// CHECK: [[STRIDE_1:%.+]] = constant 2 : index
|
||||
// CHECK: [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg4 : index
|
||||
// CHECK: [[SPATIAL_W:%.+]] = addi [[MUL_1]], %arg6 : index
|
||||
// CHECK: [[INPUT_INDEX_1:%.+]] = constant 31 : index
|
||||
// CHECK: [[CMP_1:%.+]] = cmpi "sgt", [[SPATIAL_W]], [[INPUT_INDEX_1]] : index
|
||||
// CHECK: [[W:%.+]] = select [[CMP_1]], [[INPUT_INDEX_1]], [[SPATIAL_W]] : index
|
||||
// CHECK: [[UPPER_INDEX_1:%.+]] = constant 31 : index
|
||||
// CHECK: [[GREATER_THAN_UPPER_1:%.+]] = cmpi "sgt", [[SPATIAL_W]], [[UPPER_INDEX_1]] : index
|
||||
// CHECK: [[W:%.+]] = select [[GREATER_THAN_UPPER_1]], [[UPPER_INDEX_1]], [[SPATIAL_W]] : index
|
||||
|
||||
// CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<?x3x?x32xf32>
|
||||
// CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<?x3x?x16xf32>
|
||||
|
|
Loading…
Reference in New Issue