Support dilations and enable the remaining e2e tests for MaxPoolSingleOut (#31)

* Support dilations and enable e2e tests

* Fix allocating memory for dynamic shape

* Edit comments

* Do dilation by computing an offset from kernel index

* Correct dilation formula, add an example of out-of-bound, and add a test for dilation

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
Tung D. Le 2020-03-18 22:55:50 +09:00 committed by GitHub
parent 4763e8a8bc
commit 2814ea3898
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 135 additions and 36 deletions

View File

@ -108,7 +108,7 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern {
for (int i = batchRank; i < resultShape.size(); ++i) { for (int i = batchRank; i < resultShape.size(); ++i) {
if (resultShape[i] < 0) { if (resultShape[i] < 0) {
// dim = // dim =
// let numerator = (input + pad - (kernel - 1) * dilation + 1) // let numerator = (input + pad - (kernel - 1) * dilation - 1)
// in let denomitor = stride // in let denomitor = stride
// in // in
// if (ceilMode) // if (ceilMode)
@ -117,13 +117,13 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern {
// floor(numerator / denominator) + 1 // floor(numerator / denominator) + 1
int spatialIndex = i - batchRank; int spatialIndex = i - batchRank;
// numerator = (input + pad - (kernel - 1) * dilation + 1) // numerator = (input + pad - (kernel - 1) * dilation - 1)
auto inputDim = rewriter.create<DimOp>(loc, inputOperand, i); auto inputDim = rewriter.create<DimOp>(loc, inputOperand, i);
auto inputVal = rewriter.create<IndexCastOp>( auto inputVal = rewriter.create<IndexCastOp>(
loc, inputDim, rewriter.getIntegerType(64)); loc, inputDim, rewriter.getIntegerType(64));
int64_t padKernelDilation = int64_t padKernelDilation =
(pads[spatialIndex] + pads[spatialIndex + spatialRank]) - (pads[spatialIndex] + pads[spatialIndex + spatialRank]) -
(kernelShape[spatialIndex] - 1) * dilations[spatialIndex] + 1; (kernelShape[spatialIndex] - 1) * dilations[spatialIndex] - 1;
auto padKernelDilationVal = rewriter.create<ConstantOp>( auto padKernelDilationVal = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr( loc, rewriter.getIntegerAttr(
rewriter.getIntegerType(64), padKernelDilation)); rewriter.getIntegerType(64), padKernelDilation));
@ -177,12 +177,14 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern {
// R[n][c][r1][r2] = negative_infinity; // R[n][c][r1][r2] = negative_infinity;
// for k1 = 0 .. KH: // for k1 = 0 .. KH:
// for k2 = 0 .. KW: // for k2 = 0 .. KW:
// t = D[n][c][s1 * r1 + k1][s2 * r2 + k2]; // t = D[n][c][s1 * r1 + k1 * d1][s2 * r2 + k2 * d2];
// R[n][c][r1][r2] = max(R[n][c][r1][r2], t); // R[n][c][r1][r2] = max(R[n][c][r1][r2], t);
// //
// Naming: // Naming:
// n, c, r1, r2: outer loop nest indices // n, c, r1, r2: outer loop nest indices
// k1, k2: inner loop nest indices // k1, k2: inner loop nest indices
// s1, s2: strides
// d1, d2: dilations
// //
// TODO: handle padding. // TODO: handle padding.
// //
@ -217,45 +219,87 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern {
rewriter.setInsertionPointToStart(innerLoops.getIterateBlock()); rewriter.setInsertionPointToStart(innerLoops.getIterateBlock());
{ {
// 3. Emit inner loop body // 3. Emit inner loop body
// t = D[n][c][s1 * r1 + k1][s2 * r2 + k2]; // t = D[n][c][s1 * r1 + k1 * d1][s2 * r2 + k2 * d2];
// R[n][c][r1][r2] = max(R[n][c][r1][r2], t); // R[n][c][r1][r2] = max(R[n][c][r1][r2], t);
// 3.1 Prepare indices for accesing the data tensor. // 3.1 Prepare indices for accesing the data tensor.
SmallVector<Value, 4> dataIndices; SmallVector<Value, 4> dataIndices;
// Batch indices: n, c // 3.1.1 Batch indices: n, c
for (int i = 0; i < batchRank; ++i) for (int i = 0; i < batchRank; ++i)
dataIndices.emplace_back(outerLoops.getInductionVar(i)); dataIndices.emplace_back(outerLoops.getInductionVar(i));
// Spatial indices: sX * rX + kX // 3.1.2 Insert spatial indices: sX * rX + kX * dX
for (int i = batchRank; i < nOuterLoops; ++i) { for (int i = batchRank; i < nOuterLoops; ++i) {
// Get index along the inner loop's induction variables.
// It is used to obtain kernel/pad/stride/dilation index.
int j = i - batchRank;
Value spatialIndex = outerLoops.getInductionVar(i); Value spatialIndex = outerLoops.getInductionVar(i);
// If strides are present then emit the correct access index. // If strides are present (not default) then emit the correct access
if (stridesAttribute && strides[i - batchRank] > 1) { // index.
spatialIndex = rewriter.create<MulIOp>(loc, // sX *= rX
rewriter.create<ConstantIndexOp>(loc, strides[i - batchRank]), if (strides[i - batchRank] > 1) {
outerLoops.getInductionVar(i)); auto strideIndex = emitConstantOp(
rewriter, loc, rewriter.getIndexType(), strides[j]);
spatialIndex = rewriter.create<MulIOp>(
loc, strideIndex, outerLoops.getInductionVar(i));
} }
spatialIndex = rewriter.create<AddIOp>(
loc, spatialIndex, innerLoops.getInductionVar(i - batchRank)); // Dilate the kernel index only if the dilation value is not one (not
// If ceil mode is enabled, then the calculated access index may // default). Otherwise, just add kernelIndex.
// exceed its dimension. In such a case, we will use the maximum auto kernelIndex = innerLoops.getInductionVar(j);
// index, which causes multiple visits to the element of the if (dilations[j] > 1) {
// sX += dX * kW
auto dilationIndex = emitConstantOp(
rewriter, loc, rewriter.getIndexType(), dilations[j]);
auto dilationKernelIndex =
rewriter.create<MulIOp>(loc, dilationIndex, kernelIndex);
spatialIndex =
rewriter.create<AddIOp>(loc, spatialIndex, dilationKernelIndex);
} else {
// sX += kX
spatialIndex =
rewriter.create<AddIOp>(loc, spatialIndex, kernelIndex);
}
// If ceil mode or dilation is enabled, then the calculated access
// index may exceed its dimension. In such a case, we will use the
// maximum index, which causes multiple visits to the element of the
// maximum index. // maximum index.
// TODO: Avoid multiple visits. // TODO: Avoid multiple visits.
if (ceilMode) { // Example of out-of-bound.
Value inputIndex; // - Given a 5x5 input X
// X = [[0, 0, 0, 0, 0],
// [1, 1, 1, 1, 1],
// [2, 2, 2, 2, 2],
// [3, 3, 3, 3, 3],
// [4, 4, 4, 4, 4]]
// - Do MaxPool with strides=[2, 2], kernel=[2, 2], ceilMode=true,
// output is a 3x3 array:
// Y = [[1, 1, 1],
// [3, 3, 3],
// [4, 4, 4]]
// - When computing Y[2, 0]:
// - In case of kernelIndex = 1, stride = 2
// - No dilation: spatialIndex = 2 * 2 + 1 = 5
// => out of bound
// - dilation = 2: spatialIndex = 2 * 2 + 2 * 1 = 6
// => out of bound
if (dilations[j] > 1 or ceilMode) {
Value upperIndex;
if (inputShape[i] < 0) { if (inputShape[i] < 0) {
Value inputDim = rewriter.create<DimOp>(loc, inputOperand, i); Value inputDim = rewriter.create<DimOp>(loc, inputOperand, i);
Value one = rewriter.create<ConstantIndexOp>(loc, 1); Value one = rewriter.create<ConstantIndexOp>(loc, 1);
inputIndex = rewriter.create<SubIOp>(loc, inputDim, one); upperIndex = rewriter.create<SubIOp>(loc, inputDim, one);
} else { } else {
inputIndex = upperIndex =
rewriter.create<ConstantIndexOp>(loc, inputShape[i] - 1); rewriter.create<ConstantIndexOp>(loc, inputShape[i] - 1);
} }
auto greaterCondition = rewriter.create<CmpIOp>( auto greaterCondition = rewriter.create<CmpIOp>(
loc, CmpIPredicate::sgt, spatialIndex, inputIndex); loc, CmpIPredicate::sgt, spatialIndex, upperIndex);
spatialIndex = rewriter.create<SelectOp>( spatialIndex = rewriter.create<SelectOp>(
loc, greaterCondition, inputIndex, spatialIndex); loc, greaterCondition, upperIndex, spatialIndex);
} }
dataIndices.emplace_back(spatialIndex); dataIndices.emplace_back(spatialIndex);
} }

View File

@ -312,6 +312,13 @@ test_to_enable = [
"test_maxpool_1d_default_cpu", "test_maxpool_1d_default_cpu",
"test_maxpool_2d_ceil_cpu", "test_maxpool_2d_ceil_cpu",
"test_maxpool_2d_default_cpu", "test_maxpool_2d_default_cpu",
"test_maxpool_2d_dilations_cpu",
"test_maxpool_2d_pads_cpu",
"test_maxpool_2d_precomputed_pads_cpu",
"test_maxpool_2d_precomputed_same_upper_cpu",
"test_maxpool_2d_precomputed_strides_cpu",
"test_maxpool_2d_same_lower_cpu",
"test_maxpool_2d_same_upper_cpu",
"test_maxpool_2d_strides_cpu", "test_maxpool_2d_strides_cpu",
"test_maxpool_3d_default_cpu", "test_maxpool_3d_default_cpu",

View File

@ -1429,15 +1429,17 @@ func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode(%arg0 : tensor<1x3x
// CHECK: [[STRIDE_0:%.+]] = constant 2 : index // CHECK: [[STRIDE_0:%.+]] = constant 2 : index
// CHECK: [[MUL_0:%.+]] = muli [[STRIDE_0]], %arg3 : index // CHECK: [[MUL_0:%.+]] = muli [[STRIDE_0]], %arg3 : index
// CHECK: [[SPATIAL_H:%.+]] = addi [[MUL_0]], %arg5 : index // CHECK: [[SPATIAL_H:%.+]] = addi [[MUL_0]], %arg5 : index
// CHECK: [[INPUT_INDEX_0:%.+]] = constant 31 : index // CHECK: [[UPPER_INDEX_0:%.+]] = constant 31 : index
// CHECK: [[CMP_0:%.+]] = cmpi "sgt", [[SPATIAL_H]], [[INPUT_INDEX_0]] : index // CHECK: [[GREATER_THAN_UPPER_0:%.+]] = cmpi "sgt", [[SPATIAL_H]], [[UPPER_INDEX_0]] : index
// CHECK: [[H:%.+]] = select [[CMP_0]], [[INPUT_INDEX_0]], [[SPATIAL_H]] : index // CHECK: [[H:%.+]] = select [[GREATER_THAN_UPPER_0]], [[UPPER_INDEX_0]], [[SPATIAL_H]] : index
// CHECK: [[STRIDE_1:%.+]] = constant 2 : index // CHECK: [[STRIDE_1:%.+]] = constant 2 : index
// CHECK: [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg4 : index // CHECK: [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg4 : index
// CHECK: [[SPATIAL_W:%.+]] = addi [[MUL_1]], %arg6 : index // CHECK: [[SPATIAL_W:%.+]] = addi [[MUL_1]], %arg6 : index
// CHECK: [[INPUT_INDEX_1:%.+]] = constant 31 : index // CHECK: [[UPPER_INDEX_1:%.+]] = constant 31 : index
// CHECK: [[CMP_1:%.+]] = cmpi "sgt", [[SPATIAL_W]], [[INPUT_INDEX_1]] : index // CHECK: [[GREATER_THAN_UPPER_1:%.+]] = cmpi "sgt", [[SPATIAL_W]], [[UPPER_INDEX_1]] : index
// CHECK: [[W:%.+]] = select [[CMP_1]], [[INPUT_INDEX_1]], [[SPATIAL_W]] : index // CHECK: [[W:%.+]] = select [[GREATER_THAN_UPPER_1]], [[UPPER_INDEX_1]], [[SPATIAL_W]] : index
// CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<1x3x32x32xf32> // CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<1x3x32x32xf32>
// CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32>
// CHECK: [[CMP_2:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32 // CHECK: [[CMP_2:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32
@ -1448,6 +1450,52 @@ func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode(%arg0 : tensor<1x3x
// CHECK: return [[RES]] : memref<1x3x16x16xf32> // CHECK: return [[RES]] : memref<1x3x16x16xf32>
} }
func @test_maxpooling_singleout_no_pad_w_strides_w_dilation(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*xf32> {
%0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [3, 3], strides = [2, 2], dilations = [2, 2]} : (tensor<1x3x32x32xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_maxpooling_singleout_no_pad_w_strides_w_dilation
// CHECK: [[RES:%.+]] = alloc() : memref<1x3x14x14xf32>
// CHECK: [[DEF_LOOPS_0:%.+]]:4 = krnl.define_loops 4
// CHECK: [[OPT_LOOPS_0:%.+]]:4 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS_0]]#0, [[DEF_LOOPS_0]]#1, [[DEF_LOOPS_0]]#2, [[DEF_LOOPS_0]]#3
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS_0]]#0, [[OPT_LOOPS_0]]#1, [[OPT_LOOPS_0]]#2, [[OPT_LOOPS_0]]#3) with ([[DEF_LOOPS_0]]#0 -> %arg1 = 0 to 1, [[DEF_LOOPS_0]]#1 -> %arg2 = 0 to 3, [[DEF_LOOPS_0]]#2 -> %arg3 = 0 to 14, [[DEF_LOOPS_0]]#3 -> %arg4 = 0 to 14) {
// CHECK: [[NEGATIVE_INFINITY:%.+]] = constant 0xFF800000 : f32
// CHECK: store [[NEGATIVE_INFINITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x14x14xf32>
// CHECK: [[DEF_LOOPS_1:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS_1:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS_1]]#0, [[DEF_LOOPS_1]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS_1]]#0, [[OPT_LOOPS_1]]#1) with ([[DEF_LOOPS_1]]#0 -> %arg5 = 0 to 3, [[DEF_LOOPS_1]]#1 -> %arg6 = 0 to 3) {
// CHECK: [[STRIDE_0:%.+]] = constant 2 : index
// CHECK: [[MUL_0:%.+]] = muli [[STRIDE_0]], %arg3 : index
// CHECK: [[STRIDE_1:%.+]] = constant 2 : index
// CHECK: [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg5 : index
// CHECK: [[SPATIAL_H:%.+]] = addi [[MUL_0]], [[MUL_1]] : index
// CHECK: [[UPPER_INDEX_0:%.+]] = constant 31 : index
// CHECK: [[GREATER_THAN_UPPER_0:%.+]] = cmpi "sgt", [[SPATIAL_H]], [[UPPER_INDEX_0]] : index
// CHECK: [[H:%.+]] = select [[GREATER_THAN_UPPER_0]], [[UPPER_INDEX_0]], [[SPATIAL_H]] : index
// CHECK: [[STRIDE_0_1:%.+]] = constant 2 : index
// CHECK: [[MUL_0_1:%.+]] = muli [[STRIDE_0_1]], %arg4 : index
// CHECK: [[STRIDE_1_1:%.+]] = constant 2 : index
// CHECK: [[MUL_1_1:%.+]] = muli [[STRIDE_1_1]], %arg6 : index
// CHECK: [[SPATIAL_W:%.+]] = addi [[MUL_0_1]], [[MUL_1_1]] : index
// CHECK: [[UPPER_INDEX_1:%.+]] = constant 31 : index
// CHECK: [[GREATER_THAN_UPPER_1:%.+]] = cmpi "sgt", [[SPATIAL_W]], [[UPPER_INDEX_1]] : index
// CHECK: [[W:%.+]] = select [[GREATER_THAN_UPPER_1]], [[UPPER_INDEX_1]], [[SPATIAL_W]] : index
// CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<1x3x32x32xf32>
// CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x14x14xf32>
// CHECK: [[CMP_2:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32
// CHECK: [[SELECT:%.+]] = select [[CMP_2]], [[LOAD_Y]], [[LOAD_X]] : f32
// CHECK: store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x14x14xf32>
// CHECK: }
// CHECK: }
// CHECK: return [[RES]] : memref<1x3x14x14xf32>
}
func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode_w_unknown_dims(%arg0 : tensor<?x3x?x32xf32>) -> tensor<*xf32> { func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode_w_unknown_dims(%arg0 : tensor<?x3x?x32xf32>) -> tensor<*xf32> {
%0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [3, 3], strides = [2, 2], ceil_mode = 1} : (tensor<?x3x?x32xf32>) -> tensor<*xf32> %0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [3, 3], strides = [2, 2], ceil_mode = 1} : (tensor<?x3x?x32xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
@ -1459,7 +1507,7 @@ func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode_w_unknown_dims(%arg
// CHECK: [[ONE:%.+]] = constant 1 : i64 // CHECK: [[ONE:%.+]] = constant 1 : i64
// CHECK: [[DIM_1:%.+]] = dim %arg0, 2 : memref<?x3x?x32xf32> // CHECK: [[DIM_1:%.+]] = dim %arg0, 2 : memref<?x3x?x32xf32>
// CHECK: [[DIM_1_i64:%.+]] = index_cast [[DIM_1]] : index to i64 // CHECK: [[DIM_1_i64:%.+]] = index_cast [[DIM_1]] : index to i64
// CHECK: [[KERNEL_PAD_DILATION:%.+]] = constant -1 : i64 // CHECK: [[KERNEL_PAD_DILATION:%.+]] = constant -3 : i64
// CHECK: [[NUMERATOR:%.+]] = addi [[DIM_1_i64]], [[KERNEL_PAD_DILATION]] : i64 // CHECK: [[NUMERATOR:%.+]] = addi [[DIM_1_i64]], [[KERNEL_PAD_DILATION]] : i64
// CHECK: [[DENOMINATOR:%.+]] = constant 2 : i64 // CHECK: [[DENOMINATOR:%.+]] = constant 2 : i64
// CHECK: [[DIV:%.+]] = divi_signed [[NUMERATOR]], [[DENOMINATOR]] : i64 // CHECK: [[DIV:%.+]] = divi_signed [[NUMERATOR]], [[DENOMINATOR]] : i64
@ -1490,16 +1538,16 @@ func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode_w_unknown_dims(%arg
// CHECK: [[SPATIAL_H:%.+]] = addi [[MUL_0]], %arg5 : index // CHECK: [[SPATIAL_H:%.+]] = addi [[MUL_0]], %arg5 : index
// CHECK: [[DIM_0_0:%.+]] = dim %arg0, 2 : memref<?x3x?x32xf32> // CHECK: [[DIM_0_0:%.+]] = dim %arg0, 2 : memref<?x3x?x32xf32>
// CHECK: [[ONE_INDEX:%.+]] = constant 1 : index // CHECK: [[ONE_INDEX:%.+]] = constant 1 : index
// CHECK: [[INPUT_INDEX_0:%.+]] = subi [[DIM_0_0]], [[ONE_INDEX]] : index // CHECK: [[UPPER_INDEX_0:%.+]] = subi [[DIM_0_0]], [[ONE_INDEX]] : index
// CHECK: [[CMP_0:%.+]] = cmpi "sgt", [[SPATIAL_H]], [[INPUT_INDEX_0]] : index // CHECK: [[GREATER_THAN_UPPER_0:%.+]] = cmpi "sgt", [[SPATIAL_H]], [[UPPER_INDEX_0]] : index
// CHECK: [[H:%.+]] = select [[CMP_0]], [[INPUT_INDEX_0]], [[SPATIAL_H]] : index // CHECK: [[H:%.+]] = select [[GREATER_THAN_UPPER_0]], [[UPPER_INDEX_0]], [[SPATIAL_H]] : index
// CHECK: [[STRIDE_1:%.+]] = constant 2 : index // CHECK: [[STRIDE_1:%.+]] = constant 2 : index
// CHECK: [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg4 : index // CHECK: [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg4 : index
// CHECK: [[SPATIAL_W:%.+]] = addi [[MUL_1]], %arg6 : index // CHECK: [[SPATIAL_W:%.+]] = addi [[MUL_1]], %arg6 : index
// CHECK: [[INPUT_INDEX_1:%.+]] = constant 31 : index // CHECK: [[UPPER_INDEX_1:%.+]] = constant 31 : index
// CHECK: [[CMP_1:%.+]] = cmpi "sgt", [[SPATIAL_W]], [[INPUT_INDEX_1]] : index // CHECK: [[GREATER_THAN_UPPER_1:%.+]] = cmpi "sgt", [[SPATIAL_W]], [[UPPER_INDEX_1]] : index
// CHECK: [[W:%.+]] = select [[CMP_1]], [[INPUT_INDEX_1]], [[SPATIAL_W]] : index // CHECK: [[W:%.+]] = select [[GREATER_THAN_UPPER_1]], [[UPPER_INDEX_1]], [[SPATIAL_W]] : index
// CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<?x3x?x32xf32> // CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<?x3x?x32xf32>
// CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<?x3x?x16xf32> // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<?x3x?x16xf32>