onnx-mlir/src/conversion/onnx_to_krnl/math/reduction.cpp

311 lines
11 KiB
C++

//===----- reduction.cpp - Lowering Reduction Ops -------------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers the ONNX Reduction Operators to Krnl dialect.
//
//===----------------------------------------------------------------------===//
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
using namespace mlir;
// Identity values
template <>
float getIdentityValue<float, ONNXReduceMaxOp>(){
return (float)-std::numeric_limits<float>::infinity();
}
template <>
int getIdentityValue<int, ONNXReduceMaxOp>(){
return std::numeric_limits<int>::min();
}
template <>
float getIdentityValue<float, ONNXReduceMinOp>(){
return (float)std::numeric_limits<float>::infinity();
}
template <>
int getIdentityValue<int, ONNXReduceMinOp>(){
return std::numeric_limits<int>::max();
}
template <>
float getIdentityValue<float, ONNXReduceProdOp>(){
return (float)1.0;
}
template <>
int getIdentityValue<int, ONNXReduceProdOp>(){
return 1;
}
template <>
float getIdentityValue<float, ONNXReduceSumOp>(){
return (float)0;
}
template <>
int getIdentityValue<int, ONNXReduceSumOp>(){
return 0;
}
// Scalar ops
template <>
struct ScalarOp<ONNXReduceProdOp> {
using FOp = MulFOp;
using IOp = MulIOp;
};
template <>
struct ScalarOp<ONNXReduceSumOp> {
using FOp = AddFOp;
using IOp = AddIOp;
};
//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXReduceMaxOp
//===----------------------------------------------------------------------===//
template <>
Value mapToLowerScalarOp<ONNXReduceMaxOp>(Operation *op,
ArrayRef<Type> result_types,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
auto loc = op->getLoc();
Value lhs = operands[0];
Value rhs = operands[1];
Type element_type = lhs.getType();
if (element_type.isa<IntegerType>()) {
auto max = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs, rhs);
auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
return result;
} else if (element_type.isa<FloatType>()) {
auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
return result;
} else {
emitError(loc, "unsupported element type");
return nullptr;
}
}
//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXReduceMinOp
//===----------------------------------------------------------------------===//
template <>
Value mapToLowerScalarOp<ONNXReduceMinOp>(Operation *op,
ArrayRef<Type> result_types,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
auto loc = op->getLoc();
Value lhs = operands[0];
Value rhs = operands[1];
Type element_type = lhs.getType();
if (element_type.isa<IntegerType>()) {
auto min = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, lhs, rhs);
auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
return result;
} else if (element_type.isa<FloatType>()) {
auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs);
auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
return result;
} else {
emitError(loc, "unsupported element type");
return nullptr;
}
}
template <typename ONNXReductionOp>
struct ONNXReductionOpLowering : public ConversionPattern {
ONNXReductionOpLowering(MLIRContext *ctx)
: ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
/*
* Condition: reduction function must be associative and commutative.
*
* Example 1 (here, reduction function is `+`):
* Induction variables: (i0, i1, i2)
* axes = [0, 2]
* keepdims = true
* krnl.iterate() with (i0, i1, i2) {
* Y(0, i1, 0) += X(i0, i1, i2)
* }
*
* Example 2 (here, reduction function is `+`):
* Induction variables: (i0, i1, i2)
* axes = [0, 2]
* keepdims = false
* krnl.iterate() with (i0, i1, i2) {
* Y(i1) += X(i0, i1, i2)
* }
*
*/
auto loc = op->getLoc();
auto memRefInType = operands[0].getType().cast<MemRefType>();
auto memRefInShape = memRefInType.getShape();
auto memRefOutType = convertToMemRefType(*op->result_type_begin());
int64_t inRank = memRefInType.getRank();
int64_t outRank = memRefOutType.getRank();
// Get attributes
ArrayAttr axisAttrs = llvm::dyn_cast<ONNXReductionOp>(op).axesAttr();
std::vector<int64_t> axes;
if (axisAttrs) {
for (auto axisAttr : axisAttrs.getValue()) {
int64_t axis = axisAttr.cast<IntegerAttr>().getInt();
axis = axis >= 0 ? axis : (inRank + axis);
assert(axis >= -inRank && axis <= inRank - 1);
if (std::find(axes.begin(), axes.end(), axis) == axes.end())
axes.push_back(axis);
}
} else {
for (decltype(inRank) i = 0; i < inRank; ++i) {
axes.push_back(i);
}
}
// KeepDims
auto keepdims =
llvm::dyn_cast<ONNXReductionOp>(op).keepdims();
bool isKeepdims = (keepdims == 1) ? true : false;
// Get type information
auto memRefOutShape = memRefOutType.getShape();
auto elementOutType = memRefOutType.getElementType();
std::map<int64_t, int64_t> outInDimMap =
getReductionMapping(memRefInType, axes, isKeepdims);
// Insert an allocation and deallocation for the result of this operation.
Value alloc;
bool insertDealloc = checkInsertDealloc(op);
if (hasAllConstantDimensions(memRefOutType)) {
alloc = insertAllocAndDealloc(memRefOutType, loc, rewriter, insertDealloc);
} else {
SmallVector<Value, 2> allocOperands;
for (decltype(outRank) i = 0; i < outRank; ++i) {
if (memRefOutShape[i] < 0) {
auto dim = rewriter.create<DimOp>(loc, operands[0], outInDimMap[i]);
allocOperands.push_back(dim);
}
}
alloc = rewriter.create<AllocOp>(loc, memRefOutType, allocOperands);
if (insertDealloc) {
auto *parentBlock = alloc.getDefiningOp()->getBlock();
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
dealloc.getOperation()->moveBefore(&parentBlock->back());
}
}
// There are two Krnl loops:
// - One to initialize the result memref, and
// - One to do reduction
// Define loops to initialize the result.
std::vector<Value> originalLoopsInit;
std::vector<Value> optimizedLoopsInit;
Block *optimizationBlockInit = defineLoops(rewriter, loc, originalLoopsInit,
optimizedLoopsInit, outRank);
// Iteration information
KrnlIterateOperandPack packInit(rewriter, originalLoopsInit,
optimizedLoopsInit);
for (decltype(outRank) i = 0; i < outRank; ++i) {
addDimensionToPack(rewriter, loc, packInit, alloc, i);
}
auto iterateOpInit = rewriter.create<KrnlIterateOp>(loc, packInit);
Block &iterationBlockInit = iterateOpInit.bodyRegion().front();
// Perform the insertions into the body of the initialization loop.
// No optimization
rewriter.setInsertionPointToEnd(optimizationBlockInit);
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoopsInit);
// Insert instructions inside the KernelIterateOp body.
rewriter.setInsertionPointToStart(&iterationBlockInit);
// Handle the operation:
SmallVector<Value, 4> loopIVs;
for (auto arg : iterationBlockInit.getArguments()) {
loopIVs.push_back(arg);
}
Value identity;
if (elementOutType.isa<FloatType>()) {
identity = rewriter.create<ConstantOp>(
loc, FloatAttr::get(elementOutType,
getIdentityValue<float, ONNXReductionOp>()));
} else if (elementOutType.isa<IntegerType>()) {
identity = rewriter.create<ConstantOp>(
loc, IntegerAttr::get(elementOutType,
getIdentityValue<int, ONNXReductionOp>()));
} else {
emitError(loc, "unsupported element type");
}
rewriter.create<StoreOp>(loc, identity, alloc, loopIVs);
// Define an Krnl loop to do reduction.
rewriter.setInsertionPointAfter(iterateOpInit);
std::vector<Value> originalLoops, optimizedLoops;
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
optimizedLoops, inRank);
// Iteration information
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
for (decltype(inRank) i = 0; i < inRank; ++i) {
addDimensionToPack(rewriter, loc, pack, operands[0], i);
}
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
Block &iterationBlock = iterateOp.bodyRegion().front();
// Perform the insertions into the body of the reduction loop.
// No optimization
rewriter.setInsertionPointToEnd(optimizationBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
// Insert instructions inside the KernelIterateOp body.
rewriter.setInsertionPointToStart(&iterationBlock);
// Handle the operation:
SmallVector<Value, 4> inLoopIVs, outLoopIVs;
auto args = iterationBlock.getArguments();
for (int i = 0; i < args.size(); ++i) {
inLoopIVs.push_back(args[i]);
}
Value zeroIndex = nullptr;
for (decltype(inRank) i = 0; i < outRank; ++i) {
if (outInDimMap.find(i) != outInDimMap.end()) {
outLoopIVs.push_back(inLoopIVs[outInDimMap[i]]);
} else {
if (zeroIndex) {
outLoopIVs.push_back(zeroIndex);
} else {
zeroIndex = rewriter.create<ConstantIndexOp>(loc, 0);
outLoopIVs.push_back(zeroIndex);
}
}
}
Value next, accumulated;
next = rewriter.create<LoadOp>(loc, operands[0], inLoopIVs);
accumulated = rewriter.create<LoadOp>(loc, alloc, outLoopIVs);
accumulated = mapToLowerScalarOp<ONNXReductionOp>(
op, memRefOutType.getElementType(), {accumulated, next}, rewriter);
rewriter.create<StoreOp>(loc, accumulated, alloc, outLoopIVs);
rewriter.replaceOp(op, alloc);
return matchSuccess();
}
};
void populateLoweringONNXReductionOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<ONNXReductionOpLowering<mlir::ONNXReduceMaxOp>,
ONNXReductionOpLowering<mlir::ONNXReduceMinOp>,
ONNXReductionOpLowering<mlir::ONNXReduceProdOp>,
ONNXReductionOpLowering<mlir::ONNXReduceSumOp>>(ctx);
}