Lower MaxPoolSingleOutOp to Krnl dialect (#1)
* Lower MaxPoolSingleOutOp to Krnl dialect * Edit comments * Update changes according to the new folder structure * Add MLIR tests * Support ceil_mode * Merge the first two krnl loops into one krnl loop; remove attribute checks * Dynamically allocate memory for the result if the result has unknown dimensions Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
		
							parent
							
								
									e97df0b343
								
							
						
					
					
						commit
						e4c23da4fd
					
				|  | @ -85,6 +85,7 @@ add_library(onnf_lower_frontend | |||
|         conversion/onnx_to_krnl/math/softmax.cpp | ||||
|         conversion/onnx_to_krnl/nn/conv.cpp | ||||
|         conversion/onnx_to_krnl/nn/normalization.cpp | ||||
|         conversion/onnx_to_krnl/nn/pooling.cpp | ||||
|         conversion/onnx_to_krnl/tensor/identity.cpp | ||||
|         conversion/onnx_to_krnl/tensor/reshape.cpp | ||||
|         conversion/onnx_to_krnl/tensor/transpose.cpp | ||||
|  |  | |||
|  | @ -99,6 +99,7 @@ void FrontendToKrnlLoweringPass::runOnModule() { | |||
|   // Neural network
 | ||||
|   populateLoweringONNXConvOpPattern(patterns, &getContext()); | ||||
|   populateLoweringONNXNormalizationOpPattern(patterns, &getContext()); | ||||
|   populateLoweringONNXPoolingOpPattern(patterns, &getContext()); | ||||
|   // Entry point
 | ||||
|   patterns.insert<ONNXEntryPointLowering>(&getContext()); | ||||
| 
 | ||||
|  |  | |||
|  | @ -0,0 +1,294 @@ | |||
| //===----- pooling.cpp - Lowering Pooling Ops -----------------------------===//
 | ||||
| //
 | ||||
| // Copyright 2019 The IBM Research Authors.
 | ||||
| //
 | ||||
| // =============================================================================
 | ||||
| //
 | ||||
| // This file lowers the ONNX Pooling Operators to Krnl dialect.
 | ||||
| //
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| #include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp" | ||||
| 
 | ||||
| using namespace mlir; | ||||
| 
 | ||||
| // Identity values
 | ||||
| template <> | ||||
| float getIdentityValue<float, ONNXMaxPoolSingleOutOp>() { | ||||
|   return (float)-std::numeric_limits<float>::infinity(); | ||||
| } | ||||
| 
 | ||||
| template <> | ||||
| int getIdentityValue<int, ONNXMaxPoolSingleOutOp>() { | ||||
|   return std::numeric_limits<int>::min(); | ||||
| } | ||||
| 
 | ||||
| template <> | ||||
| Value mapToLowerScalarOp<ONNXMaxPoolSingleOutOp>(Operation *op, | ||||
|     ArrayRef<Type> result_types, ArrayRef<Value> operands, | ||||
|     ConversionPatternRewriter &rewriter) { | ||||
|   auto loc = op->getLoc(); | ||||
|   Value lhs = operands[0]; | ||||
|   Value rhs = operands[1]; | ||||
|   auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs); | ||||
|   auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs); | ||||
|   return result; | ||||
| } | ||||
| 
 | ||||
| struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { | ||||
|   ONNXMaxPoolSingleOutOpLowering(MLIRContext *ctx) | ||||
|       : ConversionPattern( | ||||
|             mlir::ONNXMaxPoolSingleOutOp::getOperationName(), 1, ctx) {} | ||||
| 
 | ||||
|   PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||||
|       ConversionPatternRewriter &rewriter) const final { | ||||
|     auto loc = op->getLoc(); | ||||
| 
 | ||||
|     // Match
 | ||||
|     ONNXMaxPoolSingleOutOp poolOp = llvm::dyn_cast<ONNXMaxPoolSingleOutOp>(op); | ||||
| 
 | ||||
|     // Read kernel_shape attribute
 | ||||
|     SmallVector<int, 4> kernelShape; | ||||
|     auto kernelShapeAttribute = poolOp.kernel_shapeAttr(); | ||||
|     for (auto dim : kernelShapeAttribute.getValue()) | ||||
|       kernelShape.emplace_back(dim.cast<IntegerAttr>().getInt()); | ||||
| 
 | ||||
|     // Read strides attribute
 | ||||
|     SmallVector<int, 4> strides; | ||||
|     auto stridesAttribute = poolOp.stridesAttr(); | ||||
|     for (auto stride : stridesAttribute.getValue()) | ||||
|       strides.emplace_back(stride.cast<IntegerAttr>().getInt()); | ||||
| 
 | ||||
|     // Read ceil_mode attribute
 | ||||
|     auto ceilMode = poolOp.ceil_mode().getSExtValue(); | ||||
| 
 | ||||
|     // Read pads attribute
 | ||||
|     SmallVector<int, 4> pads; | ||||
|     auto padsAttribute = poolOp.padsAttr(); | ||||
|     for (auto pad : padsAttribute.getValue()) | ||||
|       pads.emplace_back(pad.cast<IntegerAttr>().getInt()); | ||||
| 
 | ||||
|     // Read dilations attribute
 | ||||
|     SmallVector<int, 4> dilations; | ||||
|     auto dilationsAttribute = poolOp.dilationsAttr(); | ||||
|     for (auto dilation : dilationsAttribute.getValue()) | ||||
|       dilations.emplace_back(dilation.cast<IntegerAttr>().getInt()); | ||||
| 
 | ||||
|     // Type information about the input and result of this operation.
 | ||||
|     auto &inputOperand = operands[0]; | ||||
|     auto inputShape = inputOperand.getType().cast<MemRefType>().getShape(); | ||||
|     auto memRefType = convertToMemRefType(*op->result_type_begin()); | ||||
|     auto resultShape = memRefType.getShape(); | ||||
|     auto resultElementType = memRefType.getElementType(); | ||||
| 
 | ||||
|     // Batch indices: N and C dimensions
 | ||||
|     int batchRank = 2; | ||||
| 
 | ||||
|     // Insert an allocation and deallocation for the result of this operation.
 | ||||
|     Value alloc; | ||||
|     bool insertDealloc = checkInsertDealloc(op); | ||||
| 
 | ||||
|     if (hasAllConstantDimensions(memRefType)) | ||||
|       alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); | ||||
|     else { | ||||
|       // Compute dimensions of the result of this operation.
 | ||||
|       SmallVector<Value, 2> allocOperands; | ||||
|       for (int i = 0; i < batchRank; ++i) { | ||||
|         if (resultShape[i] < 0) { | ||||
|           auto dim = rewriter.create<DimOp>(loc, inputOperand, i); | ||||
|           allocOperands.emplace_back(dim); | ||||
|         } | ||||
|       } | ||||
| 
 | ||||
|       Value zero, one; | ||||
|       if (ceilMode) { | ||||
|         zero = rewriter.create<ConstantOp>( | ||||
|             loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); | ||||
|       } | ||||
|       one = rewriter.create<ConstantOp>( | ||||
|           loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); | ||||
| 
 | ||||
|       int spatialRank = resultShape.size() - batchRank; | ||||
|       for (int i = batchRank; i < resultShape.size(); ++i) { | ||||
|         if (resultShape[i] < 0) { | ||||
|           // dim =
 | ||||
|           //   let numerator = (input + pad - (kernel - 1) * dilation + 1)
 | ||||
|           //   in let denomitor = stride
 | ||||
|           //      in
 | ||||
|           //        if (ceilMode)
 | ||||
|           //          ceil(numerator / denominator) + 1
 | ||||
|           //        else
 | ||||
|           //          floor(numerator / denominator) + 1
 | ||||
|           int spatialIndex = i - batchRank; | ||||
| 
 | ||||
|           // numerator = (input + pad - (kernel - 1) * dilation + 1)
 | ||||
|           auto inputDim = rewriter.create<DimOp>(loc, inputOperand, i); | ||||
|           auto inputVal = rewriter.create<IndexCastOp>( | ||||
|               loc, inputDim, rewriter.getIntegerType(64)); | ||||
|           int64_t padKernelDilation = | ||||
|               (pads[spatialIndex] + pads[spatialIndex + spatialRank]) - | ||||
|               (kernelShape[spatialIndex] - 1) * dilations[spatialIndex] + 1; | ||||
|           auto padKernelDilationVal = rewriter.create<ConstantOp>( | ||||
|               loc, rewriter.getIntegerAttr( | ||||
|                        rewriter.getIntegerType(64), padKernelDilation)); | ||||
|           auto numeratorVal = | ||||
|               rewriter.create<AddIOp>(loc, inputVal, padKernelDilationVal); | ||||
|           // denominator
 | ||||
|           auto denominatorVal = rewriter.create<ConstantOp>( | ||||
|               loc, rewriter.getIntegerAttr( | ||||
|                        rewriter.getIntegerType(64), strides[spatialIndex])); | ||||
| 
 | ||||
|           // numerator / denominator
 | ||||
|           Value dimVal = | ||||
|               rewriter.create<SignedDivIOp>(loc, numeratorVal, denominatorVal); | ||||
| 
 | ||||
|           if (ceilMode) { | ||||
|             auto remainder = rewriter.create<SignedRemIOp>( | ||||
|                 loc, numeratorVal, denominatorVal); | ||||
|             auto isZero = rewriter.create<CmpIOp>( | ||||
|                 loc, CmpIPredicate::eq, remainder, zero); | ||||
|             auto dimPlusOne = rewriter.create<AddIOp>(loc, dimVal, one); | ||||
|             dimVal = rewriter.create<SelectOp>(loc, isZero, dimVal, dimPlusOne); | ||||
|           } | ||||
| 
 | ||||
|           dimVal = rewriter.create<AddIOp>(loc, dimVal, one); | ||||
|           allocOperands.emplace_back(rewriter.create<IndexCastOp>( | ||||
|               loc, dimVal, rewriter.getIndexType())); | ||||
|         } | ||||
|       } | ||||
|       alloc = rewriter.create<AllocOp>(loc, memRefType, allocOperands); | ||||
|       if (insertDealloc) { | ||||
|         auto *parentBlock = alloc.getDefiningOp()->getBlock(); | ||||
|         auto dealloc = rewriter.create<DeallocOp>(loc, alloc); | ||||
|         dealloc.getOperation()->moveBefore(&parentBlock->back()); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     // R = MaxPool(D)
 | ||||
|     //
 | ||||
|     // The input/output shapes will look like this:
 | ||||
|     //
 | ||||
|     // D (NxCxHxW) -> R (NxCxRHxRW)
 | ||||
|     //
 | ||||
|     // The loop nest will look as follows:
 | ||||
|     //
 | ||||
|     // strides = [s1, s2]
 | ||||
|     //
 | ||||
|     // for n = 0 .. N:
 | ||||
|     //   for c = 0 .. C:
 | ||||
|     //     for r1 = 0 .. RH:
 | ||||
|     //       for r2 = 0 .. RW:
 | ||||
|     //         R[n][c][r1][r2] = negative_infinity;
 | ||||
|     //         for k1 = 0 .. KH:
 | ||||
|     //           for k2 = 0 .. KW:
 | ||||
|     //             t = D[n][c][s1 * r1 + k1][s2 * r2 + k2];
 | ||||
|     //             R[n][c][r1][r2] = max(R[n][c][r1][r2], t);
 | ||||
|     //
 | ||||
|     // Naming:
 | ||||
|     //   n, c, r1, r2: outer loop nest indices
 | ||||
|     //   k1, k2: inner loop nest indices
 | ||||
|     //
 | ||||
|     // TODO: handle padding.
 | ||||
|     //
 | ||||
| 
 | ||||
|     // 1. Define outer loops and emit empty optimization block.
 | ||||
|     auto nOuterLoops = resultShape.size(); | ||||
|     BuildKrnlLoop outerLoops(rewriter, loc, nOuterLoops); | ||||
|     outerLoops.createDefineOptimizeAndIterateOp(alloc); | ||||
| 
 | ||||
|     rewriter.setInsertionPointToStart(outerLoops.getIterateBlock()); | ||||
|     { | ||||
|       // 2. Emit the body of the outer loop nest.
 | ||||
|       SmallVector<Value, 4> resultIndices; | ||||
|       for (int i = 0; i < nOuterLoops; ++i) | ||||
|         resultIndices.emplace_back(outerLoops.getInductionVar(i)); | ||||
| 
 | ||||
|       // 2.1 Emit: R[n][c][r1][r2] = negative_infinity;
 | ||||
|       Value identity; | ||||
|       if (resultElementType.isa<FloatType>()) { | ||||
|         identity = rewriter.create<ConstantOp>( | ||||
|             loc, FloatAttr::get(resultElementType, | ||||
|                      getIdentityValue<float, ONNXMaxPoolSingleOutOp>())); | ||||
|       } else if (resultElementType.isa<IntegerType>()) { | ||||
|         identity = rewriter.create<ConstantOp>( | ||||
|             loc, IntegerAttr::get(resultElementType, | ||||
|                      getIdentityValue<int, ONNXMaxPoolSingleOutOp>())); | ||||
|       } else { | ||||
|         emitError(loc, "unsupported element type"); | ||||
|       } | ||||
|       rewriter.create<StoreOp>(loc, identity, alloc, resultIndices); | ||||
| 
 | ||||
|       // 2.2 Define inner loops.
 | ||||
|       int nInnerLoops = kernelShape.size(); | ||||
|       BuildKrnlLoop innerLoops(rewriter, loc, nInnerLoops); | ||||
|       innerLoops.createDefineAndOptimizeOp(); | ||||
|       //   for Kx = 0 .. KX
 | ||||
|       for (int i = 0; i < nInnerLoops; ++i) | ||||
|         innerLoops.pushBounds(0, kernelShape[i]); | ||||
| 
 | ||||
|       // 2.3 Emit inner loop nest.
 | ||||
|       innerLoops.createIterateOp(); | ||||
|       rewriter.setInsertionPointToStart(innerLoops.getIterateBlock()); | ||||
|       { | ||||
|         // 3. Emit inner loop body
 | ||||
|         // t = D[n][c][s1 * r1 + k1][s2 * r2 + k2];
 | ||||
|         // R[n][c][r1][r2] = max(R[n][c][r1][r2], t);
 | ||||
| 
 | ||||
|         // 3.1 Prepare indices for accesing the data tensor.
 | ||||
|         SmallVector<Value, 4> dataIndices; | ||||
|         // Batch indices: n, c
 | ||||
|         for (int i = 0; i < batchRank; ++i) | ||||
|           dataIndices.emplace_back(outerLoops.getInductionVar(i)); | ||||
|         // Spatial indices: sX * rX + kX
 | ||||
|         for (int i = batchRank; i < nOuterLoops; ++i) { | ||||
|           Value spatialIndex = outerLoops.getInductionVar(i); | ||||
|           // If strides are present then emit the correct access index.
 | ||||
|           if (stridesAttribute && strides[i - batchRank] > 1) { | ||||
|             spatialIndex = rewriter.create<MulIOp>(loc, | ||||
|                 rewriter.create<ConstantIndexOp>(loc, strides[i - batchRank]), | ||||
|                 outerLoops.getInductionVar(i)); | ||||
|           } | ||||
|           spatialIndex = rewriter.create<AddIOp>( | ||||
|               loc, spatialIndex, innerLoops.getInductionVar(i - batchRank)); | ||||
|           // If ceil mode is enabled, then the calculated access index may
 | ||||
|           // exceed its dimension. In such a case, we will use the maximum
 | ||||
|           // index, which causes multiple visits to the element of the
 | ||||
|           // maximum index.
 | ||||
|           // TODO: Avoid multiple visits.
 | ||||
|           if (ceilMode) { | ||||
|             Value inputIndex; | ||||
|             if (inputShape[i] < 0) { | ||||
|               Value inputDim = rewriter.create<DimOp>(loc, inputOperand, i); | ||||
|               Value one = rewriter.create<ConstantIndexOp>(loc, 1); | ||||
|               inputIndex = rewriter.create<SubIOp>(loc, inputDim, one); | ||||
|             } else { | ||||
|               inputIndex = | ||||
|                   rewriter.create<ConstantIndexOp>(loc, inputShape[i] - 1); | ||||
|             } | ||||
|             auto greaterCondition = rewriter.create<CmpIOp>( | ||||
|                 loc, CmpIPredicate::sgt, spatialIndex, inputIndex); | ||||
|             spatialIndex = rewriter.create<SelectOp>( | ||||
|                 loc, greaterCondition, inputIndex, spatialIndex); | ||||
|           } | ||||
|           dataIndices.emplace_back(spatialIndex); | ||||
|         } | ||||
| 
 | ||||
|         // 3.2 Do pooling.
 | ||||
|         auto loadData = rewriter.create<LoadOp>(loc, inputOperand, dataIndices); | ||||
|         auto loadPartialResult = | ||||
|             rewriter.create<LoadOp>(loc, alloc, resultIndices); | ||||
|         Value result = mapToLowerScalarOp<ONNXMaxPoolSingleOutOp>( | ||||
|             op, resultElementType, {loadPartialResult, loadData}, rewriter); | ||||
|         rewriter.create<StoreOp>(loc, result, alloc, resultIndices); | ||||
|       } | ||||
|     } | ||||
|     rewriter.replaceOp(op, alloc); | ||||
| 
 | ||||
|     return matchSuccess(); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| void populateLoweringONNXPoolingOpPattern( | ||||
|     OwningRewritePatternList &patterns, MLIRContext *ctx) { | ||||
|   patterns.insert<ONNXMaxPoolSingleOutOpLowering>(ctx); | ||||
| } | ||||
|  | @ -202,6 +202,9 @@ void populateLoweringONNXConvOpPattern( | |||
| void populateLoweringONNXNormalizationOpPattern( | ||||
|     OwningRewritePatternList &patterns, MLIRContext *ctx); | ||||
| 
 | ||||
| void populateLoweringONNXPoolingOpPattern( | ||||
|     OwningRewritePatternList &patterns, MLIRContext *ctx); | ||||
| 
 | ||||
| // `tensor` directory methods:
 | ||||
| 
 | ||||
| void populateLoweringONNXUnsqueezeOpPattern( | ||||
|  |  | |||
|  | @ -305,6 +305,13 @@ test_to_enable = [ | |||
|     "test_batchnorm_epsilon_cpu", | ||||
|     "test_batchnorm_example_cpu", | ||||
| 
 | ||||
|     # Pooling | ||||
|     "test_maxpool_1d_default_cpu", | ||||
|     "test_maxpool_2d_ceil_cpu", | ||||
|     "test_maxpool_2d_default_cpu", | ||||
|     "test_maxpool_2d_strides_cpu", | ||||
|     "test_maxpool_3d_default_cpu", | ||||
| 
 | ||||
| ] | ||||
| 
 | ||||
| # Extract name of all test cases. | ||||
|  |  | |||
|  | @ -1344,3 +1344,169 @@ func @test_batchnorm_testmode_1d(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>, %a | |||
|   // CHECK: return [[RES]] : memref<10xf32> | ||||
| } | ||||
| 
 | ||||
| func @test_maxpooling_singleout_no_pad(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*xf32> { | ||||
|   %0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [2, 2]} : (tensor<1x3x32x32xf32>) -> tensor<*xf32> | ||||
|   "std.return"(%0) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_maxpooling_singleout_no_pad | ||||
|   // CHECK: [[RES:%.+]] = alloc() : memref<1x3x31x31xf32> | ||||
|   // CHECK: [[DEF_LOOPS_0:%.+]]:4 = krnl.define_loops 4 | ||||
|   // CHECK: [[OPT_LOOPS_0:%.+]]:4 = krnl.optimize_loops  { | ||||
|   // CHECK:   krnl.return_loops [[DEF_LOOPS_0]]#0, [[DEF_LOOPS_0]]#1, [[DEF_LOOPS_0]]#2, [[DEF_LOOPS_0]]#3 | ||||
|   // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) | ||||
|   // CHECK: krnl.iterate([[OPT_LOOPS_0]]#0, [[OPT_LOOPS_0]]#1, [[OPT_LOOPS_0]]#2, [[OPT_LOOPS_0]]#3) with ([[DEF_LOOPS_0]]#0 -> %arg1 = 0 to 1, [[DEF_LOOPS_0]]#1 -> %arg2 = 0 to 3, [[DEF_LOOPS_0]]#2 -> %arg3 = 0 to 31, [[DEF_LOOPS_0]]#3 -> %arg4 = 0 to 31) { | ||||
|   // CHECK:   [[NEGATIVE_INFINITY:%.+]] = constant 0xFF800000 : f32 | ||||
|   // CHECK:   store [[NEGATIVE_INFINITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> | ||||
|   // CHECK:   [[DEF_LOOPS_1:%.+]]:2 = krnl.define_loops 2 | ||||
|   // CHECK:   [[OPT_LOOPS_1:%.+]]:2 = krnl.optimize_loops  { | ||||
|   // CHECK:     krnl.return_loops [[DEF_LOOPS_1]]#0, [[DEF_LOOPS_1]]#1 | ||||
|   // CHECK:   } : () -> (!krnl.loop, !krnl.loop) | ||||
|   // CHECK:   krnl.iterate([[OPT_LOOPS_1]]#0, [[OPT_LOOPS_1]]#1) with ([[DEF_LOOPS_1]]#0 -> %arg5 = 0 to 2, [[DEF_LOOPS_1]]#1 -> %arg6 = 0 to 2) { | ||||
|   // CHECK:     [[H:%.+]] = addi %arg3, %arg5 : index | ||||
|   // CHECK:     [[W:%.+]] = addi %arg4, %arg6 : index | ||||
|   // CHECK:     [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<1x3x32x32xf32> | ||||
|   // CHECK:     [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> | ||||
|   // CHECK:     [[COMPARE:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32 | ||||
|   // CHECK:     [[SELECT:%.+]] = select [[COMPARE]], [[LOAD_Y]], [[LOAD_X]] : f32 | ||||
|   // CHECK:     store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> | ||||
|   // CHECK:   } | ||||
|   // CHECK: } | ||||
|   // CHECK: return [[RES]] : memref<1x3x31x31xf32> | ||||
| } | ||||
| 
 | ||||
| func @test_maxpooling_singleout_no_pad_w_strides(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*xf32> { | ||||
|   %0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [2, 2], strides = [2, 2]} : (tensor<1x3x32x32xf32>) -> tensor<*xf32> | ||||
|   "std.return"(%0) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_maxpooling_singleout_no_pad_w_strides | ||||
|   // CHECK: [[RES:%.+]] = alloc() : memref<1x3x16x16xf32> | ||||
|   // CHECK: [[DEF_LOOPS_0:%.+]]:4 = krnl.define_loops 4 | ||||
|   // CHECK: [[OPT_LOOPS_0:%.+]]:4 = krnl.optimize_loops  { | ||||
|   // CHECK:   krnl.return_loops [[DEF_LOOPS_0]]#0, [[DEF_LOOPS_0]]#1, [[DEF_LOOPS_0]]#2, [[DEF_LOOPS_0]]#3 | ||||
|   // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) | ||||
|   // CHECK: krnl.iterate([[OPT_LOOPS_0]]#0, [[OPT_LOOPS_0]]#1, [[OPT_LOOPS_0]]#2, [[OPT_LOOPS_0]]#3) with ([[DEF_LOOPS_0]]#0 -> %arg1 = 0 to 1, [[DEF_LOOPS_0]]#1 -> %arg2 = 0 to 3, [[DEF_LOOPS_0]]#2 -> %arg3 = 0 to 16, [[DEF_LOOPS_0]]#3 -> %arg4 = 0 to 16) { | ||||
|   // CHECK:   [[NEGATIVE_INFINITY:%.+]] = constant 0xFF800000 : f32 | ||||
|   // CHECK:   store [[NEGATIVE_INFINITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> | ||||
|   // CHECK:   [[DEF_LOOPS_1:%.+]]:2 = krnl.define_loops 2 | ||||
|   // CHECK:   [[OPT_LOOPS_1:%.+]]:2 = krnl.optimize_loops  { | ||||
|   // CHECK:     krnl.return_loops [[DEF_LOOPS_1]]#0, [[DEF_LOOPS_1]]#1 | ||||
|   // CHECK:   } : () -> (!krnl.loop, !krnl.loop) | ||||
|   // CHECK:   krnl.iterate([[OPT_LOOPS_1]]#0, [[OPT_LOOPS_1]]#1) with ([[DEF_LOOPS_1]]#0 -> %arg5 = 0 to 2, [[DEF_LOOPS_1]]#1 -> %arg6 = 0 to 2) { | ||||
|   // CHECK:     [[STRIDE_0:%.+]] = constant 2 : index | ||||
|   // CHECK:     [[MUL_0:%.+]] = muli [[STRIDE_0]], %arg3 : index | ||||
|   // CHECK:     [[H:%.+]] = addi [[MUL_0]], %arg5 : index | ||||
|   // CHECK:     [[STRIDE_1:%.+]] = constant 2 : index | ||||
|   // CHECK:     [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg4 : index | ||||
|   // CHECK:     [[W:%.+]] = addi [[MUL_1]], %arg6 : index | ||||
|   // CHECK:     [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<1x3x32x32xf32> | ||||
|   // CHECK:     [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> | ||||
|   // CHECK:     [[COMPARE:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32 | ||||
|   // CHECK:     [[SELECT:%.+]] = select [[COMPARE]], [[LOAD_Y]], [[LOAD_X]] : f32 | ||||
|   // CHECK:     store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> | ||||
|   // CHECK:   } | ||||
|   // CHECK: } | ||||
|   // CHECK: return [[RES]] : memref<1x3x16x16xf32> | ||||
| } | ||||
| 
 | ||||
| func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*xf32> { | ||||
|   %0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [3, 3], strides = [2, 2], ceil_mode = 1} : (tensor<1x3x32x32xf32>) -> tensor<*xf32> | ||||
|   "std.return"(%0) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode | ||||
|   // CHECK: [[RES:%.+]] = alloc() : memref<1x3x16x16xf32> | ||||
|   // CHECK: [[DEF_LOOPS_0:%.+]]:4 = krnl.define_loops 4 | ||||
|   // CHECK: [[OPT_LOOPS_0:%.+]]:4 = krnl.optimize_loops  { | ||||
|   // CHECK:   krnl.return_loops [[DEF_LOOPS_0]]#0, [[DEF_LOOPS_0]]#1, [[DEF_LOOPS_0]]#2, [[DEF_LOOPS_0]]#3 | ||||
|   // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) | ||||
|   // CHECK: krnl.iterate([[OPT_LOOPS_0]]#0, [[OPT_LOOPS_0]]#1, [[OPT_LOOPS_0]]#2, [[OPT_LOOPS_0]]#3) with ([[DEF_LOOPS_0]]#0 -> %arg1 = 0 to 1, [[DEF_LOOPS_0]]#1 -> %arg2 = 0 to 3, [[DEF_LOOPS_0]]#2 -> %arg3 = 0 to 16, [[DEF_LOOPS_0]]#3 -> %arg4 = 0 to 16) { | ||||
|   // CHECK:   [[NEGATIVE_INFINITY:%.+]] = constant 0xFF800000 : f32 | ||||
|   // CHECK:   store [[NEGATIVE_INFINITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> | ||||
|   // CHECK:   [[DEF_LOOPS_1:%.+]]:2 = krnl.define_loops 2 | ||||
|   // CHECK:   [[OPT_LOOPS_1:%.+]]:2 = krnl.optimize_loops  { | ||||
|   // CHECK:     krnl.return_loops [[DEF_LOOPS_1]]#0, [[DEF_LOOPS_1]]#1 | ||||
|   // CHECK:   } : () -> (!krnl.loop, !krnl.loop) | ||||
|   // CHECK:   krnl.iterate([[OPT_LOOPS_1]]#0, [[OPT_LOOPS_1]]#1) with ([[DEF_LOOPS_1]]#0 -> %arg5 = 0 to 3, [[DEF_LOOPS_1]]#1 -> %arg6 = 0 to 3) { | ||||
|   // CHECK:     [[STRIDE_0:%.+]] = constant 2 : index | ||||
|   // CHECK:     [[MUL_0:%.+]] = muli [[STRIDE_0]], %arg3 : index | ||||
|   // CHECK:     [[SPATIAL_H:%.+]] = addi [[MUL_0]], %arg5 : index | ||||
|   // CHECK:     [[INPUT_INDEX_0:%.+]] = constant 31 : index | ||||
|   // CHECK:     [[CMP_0:%.+]] = cmpi "sgt", [[SPATIAL_H]], [[INPUT_INDEX_0]] : index | ||||
|   // CHECK:     [[H:%.+]] = select [[CMP_0]], [[INPUT_INDEX_0]], [[SPATIAL_H]] : index | ||||
|   // CHECK:     [[STRIDE_1:%.+]] = constant 2 : index | ||||
|   // CHECK:     [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg4 : index | ||||
|   // CHECK:     [[SPATIAL_W:%.+]] = addi [[MUL_1]], %arg6 : index | ||||
|   // CHECK:     [[INPUT_INDEX_1:%.+]] = constant 31 : index | ||||
|   // CHECK:     [[CMP_1:%.+]] = cmpi "sgt", [[SPATIAL_W]], [[INPUT_INDEX_1]] : index | ||||
|   // CHECK:     [[W:%.+]] = select [[CMP_1]], [[INPUT_INDEX_1]], [[SPATIAL_W]] : index | ||||
|   // CHECK:     [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<1x3x32x32xf32> | ||||
|   // CHECK:     [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> | ||||
|   // CHECK:     [[CMP_2:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32 | ||||
|   // CHECK:     [[SELECT:%.+]] = select [[CMP_2]], [[LOAD_Y]], [[LOAD_X]] : f32 | ||||
|   // CHECK:     store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> | ||||
|   // CHECK:   } | ||||
|   // CHECK: } | ||||
|   // CHECK: return [[RES]] : memref<1x3x16x16xf32> | ||||
| } | ||||
| 
 | ||||
| func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode_w_unknown_dims(%arg0 : tensor<?x3x?x32xf32>) -> tensor<*xf32> { | ||||
|   %0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [3, 3], strides = [2, 2], ceil_mode = 1} : (tensor<?x3x?x32xf32>) -> tensor<*xf32> | ||||
|   "std.return"(%0) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode_w_unknown_dims | ||||
| 
 | ||||
|   // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x3x?x32xf32> | ||||
|   // CHECK: [[ZERO:%.+]] = constant 0 : i64 | ||||
|   // CHECK: [[ONE:%.+]] = constant 1 : i64 | ||||
|   // CHECK: [[DIM_1:%.+]] = dim %arg0, 2 : memref<?x3x?x32xf32> | ||||
|   // CHECK: [[DIM_1_i64:%.+]] = index_cast [[DIM_1]] : index to i64 | ||||
|   // CHECK: [[KERNEL_PAD_DILATION:%.+]] = constant -1 : i64 | ||||
|   // CHECK: [[NUMERATOR:%.+]] = addi [[DIM_1_i64]], [[KERNEL_PAD_DILATION]] : i64 | ||||
|   // CHECK: [[DENOMINATOR:%.+]] = constant 2 : i64 | ||||
|   // CHECK: [[DIV:%.+]] = divi_signed [[NUMERATOR]], [[DENOMINATOR]] : i64 | ||||
|   // CHECK: [[REMAINDER:%.+]] = remi_signed [[NUMERATOR]], [[DENOMINATOR]] : i64 | ||||
|   // CHECK: [[IS_ZERO:%.+]] = cmpi "eq", [[REMAINDER]], [[ZERO]] : i64 | ||||
|   // CHECK: [[DIV_PLUS_ONE:%.+]] = addi [[DIV]], [[ONE]] : i64 | ||||
|   // CHECK: [[SELECT:%.+]] = select [[IS_ZERO]], [[DIV]], [[DIV_PLUS_ONE]] : i64 | ||||
|   // CHECK: [[SELECT_PLUS_ONE:%.+]] = addi [[SELECT]], [[ONE]] : i64 | ||||
|   // CHECK: [[DIM_1_FINAL:%.+]] = index_cast [[SELECT_PLUS_ONE]] : i64 to index | ||||
|   // CHECK: [[RES:%.+]] = alloc([[DIM_0]], [[DIM_1_FINAL]]) : memref<?x3x?x16xf32> | ||||
| 
 | ||||
|   // CHECK: [[DEF_LOOPS_0:%.+]]:4 = krnl.define_loops 4 | ||||
|   // CHECK: [[OPT_LOOPS_0:%.+]]:4 = krnl.optimize_loops  { | ||||
|   // CHECK:   krnl.return_loops [[DEF_LOOPS_0]]#0, [[DEF_LOOPS_0]]#1, [[DEF_LOOPS_0]]#2, [[DEF_LOOPS_0]]#3 | ||||
|   // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) | ||||
|   // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x3x?x16xf32> | ||||
|   // CHECK: [[DIM_3:%.+]] = dim [[RES]], 2 : memref<?x3x?x16xf32> | ||||
|   // CHECK: krnl.iterate([[OPT_LOOPS_0]]#0, [[OPT_LOOPS_0]]#1, [[OPT_LOOPS_0]]#2, [[OPT_LOOPS_0]]#3) with ([[DEF_LOOPS_0]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS_0]]#1 -> %arg2 = 0 to 3, [[DEF_LOOPS_0]]#2 -> %arg3 = 0 to [[DIM_3]], [[DEF_LOOPS_0]]#3 -> %arg4 = 0 to 16) { | ||||
|   // CHECK:   [[NEGATIVE_INFINITY:%.+]] = constant 0xFF800000 : f32 | ||||
|   // CHECK:   store [[NEGATIVE_INFINITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<?x3x?x16xf32> | ||||
|   // CHECK:   [[DEF_LOOPS_1:%.+]]:2 = krnl.define_loops 2 | ||||
|   // CHECK:   [[OPT_LOOPS_1:%.+]]:2 = krnl.optimize_loops  { | ||||
|   // CHECK:     krnl.return_loops [[DEF_LOOPS_1]]#0, [[DEF_LOOPS_1]]#1 | ||||
|   // CHECK:   } : () -> (!krnl.loop, !krnl.loop) | ||||
|   // CHECK:   krnl.iterate([[OPT_LOOPS_1]]#0, [[OPT_LOOPS_1]]#1) with ([[DEF_LOOPS_1]]#0 -> %arg5 = 0 to 3, [[DEF_LOOPS_1]]#1 -> %arg6 = 0 to 3) { | ||||
|   // CHECK:     [[STRIDE_0:%.+]] = constant 2 : index | ||||
|   // CHECK:     [[MUL_0:%.+]] = muli [[STRIDE_0]], %arg3 : index | ||||
|   // CHECK:     [[SPATIAL_H:%.+]] = addi [[MUL_0]], %arg5 : index | ||||
|   // CHECK:     [[DIM_0_0:%.+]] = dim %arg0, 2 : memref<?x3x?x32xf32> | ||||
|   // CHECK:     [[ONE_INDEX:%.+]] = constant 1 : index | ||||
|   // CHECK:     [[INPUT_INDEX_0:%.+]] = subi [[DIM_0_0]], [[ONE_INDEX]] : index | ||||
|   // CHECK:     [[CMP_0:%.+]] = cmpi "sgt", [[SPATIAL_H]], [[INPUT_INDEX_0]] : index | ||||
|   // CHECK:     [[H:%.+]] = select [[CMP_0]], [[INPUT_INDEX_0]], [[SPATIAL_H]] : index | ||||
| 
 | ||||
|   // CHECK:     [[STRIDE_1:%.+]] = constant 2 : index | ||||
|   // CHECK:     [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg4 : index | ||||
|   // CHECK:     [[SPATIAL_W:%.+]] = addi [[MUL_1]], %arg6 : index | ||||
|   // CHECK:     [[INPUT_INDEX_1:%.+]] = constant 31 : index | ||||
|   // CHECK:     [[CMP_1:%.+]] = cmpi "sgt", [[SPATIAL_W]], [[INPUT_INDEX_1]] : index | ||||
|   // CHECK:     [[W:%.+]] = select [[CMP_1]], [[INPUT_INDEX_1]], [[SPATIAL_W]] : index | ||||
| 
 | ||||
|   // CHECK:     [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<?x3x?x32xf32> | ||||
|   // CHECK:     [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<?x3x?x16xf32> | ||||
|   // CHECK:     [[CMP_2:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32 | ||||
|   // CHECK:     [[SELECT:%.+]] = select [[CMP_2]], [[LOAD_Y]], [[LOAD_X]] : f32 | ||||
|   // CHECK:     store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<?x3x?x16xf32> | ||||
|   // CHECK:   } | ||||
|   // CHECK: } | ||||
|   // CHECK: return [[RES]] : memref<?x3x?x16xf32> | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue