2020-03-19 16:48:09 +08:00
|
|
|
//===---------------- Pooling.cpp - Lowering Pooling Ops ------------------===//
|
2020-03-05 03:27:21 +08:00
|
|
|
//
|
|
|
|
// Copyright 2019 The IBM Research Authors.
|
|
|
|
//
|
|
|
|
// =============================================================================
|
|
|
|
//
|
|
|
|
// This file lowers the ONNX Pooling Operators to Krnl dialect.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-03-19 16:48:09 +08:00
|
|
|
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
|
2020-03-05 03:27:21 +08:00
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
|
|
|
// Identity values
|
|
|
|
template <>
|
2020-03-06 03:21:00 +08:00
|
|
|
Value getIdentityValue<ONNXMaxPoolSingleOutOp>(
|
|
|
|
ConversionPatternRewriter &rewriter, Location loc, Type type) {
|
|
|
|
return emitNegativeInfinityConstantOp(rewriter, loc, type);
|
2020-03-05 03:27:21 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
Value mapToLowerScalarOp<ONNXMaxPoolSingleOutOp>(Operation *op,
|
|
|
|
ArrayRef<Type> result_types, ArrayRef<Value> operands,
|
|
|
|
ConversionPatternRewriter &rewriter) {
|
|
|
|
auto loc = op->getLoc();
|
|
|
|
Value lhs = operands[0];
|
|
|
|
Value rhs = operands[1];
|
|
|
|
auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
|
|
|
|
auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern {
|
|
|
|
ONNXMaxPoolSingleOutOpLowering(MLIRContext *ctx)
|
|
|
|
: ConversionPattern(
|
|
|
|
mlir::ONNXMaxPoolSingleOutOp::getOperationName(), 1, ctx) {}
|
|
|
|
|
|
|
|
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
|
|
ConversionPatternRewriter &rewriter) const final {
|
2020-03-31 23:55:27 +08:00
|
|
|
ONNXMaxPoolSingleOutOpOperandAdaptor operandAdaptor(operands);
|
2020-03-05 03:27:21 +08:00
|
|
|
auto loc = op->getLoc();
|
|
|
|
|
|
|
|
// Match
|
|
|
|
ONNXMaxPoolSingleOutOp poolOp = llvm::dyn_cast<ONNXMaxPoolSingleOutOp>(op);
|
|
|
|
|
|
|
|
// Read kernel_shape attribute
|
|
|
|
SmallVector<int, 4> kernelShape;
|
|
|
|
auto kernelShapeAttribute = poolOp.kernel_shapeAttr();
|
|
|
|
for (auto dim : kernelShapeAttribute.getValue())
|
|
|
|
kernelShape.emplace_back(dim.cast<IntegerAttr>().getInt());
|
|
|
|
|
|
|
|
// Read strides attribute
|
|
|
|
SmallVector<int, 4> strides;
|
|
|
|
auto stridesAttribute = poolOp.stridesAttr();
|
|
|
|
for (auto stride : stridesAttribute.getValue())
|
|
|
|
strides.emplace_back(stride.cast<IntegerAttr>().getInt());
|
|
|
|
|
|
|
|
// Read ceil_mode attribute
|
|
|
|
auto ceilMode = poolOp.ceil_mode().getSExtValue();
|
|
|
|
|
|
|
|
// Read pads attribute
|
|
|
|
SmallVector<int, 4> pads;
|
|
|
|
auto padsAttribute = poolOp.padsAttr();
|
|
|
|
for (auto pad : padsAttribute.getValue())
|
|
|
|
pads.emplace_back(pad.cast<IntegerAttr>().getInt());
|
|
|
|
|
|
|
|
// Read dilations attribute
|
|
|
|
SmallVector<int, 4> dilations;
|
|
|
|
auto dilationsAttribute = poolOp.dilationsAttr();
|
|
|
|
for (auto dilation : dilationsAttribute.getValue())
|
|
|
|
dilations.emplace_back(dilation.cast<IntegerAttr>().getInt());
|
|
|
|
|
|
|
|
// Type information about the input and result of this operation.
|
2020-03-31 23:55:27 +08:00
|
|
|
auto inputOperand = operandAdaptor.X();
|
2020-03-05 03:27:21 +08:00
|
|
|
auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
|
|
|
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
|
|
|
auto resultShape = memRefType.getShape();
|
|
|
|
auto resultElementType = memRefType.getElementType();
|
|
|
|
|
|
|
|
// Batch indices: N and C dimensions
|
|
|
|
int batchRank = 2;
|
|
|
|
|
|
|
|
// Insert an allocation and deallocation for the result of this operation.
|
|
|
|
Value alloc;
|
|
|
|
bool insertDealloc = checkInsertDealloc(op);
|
|
|
|
|
|
|
|
if (hasAllConstantDimensions(memRefType))
|
|
|
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
|
|
|
else {
|
|
|
|
// Compute dimensions of the result of this operation.
|
|
|
|
SmallVector<Value, 2> allocOperands;
|
|
|
|
for (int i = 0; i < batchRank; ++i) {
|
|
|
|
if (resultShape[i] < 0) {
|
|
|
|
auto dim = rewriter.create<DimOp>(loc, inputOperand, i);
|
|
|
|
allocOperands.emplace_back(dim);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
Value zero, one;
|
|
|
|
if (ceilMode) {
|
|
|
|
zero = rewriter.create<ConstantOp>(
|
|
|
|
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
|
|
|
}
|
|
|
|
one = rewriter.create<ConstantOp>(
|
|
|
|
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));
|
|
|
|
|
|
|
|
int spatialRank = resultShape.size() - batchRank;
|
|
|
|
for (int i = batchRank; i < resultShape.size(); ++i) {
|
|
|
|
if (resultShape[i] < 0) {
|
|
|
|
// dim =
|
2020-03-18 21:55:50 +08:00
|
|
|
// let numerator = (input + pad - (kernel - 1) * dilation - 1)
|
2020-03-05 03:27:21 +08:00
|
|
|
// in let denomitor = stride
|
|
|
|
// in
|
|
|
|
// if (ceilMode)
|
|
|
|
// ceil(numerator / denominator) + 1
|
|
|
|
// else
|
|
|
|
// floor(numerator / denominator) + 1
|
|
|
|
int spatialIndex = i - batchRank;
|
|
|
|
|
2020-03-18 21:55:50 +08:00
|
|
|
// numerator = (input + pad - (kernel - 1) * dilation - 1)
|
2020-03-05 03:27:21 +08:00
|
|
|
auto inputDim = rewriter.create<DimOp>(loc, inputOperand, i);
|
|
|
|
auto inputVal = rewriter.create<IndexCastOp>(
|
|
|
|
loc, inputDim, rewriter.getIntegerType(64));
|
|
|
|
int64_t padKernelDilation =
|
|
|
|
(pads[spatialIndex] + pads[spatialIndex + spatialRank]) -
|
2020-03-18 21:55:50 +08:00
|
|
|
(kernelShape[spatialIndex] - 1) * dilations[spatialIndex] - 1;
|
2020-03-05 03:27:21 +08:00
|
|
|
auto padKernelDilationVal = rewriter.create<ConstantOp>(
|
|
|
|
loc, rewriter.getIntegerAttr(
|
|
|
|
rewriter.getIntegerType(64), padKernelDilation));
|
|
|
|
auto numeratorVal =
|
|
|
|
rewriter.create<AddIOp>(loc, inputVal, padKernelDilationVal);
|
|
|
|
// denominator
|
|
|
|
auto denominatorVal = rewriter.create<ConstantOp>(
|
|
|
|
loc, rewriter.getIntegerAttr(
|
|
|
|
rewriter.getIntegerType(64), strides[spatialIndex]));
|
|
|
|
|
|
|
|
// numerator / denominator
|
|
|
|
Value dimVal =
|
|
|
|
rewriter.create<SignedDivIOp>(loc, numeratorVal, denominatorVal);
|
|
|
|
|
|
|
|
if (ceilMode) {
|
|
|
|
auto remainder = rewriter.create<SignedRemIOp>(
|
|
|
|
loc, numeratorVal, denominatorVal);
|
|
|
|
auto isZero = rewriter.create<CmpIOp>(
|
|
|
|
loc, CmpIPredicate::eq, remainder, zero);
|
|
|
|
auto dimPlusOne = rewriter.create<AddIOp>(loc, dimVal, one);
|
|
|
|
dimVal = rewriter.create<SelectOp>(loc, isZero, dimVal, dimPlusOne);
|
|
|
|
}
|
|
|
|
|
|
|
|
dimVal = rewriter.create<AddIOp>(loc, dimVal, one);
|
|
|
|
allocOperands.emplace_back(rewriter.create<IndexCastOp>(
|
|
|
|
loc, dimVal, rewriter.getIndexType()));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
alloc = rewriter.create<AllocOp>(loc, memRefType, allocOperands);
|
|
|
|
if (insertDealloc) {
|
|
|
|
auto *parentBlock = alloc.getDefiningOp()->getBlock();
|
|
|
|
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
|
|
|
|
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// R = MaxPool(D)
|
|
|
|
//
|
|
|
|
// The input/output shapes will look like this:
|
|
|
|
//
|
|
|
|
// D (NxCxHxW) -> R (NxCxRHxRW)
|
|
|
|
//
|
|
|
|
// The loop nest will look as follows:
|
|
|
|
//
|
|
|
|
// strides = [s1, s2]
|
|
|
|
//
|
|
|
|
// for n = 0 .. N:
|
|
|
|
// for c = 0 .. C:
|
|
|
|
// for r1 = 0 .. RH:
|
|
|
|
// for r2 = 0 .. RW:
|
|
|
|
// R[n][c][r1][r2] = negative_infinity;
|
|
|
|
// for k1 = 0 .. KH:
|
|
|
|
// for k2 = 0 .. KW:
|
2020-03-18 21:55:50 +08:00
|
|
|
// t = D[n][c][s1 * r1 + k1 * d1][s2 * r2 + k2 * d2];
|
2020-03-05 03:27:21 +08:00
|
|
|
// R[n][c][r1][r2] = max(R[n][c][r1][r2], t);
|
|
|
|
//
|
|
|
|
// Naming:
|
|
|
|
// n, c, r1, r2: outer loop nest indices
|
|
|
|
// k1, k2: inner loop nest indices
|
2020-03-18 21:55:50 +08:00
|
|
|
// s1, s2: strides
|
|
|
|
// d1, d2: dilations
|
2020-03-05 03:27:21 +08:00
|
|
|
//
|
|
|
|
// TODO: handle padding.
|
|
|
|
//
|
|
|
|
|
|
|
|
// 1. Define outer loops and emit empty optimization block.
|
|
|
|
auto nOuterLoops = resultShape.size();
|
|
|
|
BuildKrnlLoop outerLoops(rewriter, loc, nOuterLoops);
|
|
|
|
outerLoops.createDefineOptimizeAndIterateOp(alloc);
|
|
|
|
|
|
|
|
rewriter.setInsertionPointToStart(outerLoops.getIterateBlock());
|
|
|
|
{
|
|
|
|
// 2. Emit the body of the outer loop nest.
|
|
|
|
SmallVector<Value, 4> resultIndices;
|
|
|
|
for (int i = 0; i < nOuterLoops; ++i)
|
|
|
|
resultIndices.emplace_back(outerLoops.getInductionVar(i));
|
|
|
|
|
|
|
|
// 2.1 Emit: R[n][c][r1][r2] = negative_infinity;
|
2020-03-06 03:21:00 +08:00
|
|
|
Value identity = getIdentityValue<ONNXMaxPoolSingleOutOp>(
|
|
|
|
rewriter, loc, resultElementType);
|
2020-03-05 03:27:21 +08:00
|
|
|
rewriter.create<StoreOp>(loc, identity, alloc, resultIndices);
|
|
|
|
|
|
|
|
// 2.2 Define inner loops.
|
|
|
|
int nInnerLoops = kernelShape.size();
|
|
|
|
BuildKrnlLoop innerLoops(rewriter, loc, nInnerLoops);
|
|
|
|
innerLoops.createDefineAndOptimizeOp();
|
|
|
|
// for Kx = 0 .. KX
|
|
|
|
for (int i = 0; i < nInnerLoops; ++i)
|
|
|
|
innerLoops.pushBounds(0, kernelShape[i]);
|
|
|
|
|
|
|
|
// 2.3 Emit inner loop nest.
|
|
|
|
innerLoops.createIterateOp();
|
|
|
|
rewriter.setInsertionPointToStart(innerLoops.getIterateBlock());
|
|
|
|
{
|
|
|
|
// 3. Emit inner loop body
|
2020-03-18 21:55:50 +08:00
|
|
|
// t = D[n][c][s1 * r1 + k1 * d1][s2 * r2 + k2 * d2];
|
2020-03-05 03:27:21 +08:00
|
|
|
// R[n][c][r1][r2] = max(R[n][c][r1][r2], t);
|
|
|
|
|
|
|
|
// 3.1 Prepare indices for accesing the data tensor.
|
|
|
|
SmallVector<Value, 4> dataIndices;
|
2020-03-18 21:55:50 +08:00
|
|
|
// 3.1.1 Batch indices: n, c
|
2020-03-05 03:27:21 +08:00
|
|
|
for (int i = 0; i < batchRank; ++i)
|
|
|
|
dataIndices.emplace_back(outerLoops.getInductionVar(i));
|
2020-03-18 21:55:50 +08:00
|
|
|
// 3.1.2 Insert spatial indices: sX * rX + kX * dX
|
2020-03-05 03:27:21 +08:00
|
|
|
for (int i = batchRank; i < nOuterLoops; ++i) {
|
2020-03-18 21:55:50 +08:00
|
|
|
// Get index along the inner loop's induction variables.
|
|
|
|
// It is used to obtain kernel/pad/stride/dilation index.
|
|
|
|
int j = i - batchRank;
|
|
|
|
|
2020-03-05 03:27:21 +08:00
|
|
|
Value spatialIndex = outerLoops.getInductionVar(i);
|
2020-03-18 21:55:50 +08:00
|
|
|
// If strides are present (not default) then emit the correct access
|
|
|
|
// index.
|
|
|
|
// sX *= rX
|
|
|
|
if (strides[i - batchRank] > 1) {
|
|
|
|
auto strideIndex = emitConstantOp(
|
|
|
|
rewriter, loc, rewriter.getIndexType(), strides[j]);
|
|
|
|
spatialIndex = rewriter.create<MulIOp>(
|
|
|
|
loc, strideIndex, outerLoops.getInductionVar(i));
|
|
|
|
}
|
|
|
|
|
|
|
|
// Dilate the kernel index only if the dilation value is not one (not
|
|
|
|
// default). Otherwise, just add kernelIndex.
|
|
|
|
auto kernelIndex = innerLoops.getInductionVar(j);
|
|
|
|
if (dilations[j] > 1) {
|
|
|
|
// sX += dX * kW
|
|
|
|
auto dilationIndex = emitConstantOp(
|
|
|
|
rewriter, loc, rewriter.getIndexType(), dilations[j]);
|
|
|
|
auto dilationKernelIndex =
|
|
|
|
rewriter.create<MulIOp>(loc, dilationIndex, kernelIndex);
|
|
|
|
spatialIndex =
|
|
|
|
rewriter.create<AddIOp>(loc, spatialIndex, dilationKernelIndex);
|
|
|
|
} else {
|
|
|
|
// sX += kX
|
|
|
|
spatialIndex =
|
|
|
|
rewriter.create<AddIOp>(loc, spatialIndex, kernelIndex);
|
2020-03-05 03:27:21 +08:00
|
|
|
}
|
2020-03-18 21:55:50 +08:00
|
|
|
|
|
|
|
// If ceil mode or dilation is enabled, then the calculated access
|
|
|
|
// index may exceed its dimension. In such a case, we will use the
|
|
|
|
// maximum index, which causes multiple visits to the element of the
|
2020-03-05 03:27:21 +08:00
|
|
|
// maximum index.
|
|
|
|
// TODO: Avoid multiple visits.
|
2020-03-18 21:55:50 +08:00
|
|
|
// Example of out-of-bound.
|
|
|
|
// - Given a 5x5 input X
|
|
|
|
// X = [[0, 0, 0, 0, 0],
|
|
|
|
// [1, 1, 1, 1, 1],
|
|
|
|
// [2, 2, 2, 2, 2],
|
|
|
|
// [3, 3, 3, 3, 3],
|
|
|
|
// [4, 4, 4, 4, 4]]
|
|
|
|
// - Do MaxPool with strides=[2, 2], kernel=[2, 2], ceilMode=true,
|
|
|
|
// output is a 3x3 array:
|
|
|
|
// Y = [[1, 1, 1],
|
|
|
|
// [3, 3, 3],
|
|
|
|
// [4, 4, 4]]
|
|
|
|
// - When computing Y[2, 0]:
|
|
|
|
// - In case of kernelIndex = 1, stride = 2
|
|
|
|
// - No dilation: spatialIndex = 2 * 2 + 1 = 5
|
|
|
|
// => out of bound
|
|
|
|
// - dilation = 2: spatialIndex = 2 * 2 + 2 * 1 = 6
|
|
|
|
// => out of bound
|
|
|
|
if (dilations[j] > 1 or ceilMode) {
|
|
|
|
Value upperIndex;
|
2020-03-05 03:27:21 +08:00
|
|
|
if (inputShape[i] < 0) {
|
|
|
|
Value inputDim = rewriter.create<DimOp>(loc, inputOperand, i);
|
|
|
|
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
|
2020-03-18 21:55:50 +08:00
|
|
|
upperIndex = rewriter.create<SubIOp>(loc, inputDim, one);
|
2020-03-05 03:27:21 +08:00
|
|
|
} else {
|
2020-03-18 21:55:50 +08:00
|
|
|
upperIndex =
|
2020-03-05 03:27:21 +08:00
|
|
|
rewriter.create<ConstantIndexOp>(loc, inputShape[i] - 1);
|
|
|
|
}
|
|
|
|
auto greaterCondition = rewriter.create<CmpIOp>(
|
2020-03-18 21:55:50 +08:00
|
|
|
loc, CmpIPredicate::sgt, spatialIndex, upperIndex);
|
2020-03-05 03:27:21 +08:00
|
|
|
spatialIndex = rewriter.create<SelectOp>(
|
2020-03-18 21:55:50 +08:00
|
|
|
loc, greaterCondition, upperIndex, spatialIndex);
|
2020-03-05 03:27:21 +08:00
|
|
|
}
|
2020-03-18 21:55:50 +08:00
|
|
|
|
2020-03-05 03:27:21 +08:00
|
|
|
dataIndices.emplace_back(spatialIndex);
|
|
|
|
}
|
|
|
|
|
|
|
|
// 3.2 Do pooling.
|
|
|
|
auto loadData = rewriter.create<LoadOp>(loc, inputOperand, dataIndices);
|
|
|
|
auto loadPartialResult =
|
|
|
|
rewriter.create<LoadOp>(loc, alloc, resultIndices);
|
|
|
|
Value result = mapToLowerScalarOp<ONNXMaxPoolSingleOutOp>(
|
|
|
|
op, resultElementType, {loadPartialResult, loadData}, rewriter);
|
|
|
|
rewriter.create<StoreOp>(loc, result, alloc, resultIndices);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
rewriter.replaceOp(op, alloc);
|
|
|
|
|
|
|
|
return matchSuccess();
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
void populateLoweringONNXPoolingOpPattern(
|
|
|
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
|
|
|
patterns.insert<ONNXMaxPoolSingleOutOpLowering>(ctx);
|
|
|
|
}
|