diff --git a/src/Conversion/ONNXToKrnl/NN/Pooling.cpp b/src/Conversion/ONNXToKrnl/NN/Pooling.cpp index 19f9b0d..d008f10 100644 --- a/src/Conversion/ONNXToKrnl/NN/Pooling.cpp +++ b/src/Conversion/ONNXToKrnl/NN/Pooling.cpp @@ -8,6 +8,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/AffineExpr.h" #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" using namespace mlir; @@ -19,6 +20,19 @@ Value getIdentityValue( return emitNegativeInfinityConstantOp(rewriter, loc, type); } +template <> +Value getIdentityValue( + ConversionPatternRewriter &rewriter, Location loc, Type type) { + return emitConstantOp(rewriter, loc, type, 0); +} + +// Scalar operations +template <> +struct ScalarOp { + using FOp = AddFOp; + using IOp = AddIOp; +}; + template <> Value emitScalarOpFor( ConversionPatternRewriter &rewriter, Location loc, Operation *op, @@ -30,288 +44,519 @@ Value emitScalarOpFor( return result; } -struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { - ONNXMaxPoolSingleOutOpLowering(MLIRContext *ctx) - : ConversionPattern( - mlir::ONNXMaxPoolSingleOutOp::getOperationName(), 1, ctx) {} +//===----------------------------------------------------------------------===// +// Get dilation values +// +template +std::vector getDilations(PoolOp poolOp) { + return {}; +} + +// MaxPool has dilations attribute. +template <> +std::vector getDilations( + ONNXMaxPoolSingleOutOp poolOp) { + std::vector dilations; + auto dilationsAttribute = poolOp.dilationsAttr(); + bool isDefaultDilations = true; + for (auto dilation : dilationsAttribute.getValue()) { + int64_t dilationValue = dilation.cast().getInt(); + if (dilationValue > 1 and isDefaultDilations) + isDefaultDilations = false; + dilations.emplace_back(dilationValue); + } + if (isDefaultDilations) + return {}; + else + return dilations; +} + +//===----------------------------------------------------------------------===// +// Get count_include_pad values +// +template +bool getCountIncludePad(PoolOp poolOp) { + return false; +} + +// AveragePool has count_include_pad attribute. +template <> +bool getCountIncludePad(ONNXAveragePoolOp poolOp) { + return (poolOp.count_include_pad() == 1); +} + +//===----------------------------------------------------------------------===// +// Helper function to do post-processing after applying a filter window. +// +template +void postProcessPoolingWindow(ConversionPatternRewriter &rewriter, Location loc, + PoolOp poolOp, Value alloc, ArrayRef resultIndices, + ArrayRef kernelShape, ArrayRef poolDimValues) {} + +// Calculate the average value for AveragePool. +template <> +void postProcessPoolingWindow( + ConversionPatternRewriter &rewriter, Location loc, ONNXAveragePoolOp poolOp, + Value alloc, ArrayRef resultIndices, ArrayRef kernelShape, + ArrayRef poolDimValues) { + // AveragePool's result type is FloatType, so it's safe to use DivFOp, SubFOp. + bool countIncludePad = getCountIncludePad(poolOp); + Value numerator = rewriter.create(loc, alloc, resultIndices); + Value denominator; + if (countIncludePad) { + int64_t kernelSize = 1; + for (int i = 0; i < kernelShape.size(); ++i) + kernelSize *= kernelShape[i]; + denominator = + emitConstantOp(rewriter, loc, numerator.getType(), kernelSize); + } else { + denominator = poolDimValues[0]; + for (int i = 1; i < poolDimValues.size(); ++i) + denominator = rewriter.create(loc, denominator, poolDimValues[i]); + denominator = rewriter.create( + loc, denominator, rewriter.getIntegerType(64)); + denominator = + rewriter.create(loc, denominator, numerator.getType()); + } + + Value average = rewriter.create(loc, numerator, denominator); + + rewriter.create(loc, average, alloc, resultIndices); +} + +//===----------------------------------------------------------------------===// +// Helper function to insert alloc and dealloc ops for memref of dynamic shape. +// +Value insertAllocAndDeallocForPooling(ConversionPatternRewriter &rewriter, + Location loc, bool insertDealloc, MemRefType memRefType, Value inputOperand, + ArrayRef kernelShape, ArrayRef pads, + ArrayRef strides, ArrayRef dilations, bool ceilMode) { + AllocOp alloc; + + // Shape and rank information related to result and kernel. + auto resultShape = memRefType.getShape(); + auto resultRank = resultShape.size(); + auto kernelRank = kernelShape.size(); + auto kernelOffset = resultRank - kernelRank; + + // Compute dimensions of the result of this operation. + SmallVector allocOperands; + for (int i = 0; i < kernelOffset; ++i) { + if (resultShape[i] < 0) { + auto dim = rewriter.create(loc, inputOperand, i); + allocOperands.emplace_back(dim); + } + } + + Value zero, one; + if (ceilMode) { + zero = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + } + one = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); + + for (int i = kernelOffset; i < resultShape.size(); ++i) { + if (resultShape[i] < 0) { + // dim = + // let numerator = (input + pad - (kernel - 1) * dilation - 1) + // in let denominator = stride + // in + // if (ceilMode) + // ceil(numerator / denominator) + 1 + // else + // floor(numerator / denominator) + 1 + int spatialIndex = i - kernelOffset; + + // numerator = (input + pad - (kernel - 1) * dilation - 1) + int64_t dilation = dilations.empty() ? 1 : dilations[spatialIndex]; + int64_t padKernelDilation = + (pads[spatialIndex] + pads[spatialIndex + kernelRank]) - + (kernelShape[spatialIndex] - 1) * dilation - 1; + auto padKernelDilationVal = emitConstantOp( + rewriter, loc, rewriter.getIntegerType(64), padKernelDilation); + auto inputDim = rewriter.create(loc, inputOperand, i); + auto inputDimVal = rewriter.create( + loc, inputDim, rewriter.getIntegerType(64)); + auto numeratorVal = + rewriter.create(loc, inputDimVal, padKernelDilationVal); + // denominator + auto denominatorVal = emitConstantOp( + rewriter, loc, rewriter.getIntegerType(64), strides[spatialIndex]); + + // numerator / denominator + Value dimVal = + rewriter.create(loc, numeratorVal, denominatorVal); + + if (ceilMode) { + auto remainder = + rewriter.create(loc, numeratorVal, denominatorVal); + auto isZero = + rewriter.create(loc, CmpIPredicate::eq, remainder, zero); + auto dimPlusOne = rewriter.create(loc, dimVal, one); + dimVal = rewriter.create(loc, isZero, dimVal, dimPlusOne); + } + + dimVal = rewriter.create(loc, dimVal, one); + allocOperands.emplace_back( + rewriter.create(loc, dimVal, rewriter.getIndexType())); + } + } + alloc = rewriter.create(loc, memRefType, allocOperands); + if (insertDealloc) { + auto *parentBlock = alloc.getOperation()->getBlock(); + auto dealloc = rewriter.create(loc, alloc); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + } + return alloc; +} + +//===----------------------------------------------------------------------===// +// Template function that does pooling. +// +template +struct ONNXPoolOpLowering : public ConversionPattern { + ONNXPoolOpLowering(MLIRContext *ctx) + : ConversionPattern(PoolOp::getOperationName(), 1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { ONNXMaxPoolSingleOutOpOperandAdaptor operandAdaptor(operands); auto loc = op->getLoc(); - // Match - ONNXMaxPoolSingleOutOp poolOp = llvm::dyn_cast(op); + PoolOp poolOp = llvm::dyn_cast(op); // Read kernel_shape attribute - SmallVector kernelShape; + SmallVector kernelShape; auto kernelShapeAttribute = poolOp.kernel_shapeAttr(); - for (auto dim : kernelShapeAttribute.getValue()) + for (Attribute dim : kernelShapeAttribute.getValue()) kernelShape.emplace_back(dim.cast().getInt()); // Read strides attribute - SmallVector strides; + SmallVector strides; auto stridesAttribute = poolOp.stridesAttr(); - for (auto stride : stridesAttribute.getValue()) + for (Attribute stride : stridesAttribute.getValue()) strides.emplace_back(stride.cast().getInt()); // Read ceil_mode attribute auto ceilMode = poolOp.ceil_mode().getSExtValue(); // Read pads attribute - SmallVector pads; + SmallVector pads; auto padsAttribute = poolOp.padsAttr(); - for (auto pad : padsAttribute.getValue()) + for (Attribute pad : padsAttribute.getValue()) pads.emplace_back(pad.cast().getInt()); - // Read dilations attribute - SmallVector dilations; - auto dilationsAttribute = poolOp.dilationsAttr(); - for (auto dilation : dilationsAttribute.getValue()) - dilations.emplace_back(dilation.cast().getInt()); + // Read dilations attribute if the op has. + std::vector dilations = getDilations(poolOp); + bool isDilated = !dilations.empty(); // Type information about the input and result of this operation. auto inputOperand = operandAdaptor.X(); auto inputShape = inputOperand.getType().cast().getShape(); auto memRefType = convertToMemRefType(*op->result_type_begin()); - auto resultShape = memRefType.getShape(); - auto resultElementType = memRefType.getElementType(); + auto outputShape = memRefType.getShape(); + auto outputElementType = memRefType.getElementType(); - // Batch indices: N and C dimensions - int batchRank = 2; + // Kernel offset in the input shape. + int kernelOffset = inputShape.size() - kernelShape.size(); - // Insert an allocation and deallocation for the result of this operation. + // Insert an allocation and deallocation for the output of this operation. Value alloc; bool insertDealloc = checkInsertDealloc(op); if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else { - // Compute dimensions of the result of this operation. - SmallVector allocOperands; - for (int i = 0; i < batchRank; ++i) { - if (resultShape[i] < 0) { - auto dim = rewriter.create(loc, inputOperand, i); - allocOperands.emplace_back(dim); - } - } - - Value zero, one; - if (ceilMode) { - zero = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - } - one = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); - - int spatialRank = resultShape.size() - batchRank; - for (int i = batchRank; i < resultShape.size(); ++i) { - if (resultShape[i] < 0) { - // dim = - // let numerator = (input + pad - (kernel - 1) * dilation - 1) - // in let denomitor = stride - // in - // if (ceilMode) - // ceil(numerator / denominator) + 1 - // else - // floor(numerator / denominator) + 1 - int spatialIndex = i - batchRank; - - // numerator = (input + pad - (kernel - 1) * dilation - 1) - auto inputDim = rewriter.create(loc, inputOperand, i); - auto inputVal = rewriter.create( - loc, inputDim, rewriter.getIntegerType(64)); - int64_t padKernelDilation = - (pads[spatialIndex] + pads[spatialIndex + spatialRank]) - - (kernelShape[spatialIndex] - 1) * dilations[spatialIndex] - 1; - auto padKernelDilationVal = rewriter.create( - loc, rewriter.getIntegerAttr( - rewriter.getIntegerType(64), padKernelDilation)); - auto numeratorVal = - rewriter.create(loc, inputVal, padKernelDilationVal); - // denominator - auto denominatorVal = rewriter.create( - loc, rewriter.getIntegerAttr( - rewriter.getIntegerType(64), strides[spatialIndex])); - - // numerator / denominator - Value dimVal = - rewriter.create(loc, numeratorVal, denominatorVal); - - if (ceilMode) { - auto remainder = rewriter.create( - loc, numeratorVal, denominatorVal); - auto isZero = rewriter.create( - loc, CmpIPredicate::eq, remainder, zero); - auto dimPlusOne = rewriter.create(loc, dimVal, one); - dimVal = rewriter.create(loc, isZero, dimVal, dimPlusOne); - } - - dimVal = rewriter.create(loc, dimVal, one); - allocOperands.emplace_back(rewriter.create( - loc, dimVal, rewriter.getIndexType())); - } - } - alloc = rewriter.create(loc, memRefType, allocOperands); - if (insertDealloc) { - auto *parentBlock = alloc.getDefiningOp()->getBlock(); - auto dealloc = rewriter.create(loc, alloc); - dealloc.getOperation()->moveBefore(&parentBlock->back()); - } + alloc = insertAllocAndDeallocForPooling(rewriter, loc, insertDealloc, + memRefType, inputOperand, kernelShape, pads, strides, dilations, + ceilMode); } - // R = MaxPool(D) + // input = Pool(output) // // The input/output shapes will look like this: // - // D (NxCxHxW) -> R (NxCxRHxRW) + // input (NxCxHxW) -> output (NxCxHOxWO) // // The loop nest will look as follows: // - // strides = [s1, s2] + // kernelShape = [kH, kW] + // pads = [ptH, ptW, pbH, pbW] + // strides = [sH, sW] + // dilations = [dH, dW] + // round = ceil if ceilMode else floor // - // for n = 0 .. N: - // for c = 0 .. C: - // for r1 = 0 .. RH: - // for r2 = 0 .. RW: - // R[n][c][r1][r2] = negative_infinity; - // for k1 = 0 .. KH: - // for k2 = 0 .. KW: - // t = D[n][c][s1 * r1 + k1 * d1][s2 * r2 + k2 * d2]; - // R[n][c][r1][r2] = max(R[n][c][r1][r2], t); + // for n in range(N): + // for c in range(C): + // for ho in range(HO): + // for wo in range(WO): + // # Initialize values for the output. + // output[n][c][ho][wo] = getIdentityValue(...); // - // Naming: - // n, c, r1, r2: outer loop nest indices - // k1, k2: inner loop nest indices - // s1, s2: strides - // d1, d2: dilations + // # Thanks to Tian (@tjingrant) for the following derivation about + // # firstValid. + // # When dilation is non-unit, the first valid pixel to + // # apply pooling on will not be the 0-th pixel, but rather + // # the smallest integer n to make -pH + n * 3 greater than + // # or equal to 0. + // # We derive what is this smallest n: + // # -pH + n * dH >= 0 + // # n * dH >= pH + // # n >= pH/dH + // # thus n = ceil(pH/dH) + // # thus the first valid pixel location is + // # ceil(pH / dilation) * dilation - pH // - // TODO: handle padding. + // firstValidH = ceil(float(ptH / dH)) * dH - ptH + // startH = max(firstValidH, ho * sH - ptH) + // endH = min(H, ho * sH + (kH -1) * dH + 1 - ptH) + // + // firstValidW= ceil(float(pW / dW)) * dW - ptW + // startW = max(firstValidW, wo * sW - ptW) + // endW = min(W, wo * sW + (kW - 1) * dW + 1 - ptW) + // + // hDim= round(float(endH - startH) / float(dH)) + // wDim= round(float(endW - startW) / float(dW)) + // + // # Apply the pooling window. + // # The pooling window can be smaller than the kernel when slicing + // # over the border edges. + // for hi in range(startH, endH, dH): + // for wi in range(startW, endW, dW): + // output[n, c, ho, wo] = emitScalarOpFor(output[n, c, ho, wo], + // input[n, c, hi, wi]); + // + // # The above two for-loops are rewritten as follows: + // # (since KrnlIterateOp has not supported `step` yet) + // for hp in range(hDim): + // for wp in range(wDim): + // hi = hp * dH + startH + // wi = wp * dW + startW + // output[n, c, ho, wo] = emitScalarOpFor(output[n, c, ho, wo], + // input[n, c, hi, wi]); + // + // # Do post processing such as taking average pooling: + // postProcessPoolingWindow(...) + // + // Helper functions: + // getIdentityValue(): to return the indentity value + // - negative infinity for MaxPool + // - 0 for AveragePool + // emitScalarOpFor(): to do primitive computation for Pooling, e.g. + // - compute max for MaxPool + // - compute sum for AveragePool + // postProcessPoolingWindow(): to do post processing over the whole + // pooling window, e.g. + // - do nothing in case of MaxPool + // - calculate the average in case of AveragePool, e.g. + // if hDim * wDim> 0: + // output[n, c, ho, wo] = output[n, c, ho, wo] / (hDim*wDim) // - // 1. Define outer loops and emit empty optimization block. - auto nOuterLoops = resultShape.size(); - BuildKrnlLoop outerLoops(rewriter, loc, nOuterLoops); - outerLoops.createDefineOptimizeAndIterateOp(alloc); + // Identity value of the operation. + auto identity = getIdentityValue(rewriter, loc, outputElementType); - rewriter.setInsertionPointToStart(outerLoops.getIterateBlock()); + // 1. Define output loops to compute one output pixel. + // for n in range(N): + // for c in range(C): + // for ho in range(HO): + // for wo in range(WO): + BuildKrnlLoop outputLoops(rewriter, loc, outputShape.size()); + outputLoops.createDefineOptimizeAndIterateOp(alloc); + + auto ipMainRegion = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointToStart(outputLoops.getIterateBlock()); { - // 2. Emit the body of the outer loop nest. - SmallVector resultIndices; - for (int i = 0; i < nOuterLoops; ++i) - resultIndices.emplace_back(outerLoops.getInductionVar(i)); + // 2. Emit the body of the output loop nest, which applies a pooling + // window to a region in the input, producing one output pixel. + SmallVector outputIndices; + for (int i = 0; i < outputShape.size(); ++i) + outputIndices.emplace_back(outputLoops.getInductionVar(i)); - // 2.1 Emit: R[n][c][r1][r2] = negative_infinity; - Value identity = getIdentityValue( - rewriter, loc, resultElementType); - rewriter.create(loc, identity, alloc, resultIndices); + // 2.1 Emit: output[n][c][ho][wo] = identity + rewriter.create(loc, identity, alloc, outputIndices); - // 2.2 Define inner loops. - int nInnerLoops = kernelShape.size(); - BuildKrnlLoop innerLoops(rewriter, loc, nInnerLoops); - innerLoops.createDefineAndOptimizeOp(); - // for Kx = 0 .. KX - for (int i = 0; i < nInnerLoops; ++i) - innerLoops.pushBounds(0, kernelShape[i]); + // 2.2 Emit affine maps which express the lower and upper bounds for the + // pooling window's dimensions. + // The pooling window can be smaller than the kernel when slicing it over + // the border edges. Thus, we will compute the start and end indices for + // each dimension as follows. + // firstValidH = ceil(float(ptH / dH)) * dH - ptH + // startH = max(firstValidH, ho * sH - ptH) + // endH = min(H, ho * sH + (kH - 1) * dH + 1 - pbH) + // hDim = round(float(endH - startH) / float(dH)) - // 2.3 Emit inner loop nest. - innerLoops.createIterateOp(); - rewriter.setInsertionPointToStart(innerLoops.getIterateBlock()); - { - // 3. Emit inner loop body - // t = D[n][c][s1 * r1 + k1 * d1][s2 * r2 + k2 * d2]; - // R[n][c][r1][r2] = max(R[n][c][r1][r2], t); - - // 3.1 Prepare indices for accesing the data tensor. - SmallVector dataIndices; - // 3.1.1 Batch indices: n, c - for (int i = 0; i < batchRank; ++i) - dataIndices.emplace_back(outerLoops.getInductionVar(i)); - // 3.1.2 Insert spatial indices: sX * rX + kX * dX - for (int i = batchRank; i < nOuterLoops; ++i) { - // Get index along the inner loop's induction variables. - // It is used to obtain kernel/pad/stride/dilation index. - int j = i - batchRank; - - Value spatialIndex = outerLoops.getInductionVar(i); - // If strides are present (not default) then emit the correct access - // index. - // sX *= rX - if (strides[i - batchRank] > 1) { - auto strideIndex = emitConstantOp( - rewriter, loc, rewriter.getIndexType(), strides[j]); - spatialIndex = rewriter.create( - loc, strideIndex, outerLoops.getInductionVar(i)); - } - - // Dilate the kernel index only if the dilation value is not one (not - // default). Otherwise, just add kernelIndex. - auto kernelIndex = innerLoops.getInductionVar(j); - if (dilations[j] > 1) { - // sX += dX * kW - auto dilationIndex = emitConstantOp( - rewriter, loc, rewriter.getIndexType(), dilations[j]); - auto dilationKernelIndex = - rewriter.create(loc, dilationIndex, kernelIndex); - spatialIndex = - rewriter.create(loc, spatialIndex, dilationKernelIndex); + // Prepare induction variables and constants as arguments for the affine + // maps. + SmallVector, 4> IVsAndConstants; + { // Construct IVsAndConstants. + for (int i = 0; i < kernelShape.size(); ++i) { + SmallVector ic; + // d0, output + ic.emplace_back(outputLoops.getInductionVar(i + kernelOffset)); + // s0, input dim + if (inputShape[i + kernelOffset] < 0) { + ic.emplace_back( + rewriter.create(loc, inputOperand, i + kernelOffset)); } else { - // sX += kX - spatialIndex = - rewriter.create(loc, spatialIndex, kernelIndex); + ic.emplace_back(emitConstantOp(rewriter, loc, + rewriter.getIndexType(), inputShape[i + kernelOffset])); } + // s1, kernel dim + ic.emplace_back(emitConstantOp( + rewriter, loc, rewriter.getIndexType(), kernelShape[i])); + // s2, pad dim + ic.emplace_back( + emitConstantOp(rewriter, loc, rewriter.getIndexType(), pads[i])); + // s3, stride dim + ic.emplace_back(emitConstantOp( + rewriter, loc, rewriter.getIndexType(), strides[i])); + // s4, dilation dim + ic.emplace_back(emitConstantOp(rewriter, loc, rewriter.getIndexType(), + (isDilated) ? dilations[i] : 1)); + IVsAndConstants.emplace_back(ic); + } + } - // If ceil mode or dilation is enabled, then the calculated access - // index may exceed its dimension. In such a case, we will use the - // maximum index, which causes multiple visits to the element of the - // maximum index. - // TODO: Avoid multiple visits. - // Example of out-of-bound. - // - Given a 5x5 input X - // X = [[0, 0, 0, 0, 0], - // [1, 1, 1, 1, 1], - // [2, 2, 2, 2, 2], - // [3, 3, 3, 3, 3], - // [4, 4, 4, 4, 4]] - // - Do MaxPool with strides=[2, 2], kernel=[2, 2], ceilMode=true, - // output is a 3x3 array: - // Y = [[1, 1, 1], - // [3, 3, 3], - // [4, 4, 4]] - // - When computing Y[2, 0]: - // - In case of kernelIndex = 1, stride = 2 - // - No dilation: spatialIndex = 2 * 2 + 1 = 5 - // => out of bound - // - dilation = 2: spatialIndex = 2 * 2 + 2 * 1 = 6 - // => out of bound - if (dilations[j] > 1 or ceilMode) { - Value upperIndex; - if (inputShape[i] < 0) { - Value inputDim = rewriter.create(loc, inputOperand, i); - Value one = rewriter.create(loc, 1); - upperIndex = rewriter.create(loc, inputDim, one); - } else { - upperIndex = - rewriter.create(loc, inputShape[i] - 1); + // Affine maps for the pooling window. + AffineMap poolStartMap, poolEndMap, poolDimMap; + { // Construct poolStartMap, poolEndMap and poolDimMap. + // AffineExpr(s) to obtain the dimensions and symbols. + AffineExpr outputIndex = rewriter.getAffineDimExpr(0); + AffineExpr inputDim = rewriter.getAffineSymbolExpr(0); + AffineExpr kernelDim = rewriter.getAffineSymbolExpr(1); + AffineExpr padTopDim = rewriter.getAffineSymbolExpr(2); + AffineExpr strideDim = rewriter.getAffineSymbolExpr(3); + AffineExpr dilationDim = rewriter.getAffineSymbolExpr(4); + AffineExpr start1 = + (padTopDim).ceilDiv(dilationDim) * dilationDim - padTopDim; + AffineExpr start2 = outputIndex * strideDim - padTopDim; + AffineExpr end1 = inputDim; + AffineExpr end2 = outputIndex * strideDim + + (kernelDim - 1) * dilationDim + 1 - padTopDim; + + // poolDimMap + SmallVector dimExpr; + // Upperbound for an affine.for is `min AffineMap`, where `min` is + // automatically inserted when an affine.for is constructed from + // an AffineMap, thus we rewrite `endH - startH` as follows: + // endH - start H + // = min(end1, end2) - max(start1, start2) + // = min(end1 - start1, end1 - start2, end2 - start1, end2 - start2) + AffineExpr dimExpr1 = end1 - start1; + AffineExpr dimExpr2 = end1 - start2; + AffineExpr dimExpr3 = end2 - start1; + AffineExpr dimExpr4 = end2 - start2; + for (AffineExpr de : {dimExpr1, dimExpr2, dimExpr3, dimExpr4}) { + if (isDilated) { + de = de + 1; + de = + (ceilMode) ? de.ceilDiv(dilationDim) : de.floorDiv(dilationDim); + } + dimExpr.emplace_back(de); + } + poolDimMap = AffineMap::get(1, 5, dimExpr); + + // poolStartMap and poolEndMap + poolStartMap = AffineMap::get(1, 5, {start1, start2}); + poolEndMap = AffineMap::get(1, 5, {end1, end2}); + } + + // Obtain values from the affine maps. + SmallVector poolStartValues; + SmallVector poolDimValues; + { // Construct poolStartValues and poolDimValues. + for (int i = 0; i < kernelShape.size(); ++i) { + Value startIndex = rewriter.create( + loc, poolStartMap, ValueRange(IVsAndConstants[i])); + poolStartValues.emplace_back(startIndex); + + Value endIndex = rewriter.create( + loc, poolEndMap, ValueRange(IVsAndConstants[i])); + + Value dim = rewriter.create(loc, endIndex, startIndex); + if (isDilated) { + Value one = + emitConstantOp(rewriter, loc, rewriter.getIndexType(), 1); + Value numerator = rewriter.create(loc, dim, one); + Value denominator = IVsAndConstants[i][5]; // dilations[i] + dim = rewriter.create(loc, numerator, denominator); + if (ceilMode) { + auto remainder = + rewriter.create(loc, numerator, denominator); + Value zero = + emitConstantOp(rewriter, loc, rewriter.getIndexType(), 0); + auto isZero = rewriter.create( + loc, CmpIPredicate::eq, remainder, zero); + auto dimPlusOne = rewriter.create(loc, dim, one); + dim = rewriter.create(loc, isZero, dim, dimPlusOne); } - auto greaterCondition = rewriter.create( - loc, CmpIPredicate::sgt, spatialIndex, upperIndex); - spatialIndex = rewriter.create( - loc, greaterCondition, upperIndex, spatialIndex); } + poolDimValues.emplace_back(dim); + } + } - dataIndices.emplace_back(spatialIndex); + // 2.3 Define pooling loops. + // for hp in range(hDim): + // for wp in range(wDim): + // hi = hp * dH + startH + // wi = wp * dW + startW + // output[n][c][ho][wo] = + // emitScalarOpFor(output[n][c][ho][wo], input[n, c, hi, wi]); + BuildKrnlLoop poolingLoops(rewriter, loc, kernelShape.size()); + poolingLoops.createDefineAndOptimizeOp(); + for (int i = 0; i < kernelShape.size(); ++i) + poolingLoops.pushBounds( + 0, poolDimMap, llvm::makeArrayRef(IVsAndConstants[i])); + poolingLoops.createIterateOp(); + + auto ipOuterLoops = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointToStart(poolingLoops.getIterateBlock()); + { + // 2.4 Emit the body of the pooling loop nest. + // Prepare indices to access a pixel in the input. + std::vector inputIndices; + { // Construct inputIndices + for (int i = 0; i < kernelOffset; ++i) + inputIndices.emplace_back(outputIndices[i]); + for (int i = kernelOffset; i < inputShape.size(); ++i) { + int j = i - kernelOffset; + if (isDilated) { + // hi = hp * dH + startH + Value index = rewriter.create( + loc, poolingLoops.getInductionVar(j), IVsAndConstants[j][5]); + index = rewriter.create(loc, index, poolStartValues[j]); + inputIndices.emplace_back(index); + } else { + // hi = hp + startH + inputIndices.emplace_back(rewriter.create( + loc, poolingLoops.getInductionVar(j), poolStartValues[j])); + } + } } - // 3.2 Do pooling. - auto loadData = rewriter.create(loc, inputOperand, dataIndices); - auto loadPartialResult = - rewriter.create(loc, alloc, resultIndices); - Value result = emitScalarOpFor(rewriter, loc, - op, resultElementType, {loadPartialResult, loadData}); - rewriter.create(loc, result, alloc, resultIndices); + // Apply pooling operation. + // output[n][c][ho][wo] = + // emitScalarOpFor(output[n][c][ho][wo], input[n, c, hi, wi]); + Value loadInput = + rewriter.create(loc, inputOperand, inputIndices); + Value loadPartialOutput = + rewriter.create(loc, alloc, outputIndices); + Value output = emitScalarOpFor(rewriter, loc, op, + outputElementType, {loadPartialOutput, loadInput}); + rewriter.create(loc, output, alloc, outputIndices); } + + // 2.5 Post-processing for the pooling window, e.g. taking average. + rewriter.restoreInsertionPoint(ipOuterLoops); + postProcessPoolingWindow(rewriter, loc, poolOp, alloc, + outputIndices, kernelShape, poolDimValues); } + + // Go back to the main region. + rewriter.restoreInsertionPoint(ipMainRegion); + rewriter.replaceOp(op, alloc); return success(); @@ -320,5 +565,6 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { void populateLoweringONNXPoolingOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx) { - patterns.insert(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); } diff --git a/src/Dialect/Krnl/KrnlHelper.cpp b/src/Dialect/Krnl/KrnlHelper.cpp index fa3818a..09af502 100644 --- a/src/Dialect/Krnl/KrnlHelper.cpp +++ b/src/Dialect/Krnl/KrnlHelper.cpp @@ -149,6 +149,15 @@ void KrnlIterateOperandPack::pushOperandBound(Value operand) { _operands.emplace_back(operand); } +void KrnlIterateOperandPack::pushAffineMapBound( + AffineMap map, ArrayRef operands) { + if (boundMaps.size() % 2 == 0) + _operands.emplace_back(inputLoops[boundMaps.size() / 2]); + boundMaps.emplace_back(AffineMapAttr::get(map)); + for (auto operand : operands) + _operands.emplace_back(operand); +} + BuildKrnlLoop::BuildKrnlLoop( ConversionPatternRewriter &rewriter, Location loc, int loopNum) : rewriter(rewriter), loc(loc), originalLoopNum(loopNum), pack(NULL), @@ -209,6 +218,13 @@ int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBound) { return pushCount++; } +int BuildKrnlLoop::pushBounds(int64_t lowerBound, AffineMap upperBound, + ArrayRef operandsForUpperBoundMap) { + pack->pushConstantBound(lowerBound); + pack->pushAffineMapBound(upperBound, operandsForUpperBoundMap); + return pushCount++; +} + int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand, int upperBoundMemRefIndex, bool upperBoundMustBeConstant) { pack->pushConstantBound(lowerBound); diff --git a/src/Dialect/Krnl/KrnlHelper.hpp b/src/Dialect/Krnl/KrnlHelper.hpp index 5c54e80..d4d155a 100644 --- a/src/Dialect/Krnl/KrnlHelper.hpp +++ b/src/Dialect/Krnl/KrnlHelper.hpp @@ -87,6 +87,8 @@ struct KrnlIterateOperandPack { void pushOperandBound(mlir::Value operand); + void pushAffineMapBound(mlir::AffineMap map, ArrayRef operands); + llvm::SmallVector getOperands() const { return _operands; } mlir::ArrayAttr getAttributes() const { @@ -159,6 +161,8 @@ public: // must be of MemRef type. int pushBounds(int64_t lowerBound, int64_t upperBound); int pushBounds(int64_t lowerBound, Value upperBound); + int pushBounds(int64_t lowerBound, AffineMap upperBound, + ArrayRef operandsForUpperBoundMap); int pushBounds(Value lowerBound, Value upperBound); int pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand, int upperBoundMemRefIndex, bool upperBoundMustBeConstant = false); diff --git a/src/Dialect/ONNX/ONNXOps.td b/src/Dialect/ONNX/ONNXOps.td index b220d06..9956f0d 100644 --- a/src/Dialect/ONNX/ONNXOps.td +++ b/src/Dialect/ONNX/ONNXOps.td @@ -98,7 +98,6 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> { def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut", [NoSideEffect, DeclareOpInterfaceMethods]> { - let hasCanonicalizer = 1; let summary = "ONNX MaxPool operation with a single output."; let description = [{ "ONNX MaxPool operation with a single output." diff --git a/src/Transform/ONNX/ONNXRewrite.cpp b/src/Transform/ONNX/ONNXRewrite.cpp index cde2cb3..a6f1b3f 100644 --- a/src/Transform/ONNX/ONNXRewrite.cpp +++ b/src/Transform/ONNX/ONNXRewrite.cpp @@ -33,7 +33,7 @@ bool hasNonZeroInArrayAttr(ArrayAttr attrs) { } // Create an ArrayAttr of IntergerAttr(s) of zero values. -// This function is used for padding attribute in MaxPoolSingleOut. +// This function is used for padding attribute in Conv. ArrayAttr createArrayAttrOfZeros( PatternRewriter &rewriter, ArrayAttr origAttrs) { int nElements = origAttrs.getValue().size(); @@ -51,7 +51,7 @@ ArrayAttr createArrayAttrOfZeros( // |_____| |_____| // nZeros nZeros // -// This function is used for padding attribute in MaxPoolSingleOut. +// This function is used for padding attribute in Conv. ArrayAttr insertZerosForNonPaddedDims( PatternRewriter &rewriter, ArrayAttr origAttrs, int extensionLength) { int nDims = (int)origAttrs.getValue().size() / 2; @@ -72,11 +72,6 @@ ArrayAttr insertZerosForNonPaddedDims( } // end anonymous namespace -/// on the ONNXMaxPoolSingleOutOp. -void ONNXMaxPoolSingleOutOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} /// on the ONNXConvOp. void ONNXConvOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { diff --git a/src/Transform/ONNX/ONNXRewrite.td b/src/Transform/ONNX/ONNXRewrite.td index 8e0a559..496cdc2 100644 --- a/src/Transform/ONNX/ONNXRewrite.td +++ b/src/Transform/ONNX/ONNXRewrite.td @@ -33,13 +33,8 @@ class StringAttrOfValue: class FloatAttrOfValue: NativeCodeCall<"FloatAttr::get($0.getType().cast().getElementType(), " # val # ")">; -// Create a FloatAttr for the negative infinity. -def FloatAttrOfNegativeInfinity: - NativeCodeCall<"FloatAttr::get($0.getType().cast().getElementType(), " - "-std::numeric_limits::infinity())">; - // Create an ArrayAttr of IntergerAttr(s) of zero values. -// This function is used for padding attribute in MaxPoolSingleOut. +// This function is used for padding attribute in Conv. def createArrayAttrOfZerosFrom: NativeCodeCall<"createArrayAttrOfZeros($_builder, $0)">; @@ -53,7 +48,7 @@ def createArrayAttrOfZerosFrom: // |_____| |_____| // nZeros nZeros // -// This function is used for padding attribute in MaxPoolSingleOut. +// This function is used for padding attribute in Conv. class insertZerosForNonPaddedDims: NativeCodeCall<"insertZerosForNonPaddedDims($_builder, $0," # extensionLength # ")">; @@ -66,37 +61,6 @@ def HasNonZeroInArrayAttr: Constraint, class IsNotStringAttrOfValue: Constraint().getValue() != \"" # val # "\"">>; -//===----------------------------------------------------------------------===// -// Rewrite: -// %0 = onnx.MaxPoolSingleOutOp(%D : tensor) -// {pads = [b0, b1, ... bK, e0, e1, ..., eK]} -> -// tensor -// -// as: -// %0 = onnx.PadConstantValuePadOp(%D) -// {pads = [0, 0, b0, b1, ... bK, 0, 0, e0, e1, ..., eK]} -> -// tensor -// %1 = onnx.MaxPoolSingleOut(%0 : tensor) {pads = [0, ..., 0]} -> -// tensor -//===----------------------------------------------------------------------===// - -def MaxPoolSingleOutOpPaddingPattern: Pat< - (ONNXMaxPoolSingleOutOp:$res - $x, - $auto_pad, $ceil_mode, $dilation, $kernel_shape, - $pads, - $storage_order, $strides), - (ONNXMaxPoolSingleOutOp - (ONNXPadConstantValuePadOp $x, - (insertZerosForNonPaddedDims<2> $pads), - (FloatAttrOfNegativeInfinity $res), - (StringAttrOfValue<"constant">)), - $auto_pad, $ceil_mode, $dilation, $kernel_shape, - (createArrayAttrOfZerosFrom $pads), - $storage_order, $strides), - [(HasNonZeroInArrayAttr:$pads), (IsNotStringAttrOfValue<"VALID"> $auto_pad)] ->; - //===----------------------------------------------------------------------===// // Rewrite: // %0 = onnx.ConvOp(%D : tensor, %K) diff --git a/test/backend/test.py b/test/backend/test.py index 02d1959..c0cb3b3 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -327,7 +327,7 @@ test_to_enable = [ "test_batchnorm_epsilon_cpu", "test_batchnorm_example_cpu", - # Pooling + # MaxPoolSingleOut "test_maxpool_1d_default_cpu", "test_maxpool_2d_ceil_cpu", "test_maxpool_2d_default_cpu", @@ -341,6 +341,21 @@ test_to_enable = [ "test_maxpool_2d_strides_cpu", "test_maxpool_3d_default_cpu", + # AveragePool + "test_averagepool_1d_default_cpu", + "test_averagepool_2d_ceil_cpu", + "test_averagepool_2d_default_cpu", + "test_averagepool_2d_pads_count_include_pad_cpu", + "test_averagepool_2d_pads_cpu", + "test_averagepool_2d_precomputed_pads_count_include_pad_cpu", + "test_averagepool_2d_precomputed_pads_cpu", + "test_averagepool_2d_precomputed_same_upper_cpu", + "test_averagepool_2d_precomputed_strides_cpu", + "test_averagepool_2d_same_lower_cpu", + "test_averagepool_2d_same_upper_cpu", + "test_averagepool_2d_strides_cpu", + "test_averagepool_3d_default_cpu", + ] # Extract name of all test cases. diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index eb27b39..123d77f 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -94,27 +94,3 @@ func @test_gemm_add_fusion_rank3(%arg0: tensor<128x128x256xf32>, %arg1: tensor<1 // return [[GEMM]] : tensor<*xf32> } -// ----- - -//CHECK-LABEL: @test_maxpoolsingleout_split(%{{.*}}: tensor<5x5x32x32xf32>) -> tensor<5x5x36x38xf32> { -func @test_maxpoolsingleout_split(%arg0: tensor<5x5x32x32xf32>) -> tensor<5x5x36x38xf32> { - %0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0, kernel_shape = [5,3], pads = [1, 2, 3, 4] } : (tensor<5x5x32x32xf32>) -> tensor<5x5x36x38xf32> - "std.return"(%0) : (tensor<5x5x36x38xf32>) -> () - - // CHECK-NEXT: %0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0xFF800000 : f32, mode = "constant", pads = [0, 0, 1, 2, 0, 0, 3, 4]} : (tensor<5x5x32x32xf32>) -> tensor<5x5x36x38xf32> - // CHECK-NEXT: %1 = "onnx.MaxPoolSingleOut"(%0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, kernel_shape = [5, 3], pads = [0, 0, 0, 0], storage_order = 0 : i64} : (tensor<5x5x36x38xf32>) -> tensor<5x5x36x38xf32> - // CHECK-NEXT: return %1 : tensor<5x5x36x38xf32> -} - -// ----- - -//CHECK-LABEL: @test_maxpoolsingleout_split_unknown_dims(%{{.*}}: tensor<*xf32>) -> tensor<*xf32> { -func @test_maxpoolsingleout_split_unknown_dims(%arg0: tensor<*xf32>) -> tensor<*xf32> { - %0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", ceil_mode = 0, kernel_shape = [5,3], pads = [1, 2, 3, 4] } : (tensor<*xf32>) -> tensor<*xf32> - "std.return"(%0) : (tensor<*xf32>) -> () - - // CHECK-NEXT: %0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0xFF800000 : f32, mode = "constant", pads = [0, 0, 1, 2, 0, 0, 3, 4]} : (tensor<*xf32>) -> tensor<*xf32> - // CHECK-NEXT: %1 = "onnx.MaxPoolSingleOut"(%0) {auto_pad = "NOTSET", ceil_mode = 0 : i64, kernel_shape = [5, 3], pads = [0, 0, 0, 0], storage_order = 0 : i64} : (tensor<*xf32>) -> tensor<*xf32> - // CHECK-NEXT: return %1 : tensor<*xf32> -} - diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 9f0af27..2757cf4 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -1505,231 +1505,6 @@ func @test_batchnorm_testmode_1d(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>, %a // ----- -func @test_maxpooling_singleout_no_pad(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*xf32> { - %0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [2, 2]} : (tensor<1x3x32x32xf32>) -> tensor<*xf32> - "std.return"(%0) : (tensor<*xf32>) -> () - - // CHECK-LABEL: test_maxpooling_singleout_no_pad - // CHECK: [[RES:%.+]] = alloc() : memref<1x3x31x31xf32> - // CHECK: [[DEF_LOOPS_0:%.+]]:4 = krnl.define_loops 4 - // CHECK: [[OPT_LOOPS_0:%.+]]:4 = krnl.optimize_loops { - // CHECK: krnl.return_loops [[DEF_LOOPS_0]]#0, [[DEF_LOOPS_0]]#1, [[DEF_LOOPS_0]]#2, [[DEF_LOOPS_0]]#3 - // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) - // CHECK: krnl.iterate([[OPT_LOOPS_0]]#0, [[OPT_LOOPS_0]]#1, [[OPT_LOOPS_0]]#2, [[OPT_LOOPS_0]]#3) with ([[DEF_LOOPS_0]]#0 -> %arg1 = 0 to 1, [[DEF_LOOPS_0]]#1 -> %arg2 = 0 to 3, [[DEF_LOOPS_0]]#2 -> %arg3 = 0 to 31, [[DEF_LOOPS_0]]#3 -> %arg4 = 0 to 31) { - // CHECK: [[NEGATIVE_INFINITY:%.+]] = constant 0xFF800000 : f32 - // CHECK: store [[NEGATIVE_INFINITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> - // CHECK: [[DEF_LOOPS_1:%.+]]:2 = krnl.define_loops 2 - // CHECK: [[OPT_LOOPS_1:%.+]]:2 = krnl.optimize_loops { - // CHECK: krnl.return_loops [[DEF_LOOPS_1]]#0, [[DEF_LOOPS_1]]#1 - // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: krnl.iterate([[OPT_LOOPS_1]]#0, [[OPT_LOOPS_1]]#1) with ([[DEF_LOOPS_1]]#0 -> %arg5 = 0 to 2, [[DEF_LOOPS_1]]#1 -> %arg6 = 0 to 2) { - // CHECK: [[H:%.+]] = addi %arg3, %arg5 : index - // CHECK: [[W:%.+]] = addi %arg4, %arg6 : index - // CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<1x3x32x32xf32> - // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> - // CHECK: [[COMPARE:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32 - // CHECK: [[SELECT:%.+]] = select [[COMPARE]], [[LOAD_Y]], [[LOAD_X]] : f32 - // CHECK: store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> - // CHECK: } - // CHECK: } - // CHECK: return [[RES]] : memref<1x3x31x31xf32> -} - -// ----- - -func @test_maxpooling_singleout_no_pad_w_strides(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*xf32> { - %0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [2, 2], strides = [2, 2]} : (tensor<1x3x32x32xf32>) -> tensor<*xf32> - "std.return"(%0) : (tensor<*xf32>) -> () - - // CHECK-LABEL: test_maxpooling_singleout_no_pad_w_strides - // CHECK: [[RES:%.+]] = alloc() : memref<1x3x16x16xf32> - // CHECK: [[DEF_LOOPS_0:%.+]]:4 = krnl.define_loops 4 - // CHECK: [[OPT_LOOPS_0:%.+]]:4 = krnl.optimize_loops { - // CHECK: krnl.return_loops [[DEF_LOOPS_0]]#0, [[DEF_LOOPS_0]]#1, [[DEF_LOOPS_0]]#2, [[DEF_LOOPS_0]]#3 - // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) - // CHECK: krnl.iterate([[OPT_LOOPS_0]]#0, [[OPT_LOOPS_0]]#1, [[OPT_LOOPS_0]]#2, [[OPT_LOOPS_0]]#3) with ([[DEF_LOOPS_0]]#0 -> %arg1 = 0 to 1, [[DEF_LOOPS_0]]#1 -> %arg2 = 0 to 3, [[DEF_LOOPS_0]]#2 -> %arg3 = 0 to 16, [[DEF_LOOPS_0]]#3 -> %arg4 = 0 to 16) { - // CHECK: [[NEGATIVE_INFINITY:%.+]] = constant 0xFF800000 : f32 - // CHECK: store [[NEGATIVE_INFINITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> - // CHECK: [[DEF_LOOPS_1:%.+]]:2 = krnl.define_loops 2 - // CHECK: [[OPT_LOOPS_1:%.+]]:2 = krnl.optimize_loops { - // CHECK: krnl.return_loops [[DEF_LOOPS_1]]#0, [[DEF_LOOPS_1]]#1 - // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: krnl.iterate([[OPT_LOOPS_1]]#0, [[OPT_LOOPS_1]]#1) with ([[DEF_LOOPS_1]]#0 -> %arg5 = 0 to 2, [[DEF_LOOPS_1]]#1 -> %arg6 = 0 to 2) { - // CHECK: [[STRIDE_0:%.+]] = constant 2 : index - // CHECK: [[MUL_0:%.+]] = muli [[STRIDE_0]], %arg3 : index - // CHECK: [[H:%.+]] = addi [[MUL_0]], %arg5 : index - // CHECK: [[STRIDE_1:%.+]] = constant 2 : index - // CHECK: [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg4 : index - // CHECK: [[W:%.+]] = addi [[MUL_1]], %arg6 : index - // CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<1x3x32x32xf32> - // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> - // CHECK: [[COMPARE:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32 - // CHECK: [[SELECT:%.+]] = select [[COMPARE]], [[LOAD_Y]], [[LOAD_X]] : f32 - // CHECK: store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> - // CHECK: } - // CHECK: } - // CHECK: return [[RES]] : memref<1x3x16x16xf32> -} - -// ----- - -func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*xf32> { - %0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [3, 3], strides = [2, 2], ceil_mode = 1} : (tensor<1x3x32x32xf32>) -> tensor<*xf32> - "std.return"(%0) : (tensor<*xf32>) -> () - - // CHECK-LABEL: test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode - // CHECK: [[RES:%.+]] = alloc() : memref<1x3x16x16xf32> - // CHECK: [[DEF_LOOPS_0:%.+]]:4 = krnl.define_loops 4 - // CHECK: [[OPT_LOOPS_0:%.+]]:4 = krnl.optimize_loops { - // CHECK: krnl.return_loops [[DEF_LOOPS_0]]#0, [[DEF_LOOPS_0]]#1, [[DEF_LOOPS_0]]#2, [[DEF_LOOPS_0]]#3 - // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) - // CHECK: krnl.iterate([[OPT_LOOPS_0]]#0, [[OPT_LOOPS_0]]#1, [[OPT_LOOPS_0]]#2, [[OPT_LOOPS_0]]#3) with ([[DEF_LOOPS_0]]#0 -> %arg1 = 0 to 1, [[DEF_LOOPS_0]]#1 -> %arg2 = 0 to 3, [[DEF_LOOPS_0]]#2 -> %arg3 = 0 to 16, [[DEF_LOOPS_0]]#3 -> %arg4 = 0 to 16) { - // CHECK: [[NEGATIVE_INFINITY:%.+]] = constant 0xFF800000 : f32 - // CHECK: store [[NEGATIVE_INFINITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> - // CHECK: [[DEF_LOOPS_1:%.+]]:2 = krnl.define_loops 2 - // CHECK: [[OPT_LOOPS_1:%.+]]:2 = krnl.optimize_loops { - // CHECK: krnl.return_loops [[DEF_LOOPS_1]]#0, [[DEF_LOOPS_1]]#1 - // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: krnl.iterate([[OPT_LOOPS_1]]#0, [[OPT_LOOPS_1]]#1) with ([[DEF_LOOPS_1]]#0 -> %arg5 = 0 to 3, [[DEF_LOOPS_1]]#1 -> %arg6 = 0 to 3) { - // CHECK: [[STRIDE_0:%.+]] = constant 2 : index - // CHECK: [[MUL_0:%.+]] = muli [[STRIDE_0]], %arg3 : index - // CHECK: [[SPATIAL_H:%.+]] = addi [[MUL_0]], %arg5 : index - // CHECK: [[UPPER_INDEX_0:%.+]] = constant 31 : index - // CHECK: [[GREATER_THAN_UPPER_0:%.+]] = cmpi "sgt", [[SPATIAL_H]], [[UPPER_INDEX_0]] : index - // CHECK: [[H:%.+]] = select [[GREATER_THAN_UPPER_0]], [[UPPER_INDEX_0]], [[SPATIAL_H]] : index - - // CHECK: [[STRIDE_1:%.+]] = constant 2 : index - // CHECK: [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg4 : index - // CHECK: [[SPATIAL_W:%.+]] = addi [[MUL_1]], %arg6 : index - // CHECK: [[UPPER_INDEX_1:%.+]] = constant 31 : index - // CHECK: [[GREATER_THAN_UPPER_1:%.+]] = cmpi "sgt", [[SPATIAL_W]], [[UPPER_INDEX_1]] : index - // CHECK: [[W:%.+]] = select [[GREATER_THAN_UPPER_1]], [[UPPER_INDEX_1]], [[SPATIAL_W]] : index - - // CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<1x3x32x32xf32> - // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> - // CHECK: [[CMP_2:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32 - // CHECK: [[SELECT:%.+]] = select [[CMP_2]], [[LOAD_Y]], [[LOAD_X]] : f32 - // CHECK: store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32> - // CHECK: } - // CHECK: } - // CHECK: return [[RES]] : memref<1x3x16x16xf32> -} - -// ----- - -func @test_maxpooling_singleout_no_pad_w_strides_w_dilation(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*xf32> { - %0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [3, 3], strides = [2, 2], dilations = [2, 2]} : (tensor<1x3x32x32xf32>) -> tensor<*xf32> - "std.return"(%0) : (tensor<*xf32>) -> () - - // CHECK-LABEL: test_maxpooling_singleout_no_pad_w_strides_w_dilation - // CHECK: [[RES:%.+]] = alloc() : memref<1x3x14x14xf32> - // CHECK: [[DEF_LOOPS_0:%.+]]:4 = krnl.define_loops 4 - // CHECK: [[OPT_LOOPS_0:%.+]]:4 = krnl.optimize_loops { - // CHECK: krnl.return_loops [[DEF_LOOPS_0]]#0, [[DEF_LOOPS_0]]#1, [[DEF_LOOPS_0]]#2, [[DEF_LOOPS_0]]#3 - // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) - // CHECK: krnl.iterate([[OPT_LOOPS_0]]#0, [[OPT_LOOPS_0]]#1, [[OPT_LOOPS_0]]#2, [[OPT_LOOPS_0]]#3) with ([[DEF_LOOPS_0]]#0 -> %arg1 = 0 to 1, [[DEF_LOOPS_0]]#1 -> %arg2 = 0 to 3, [[DEF_LOOPS_0]]#2 -> %arg3 = 0 to 14, [[DEF_LOOPS_0]]#3 -> %arg4 = 0 to 14) { - // CHECK: [[NEGATIVE_INFINITY:%.+]] = constant 0xFF800000 : f32 - // CHECK: store [[NEGATIVE_INFINITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x14x14xf32> - // CHECK: [[DEF_LOOPS_1:%.+]]:2 = krnl.define_loops 2 - // CHECK: [[OPT_LOOPS_1:%.+]]:2 = krnl.optimize_loops { - // CHECK: krnl.return_loops [[DEF_LOOPS_1]]#0, [[DEF_LOOPS_1]]#1 - // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: krnl.iterate([[OPT_LOOPS_1]]#0, [[OPT_LOOPS_1]]#1) with ([[DEF_LOOPS_1]]#0 -> %arg5 = 0 to 3, [[DEF_LOOPS_1]]#1 -> %arg6 = 0 to 3) { - // CHECK: [[STRIDE_0:%.+]] = constant 2 : index - // CHECK: [[MUL_0:%.+]] = muli [[STRIDE_0]], %arg3 : index - // CHECK: [[STRIDE_1:%.+]] = constant 2 : index - // CHECK: [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg5 : index - // CHECK: [[SPATIAL_H:%.+]] = addi [[MUL_0]], [[MUL_1]] : index - // CHECK: [[UPPER_INDEX_0:%.+]] = constant 31 : index - // CHECK: [[GREATER_THAN_UPPER_0:%.+]] = cmpi "sgt", [[SPATIAL_H]], [[UPPER_INDEX_0]] : index - // CHECK: [[H:%.+]] = select [[GREATER_THAN_UPPER_0]], [[UPPER_INDEX_0]], [[SPATIAL_H]] : index - - // CHECK: [[STRIDE_0_1:%.+]] = constant 2 : index - // CHECK: [[MUL_0_1:%.+]] = muli [[STRIDE_0_1]], %arg4 : index - // CHECK: [[STRIDE_1_1:%.+]] = constant 2 : index - // CHECK: [[MUL_1_1:%.+]] = muli [[STRIDE_1_1]], %arg6 : index - // CHECK: [[SPATIAL_W:%.+]] = addi [[MUL_0_1]], [[MUL_1_1]] : index - // CHECK: [[UPPER_INDEX_1:%.+]] = constant 31 : index - // CHECK: [[GREATER_THAN_UPPER_1:%.+]] = cmpi "sgt", [[SPATIAL_W]], [[UPPER_INDEX_1]] : index - // CHECK: [[W:%.+]] = select [[GREATER_THAN_UPPER_1]], [[UPPER_INDEX_1]], [[SPATIAL_W]] : index - - // CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<1x3x32x32xf32> - // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x14x14xf32> - // CHECK: [[CMP_2:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32 - // CHECK: [[SELECT:%.+]] = select [[CMP_2]], [[LOAD_Y]], [[LOAD_X]] : f32 - // CHECK: store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x14x14xf32> - // CHECK: } - // CHECK: } - // CHECK: return [[RES]] : memref<1x3x14x14xf32> -} - -// ----- - -func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode_w_unknown_dims(%arg0 : tensor) -> tensor<*xf32> { - %0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [3, 3], strides = [2, 2], ceil_mode = 1} : (tensor) -> tensor<*xf32> - "std.return"(%0) : (tensor<*xf32>) -> () - - // CHECK-LABEL: test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode_w_unknown_dims - - // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref - // CHECK: [[ZERO:%.+]] = constant 0 : i64 - // CHECK: [[ONE:%.+]] = constant 1 : i64 - // CHECK: [[DIM_1:%.+]] = dim %arg0, 2 : memref - // CHECK: [[DIM_1_i64:%.+]] = index_cast [[DIM_1]] : index to i64 - // CHECK: [[KERNEL_PAD_DILATION:%.+]] = constant -3 : i64 - // CHECK: [[NUMERATOR:%.+]] = addi [[DIM_1_i64]], [[KERNEL_PAD_DILATION]] : i64 - // CHECK: [[DENOMINATOR:%.+]] = constant 2 : i64 - // CHECK: [[DIV:%.+]] = divi_signed [[NUMERATOR]], [[DENOMINATOR]] : i64 - // CHECK: [[REMAINDER:%.+]] = remi_signed [[NUMERATOR]], [[DENOMINATOR]] : i64 - // CHECK: [[IS_ZERO:%.+]] = cmpi "eq", [[REMAINDER]], [[ZERO]] : i64 - // CHECK: [[DIV_PLUS_ONE:%.+]] = addi [[DIV]], [[ONE]] : i64 - // CHECK: [[SELECT:%.+]] = select [[IS_ZERO]], [[DIV]], [[DIV_PLUS_ONE]] : i64 - // CHECK: [[SELECT_PLUS_ONE:%.+]] = addi [[SELECT]], [[ONE]] : i64 - // CHECK: [[DIM_1_FINAL:%.+]] = index_cast [[SELECT_PLUS_ONE]] : i64 to index - // CHECK: [[RES:%.+]] = alloc([[DIM_0]], [[DIM_1_FINAL]]) : memref - - // CHECK: [[DEF_LOOPS_0:%.+]]:4 = krnl.define_loops 4 - // CHECK: [[OPT_LOOPS_0:%.+]]:4 = krnl.optimize_loops { - // CHECK: krnl.return_loops [[DEF_LOOPS_0]]#0, [[DEF_LOOPS_0]]#1, [[DEF_LOOPS_0]]#2, [[DEF_LOOPS_0]]#3 - // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) - // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref - // CHECK: [[DIM_3:%.+]] = dim [[RES]], 2 : memref - // CHECK: krnl.iterate([[OPT_LOOPS_0]]#0, [[OPT_LOOPS_0]]#1, [[OPT_LOOPS_0]]#2, [[OPT_LOOPS_0]]#3) with ([[DEF_LOOPS_0]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS_0]]#1 -> %arg2 = 0 to 3, [[DEF_LOOPS_0]]#2 -> %arg3 = 0 to [[DIM_3]], [[DEF_LOOPS_0]]#3 -> %arg4 = 0 to 16) { - // CHECK: [[NEGATIVE_INFINITY:%.+]] = constant 0xFF800000 : f32 - // CHECK: store [[NEGATIVE_INFINITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref - // CHECK: [[DEF_LOOPS_1:%.+]]:2 = krnl.define_loops 2 - // CHECK: [[OPT_LOOPS_1:%.+]]:2 = krnl.optimize_loops { - // CHECK: krnl.return_loops [[DEF_LOOPS_1]]#0, [[DEF_LOOPS_1]]#1 - // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: krnl.iterate([[OPT_LOOPS_1]]#0, [[OPT_LOOPS_1]]#1) with ([[DEF_LOOPS_1]]#0 -> %arg5 = 0 to 3, [[DEF_LOOPS_1]]#1 -> %arg6 = 0 to 3) { - // CHECK: [[STRIDE_0:%.+]] = constant 2 : index - // CHECK: [[MUL_0:%.+]] = muli [[STRIDE_0]], %arg3 : index - // CHECK: [[SPATIAL_H:%.+]] = addi [[MUL_0]], %arg5 : index - // CHECK: [[DIM_0_0:%.+]] = dim %arg0, 2 : memref - // CHECK: [[ONE_INDEX:%.+]] = constant 1 : index - // CHECK: [[UPPER_INDEX_0:%.+]] = subi [[DIM_0_0]], [[ONE_INDEX]] : index - // CHECK: [[GREATER_THAN_UPPER_0:%.+]] = cmpi "sgt", [[SPATIAL_H]], [[UPPER_INDEX_0]] : index - // CHECK: [[H:%.+]] = select [[GREATER_THAN_UPPER_0]], [[UPPER_INDEX_0]], [[SPATIAL_H]] : index - - // CHECK: [[STRIDE_1:%.+]] = constant 2 : index - // CHECK: [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg4 : index - // CHECK: [[SPATIAL_W:%.+]] = addi [[MUL_1]], %arg6 : index - // CHECK: [[UPPER_INDEX_1:%.+]] = constant 31 : index - // CHECK: [[GREATER_THAN_UPPER_1:%.+]] = cmpi "sgt", [[SPATIAL_W]], [[UPPER_INDEX_1]] : index - // CHECK: [[W:%.+]] = select [[GREATER_THAN_UPPER_1]], [[UPPER_INDEX_1]], [[SPATIAL_W]] : index - - // CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref - // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref - // CHECK: [[CMP_2:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32 - // CHECK: [[SELECT:%.+]] = select [[CMP_2]], [[LOAD_Y]], [[LOAD_X]] : f32 - // CHECK: store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref - // CHECK: } - // CHECK: } - // CHECK: return [[RES]] : memref -} - -// ----- - func @test_abs_float(%arg0 : tensor) -> tensor<*xf32> { %0 = "onnx.Abs"(%arg0) : (tensor) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () @@ -1810,7 +1585,6 @@ func @test_constant_dense_2d_value(%arg0: tensor<1xf32>) -> tensor<*xf32> { // CHECK: return [[RES]] : memref<3x2xf32> } - // ----- func @test_concat_1(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<5x5x3x32xf32>, %arg2 : tensor<5x5x5x32xf32>) -> tensor<5x5x9x32xf32> { @@ -1849,3 +1623,133 @@ func @test_concat_1(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<5x5x3x32xf32>, // CHECK: return [[RES]] : memref<5x5x9x32xf32> } + +// ----- + +func @test_pool_general_computation(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*xf32> { + %0 = "onnx.AveragePool"(%arg0) {auto_pad = "NOTSET", kernel_shape = [2, 2]} : (tensor<1x3x32x32xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-DAG: #{{.*}} = affine_map<(d0)[s0, s1, s2, s3, s4] -> ((s2 ceildiv s4) * s4 - s2, d0 * s3 - s2)> + // CHECK-DAG: #{{.*}} = affine_map<(d0)[s0, s1, s2, s3, s4] -> (s0, d0 * s3 + (s1 - 1) * s4 - s2 + 1)> + // CHECK-DAG: #{{.*}} = affine_map<() -> (0)> + // CHECK-DAG: #{{.*}} = affine_map<(d0)[s0, s1, s2, s3, s4] -> (s0 - ((s2 ceildiv s4) * s4 - s2), -(d0 * s3 - s2) + s0, d0 * s3 + (s1 - 1) * s4 - s2 - ((s2 ceildiv s4) * s4 - s2) + 1, d0 * s3 + (s1 - 1) * s4 - s2 - (d0 * s3 - s2) + 1)> + + // CHECK-LABEL: @test_pool_general_computation + + // CHECK: [[RES:%.+]] = alloc() : memref<1x3x31x31xf32> + // CHECK: [[IDENTITY:%.+]] = constant 0.000000e+00 : f32 + + // CHECK: [[OUTPUT_LOOPS:%.+]]:4 = krnl.define_loops 4 + // CHECK: [[OPT_OUTPUT_LOOPS:%.+]]:4 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[OUTPUT_LOOPS]]#0, [[OUTPUT_LOOPS]]#1, [[OUTPUT_LOOPS]]#2, [[OUTPUT_LOOPS]]#3 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_OUTPUT_LOOPS]]#0, [[OPT_OUTPUT_LOOPS]]#1, [[OPT_OUTPUT_LOOPS]]#2, [[OPT_OUTPUT_LOOPS]]#3) with ([[OUTPUT_LOOPS]]#0 -> %arg1 = 0 to 1, [[OUTPUT_LOOPS]]#1 -> %arg2 = 0 to 3, [[OUTPUT_LOOPS]]#2 -> %arg3 = 0 to 31, [[OUTPUT_LOOPS]]#3 -> %arg4 = 0 to 31) { + + // CHECK: store [[IDENTITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> + + // CHECK: [[POOL_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_POOL_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[POOL_LOOPS]]#0, [[POOL_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_POOL_LOOPS]]#0, [[OPT_POOL_LOOPS]]#1) with ([[POOL_LOOPS]]#0 -> %arg5 = 0 to min #map3(%arg3)[%c32, %c2, %c0, %c1, %c1_0], [[POOL_LOOPS]]#1 -> %arg6 = 0 to min #map3(%arg4)[%c32_1, %c2_2, %c0_3, %c1_4, %c1_5]) { + // CHECK: {{.*}} = load %arg0[%arg1, %arg2, {{.*}}, {{.*}}] : memref<1x3x32x32xf32> + // CHECK: {{.*}} = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> + // CHECK: store {{.*}}, [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> + // CHECK: } + + // CHECK: {{.*}} = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> + // CHECK: store {{.*}}, [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> + // CHECK: } +} + +// ----- + +func @test_averagepool_identity_value(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*xf32> { + %0 = "onnx.AveragePool"(%arg0) {auto_pad = "NOTSET", kernel_shape = [2, 2]} : (tensor<1x3x32x32xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: @test_averagepool_identity_value + // CHECK: [[RES:%.+]] = alloc() : memref<1x3x31x31xf32> + // CHECK: [[IDENTITY:%.+]] = constant 0.000000e+00 : f32 + // CHECK: store [[IDENTITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> +} + +// ----- + +func @test_maxpool_identity_value(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*xf32> { + %0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [2, 2]} : (tensor<1x3x32x32xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: @test_maxpool_identity_value + // CHECK: [[RES:%.+]] = alloc() : memref<1x3x31x31xf32> + // CHECK: [[IDENTITY:%.+]] = constant 0xFF800000 : f32 + // CHECK: store [[IDENTITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> +} + +// ----- + +func @test_averagepool_pooling_operation(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*xf32> { + %0 = "onnx.AveragePool"(%arg0) {auto_pad = "NOTSET", kernel_shape = [2, 2]} : (tensor<1x3x32x32xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: @test_averagepool_pooling_operation + // CHECK: [[RES:%.+]] = alloc() : memref<1x3x31x31xf32> + + // CHECK: [[OUTPUT_LOOPS:%.+]]:4 = krnl.define_loops 4 + // CHECK: [[OPT_OUTPUT_LOOPS:%.+]]:4 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[OUTPUT_LOOPS]]#0, [[OUTPUT_LOOPS]]#1, [[OUTPUT_LOOPS]]#2, [[OUTPUT_LOOPS]]#3 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_OUTPUT_LOOPS]]#0, [[OPT_OUTPUT_LOOPS]]#1, [[OPT_OUTPUT_LOOPS]]#2, [[OPT_OUTPUT_LOOPS]]#3) with ([[OUTPUT_LOOPS]]#0 -> %arg1 = 0 to 1, [[OUTPUT_LOOPS]]#1 -> %arg2 = 0 to 3, [[OUTPUT_LOOPS]]#2 -> %arg3 = 0 to 31, [[OUTPUT_LOOPS]]#3 -> %arg4 = 0 to 31) { + + // CHECK: [[POOL_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_POOL_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[POOL_LOOPS]]#0, [[POOL_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_POOL_LOOPS]]#0, [[OPT_POOL_LOOPS]]#1) with ([[POOL_LOOPS]]#0 -> %arg5 = 0 to min #map3(%arg3)[%c32, %c2, %c0, %c1, %c1_0], [[POOL_LOOPS]]#1 -> %arg6 = 0 to min #map3(%arg4)[%c32_1, %c2_2, %c0_3, %c1_4, %c1_5]) { + + // CHECK: [[INPUT_LOAD:%.+]] = load %arg0[%arg1, %arg2, {{.*}}, {{.*}}] : memref<1x3x32x32xf32> + // CHECK: [[OUTPUT_LOAD:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> + // CHECK: [[SUM:%.+]] = addf [[OUTPUT_LOAD]], [[INPUT_LOAD]] : f32 + // CHECK: store [[SUM]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> + // CHECK: } + + // CHECK: [[NUMERATOR:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> + // CHECK: [[AVERAGE:%.+]] = divf [[NUMERATOR]], {{.*}} : f32 + // CHECK: store [[AVERAGE]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> + // CHECK: } +} + +// ----- + +func @test_maxpool_pooling_operation(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*xf32> { + %0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [2, 2]} : (tensor<1x3x32x32xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: @test_maxpool_pooling_operation + // CHECK: [[RES:%.+]] = alloc() : memref<1x3x31x31xf32> + + // CHECK: [[OUTPUT_LOOPS:%.+]]:4 = krnl.define_loops 4 + // CHECK: [[OPT_OUTPUT_LOOPS:%.+]]:4 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[OUTPUT_LOOPS]]#0, [[OUTPUT_LOOPS]]#1, [[OUTPUT_LOOPS]]#2, [[OUTPUT_LOOPS]]#3 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_OUTPUT_LOOPS]]#0, [[OPT_OUTPUT_LOOPS]]#1, [[OPT_OUTPUT_LOOPS]]#2, [[OPT_OUTPUT_LOOPS]]#3) with ([[OUTPUT_LOOPS]]#0 -> %arg1 = 0 to 1, [[OUTPUT_LOOPS]]#1 -> %arg2 = 0 to 3, [[OUTPUT_LOOPS]]#2 -> %arg3 = 0 to 31, [[OUTPUT_LOOPS]]#3 -> %arg4 = 0 to 31) { + + // CHECK: [[POOL_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_POOL_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[POOL_LOOPS]]#0, [[POOL_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_POOL_LOOPS]]#0, [[OPT_POOL_LOOPS]]#1) with ([[POOL_LOOPS]]#0 -> %arg5 = 0 to min #map3(%arg3)[%c32, %c2, %c0, %c1, %c1_0], [[POOL_LOOPS]]#1 -> %arg6 = 0 to min #map3(%arg4)[%c32_1, %c2_2, %c0_3, %c1_4, %c1_5]) { + + // CHECK: [[INPUT_LOAD:%.+]] = load %arg0[%arg1, %arg2, {{.*}}, {{.*}}] : memref<1x3x32x32xf32> + // CHECK: [[OUTPUT_LOAD:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> + // CHECK: [[GREATER:%.+]] = cmpf "ogt", [[OUTPUT_LOAD]], [[INPUT_LOAD]] : f32 + // CHECK: [[SELECT:%.+]] = select [[GREATER]], [[OUTPUT_LOAD]], [[INPUT_LOAD]] : f32 + // CHECK: store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> + // CHECK: } + + // CHECK-NOT: {{.*}} = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> + // CHECK-NOT: store {{.*}}, [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> + // CHECK: } +} +