//===----- reduction.cpp - Lowering Reduction Ops -------------------------===// // // Copyright 2019 The IBM Research Authors. // // ============================================================================= // // This file lowers the ONNX Reduction Operators to Krnl dialect. // //===----------------------------------------------------------------------===// #include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp" using namespace mlir; // Identity values template <> float getIdentityValue(){ return (float)-std::numeric_limits::infinity(); } template <> int getIdentityValue(){ return std::numeric_limits::min(); } template <> float getIdentityValue(){ return (float)std::numeric_limits::infinity(); } template <> int getIdentityValue(){ return std::numeric_limits::max(); } template <> float getIdentityValue(){ return (float)1.0; } template <> int getIdentityValue(){ return 1; } template <> float getIdentityValue(){ return (float)0; } template <> int getIdentityValue(){ return 0; } // Scalar ops template <> struct ScalarOp { using FOp = MulFOp; using IOp = MulIOp; }; template <> struct ScalarOp { using FOp = AddFOp; using IOp = AddIOp; }; //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXReduceMaxOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { auto loc = op->getLoc(); Value lhs = operands[0]; Value rhs = operands[1]; Type element_type = lhs.getType(); if (element_type.isa()) { auto max = rewriter.create(loc, CmpIPredicate::sgt, lhs, rhs); auto result = rewriter.create(loc, max, lhs, rhs); return result; } else if (element_type.isa()) { auto max = rewriter.create(loc, CmpFPredicate::OGT, lhs, rhs); auto result = rewriter.create(loc, max, lhs, rhs); return result; } else { emitError(loc, "unsupported element type"); return nullptr; } } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXReduceMinOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { auto loc = op->getLoc(); Value lhs = operands[0]; Value rhs = operands[1]; Type element_type = lhs.getType(); if (element_type.isa()) { auto min = rewriter.create(loc, CmpIPredicate::slt, lhs, rhs); auto result = rewriter.create(loc, min, lhs, rhs); return result; } else if (element_type.isa()) { auto min = rewriter.create(loc, CmpFPredicate::OLT, lhs, rhs); auto result = rewriter.create(loc, min, lhs, rhs); return result; } else { emitError(loc, "unsupported element type"); return nullptr; } } template struct ONNXReductionOpLowering : public ConversionPattern { ONNXReductionOpLowering(MLIRContext *ctx) : ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { /* * Condition: reduction function must be associative and commutative. * * Example 1 (here, reduction function is `+`): * Induction variables: (i0, i1, i2) * axes = [0, 2] * keepdims = true * krnl.iterate() with (i0, i1, i2) { * Y(0, i1, 0) += X(i0, i1, i2) * } * * Example 2 (here, reduction function is `+`): * Induction variables: (i0, i1, i2) * axes = [0, 2] * keepdims = false * krnl.iterate() with (i0, i1, i2) { * Y(i1) += X(i0, i1, i2) * } * */ auto loc = op->getLoc(); auto memRefInType = operands[0].getType().cast(); auto memRefInShape = memRefInType.getShape(); auto memRefOutType = convertToMemRefType(*op->result_type_begin()); int64_t inRank = memRefInType.getRank(); int64_t outRank = memRefOutType.getRank(); // Get attributes ArrayAttr axisAttrs = llvm::dyn_cast(op).axesAttr(); std::vector axes; if (axisAttrs) { for (auto axisAttr : axisAttrs.getValue()) { int64_t axis = axisAttr.cast().getInt(); axis = axis >= 0 ? axis : (inRank + axis); assert(axis >= -inRank && axis <= inRank - 1); if (std::find(axes.begin(), axes.end(), axis) == axes.end()) axes.push_back(axis); } } else { for (decltype(inRank) i = 0; i < inRank; ++i) { axes.push_back(i); } } // KeepDims auto keepdims = llvm::dyn_cast(op).keepdims(); bool isKeepdims = (keepdims == 1) ? true : false; // Get type information auto memRefOutShape = memRefOutType.getShape(); auto elementOutType = memRefOutType.getElementType(); std::map outInDimMap = getReductionMapping(memRefInType, axes, isKeepdims); // Insert an allocation and deallocation for the result of this operation. Value alloc; bool insertDealloc = checkInsertDealloc(op); if (hasAllConstantDimensions(memRefOutType)) { alloc = insertAllocAndDealloc(memRefOutType, loc, rewriter, insertDealloc); } else { SmallVector allocOperands; for (decltype(outRank) i = 0; i < outRank; ++i) { if (memRefOutShape[i] < 0) { auto dim = rewriter.create(loc, operands[0], outInDimMap[i]); allocOperands.push_back(dim); } } alloc = rewriter.create(loc, memRefOutType, allocOperands); if (insertDealloc) { auto *parentBlock = alloc.getDefiningOp()->getBlock(); auto dealloc = rewriter.create(loc, alloc); dealloc.getOperation()->moveBefore(&parentBlock->back()); } } // There are two Krnl loops: // - One to initialize the result memref, and // - One to do reduction // Define loops to initialize the result. std::vector originalLoopsInit; std::vector optimizedLoopsInit; Block *optimizationBlockInit = defineLoops(rewriter, loc, originalLoopsInit, optimizedLoopsInit, outRank); // Iteration information KrnlIterateOperandPack packInit(rewriter, originalLoopsInit, optimizedLoopsInit); for (decltype(outRank) i = 0; i < outRank; ++i) { addDimensionToPack(rewriter, loc, packInit, alloc, i); } auto iterateOpInit = rewriter.create(loc, packInit); Block &iterationBlockInit = iterateOpInit.bodyRegion().front(); // Perform the insertions into the body of the initialization loop. // No optimization rewriter.setInsertionPointToEnd(optimizationBlockInit); rewriter.create(loc, originalLoopsInit); // Insert instructions inside the KernelIterateOp body. rewriter.setInsertionPointToStart(&iterationBlockInit); // Handle the operation: SmallVector loopIVs; for (auto arg : iterationBlockInit.getArguments()) { loopIVs.push_back(arg); } Value identity; if (elementOutType.isa()) { identity = rewriter.create( loc, FloatAttr::get(elementOutType, getIdentityValue())); } else if (elementOutType.isa()) { identity = rewriter.create( loc, IntegerAttr::get(elementOutType, getIdentityValue())); } else { emitError(loc, "unsupported element type"); } rewriter.create(loc, identity, alloc, loopIVs); // Define an Krnl loop to do reduction. rewriter.setInsertionPointAfter(iterateOpInit); std::vector originalLoops, optimizedLoops; Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, optimizedLoops, inRank); // Iteration information KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); for (decltype(inRank) i = 0; i < inRank; ++i) { addDimensionToPack(rewriter, loc, pack, operands[0], i); } auto iterateOp = rewriter.create(loc, pack); Block &iterationBlock = iterateOp.bodyRegion().front(); // Perform the insertions into the body of the reduction loop. // No optimization rewriter.setInsertionPointToEnd(optimizationBlock); rewriter.create(loc, originalLoops); // Insert instructions inside the KernelIterateOp body. rewriter.setInsertionPointToStart(&iterationBlock); // Handle the operation: SmallVector inLoopIVs, outLoopIVs; auto args = iterationBlock.getArguments(); for (int i = 0; i < args.size(); ++i) { inLoopIVs.push_back(args[i]); } Value zeroIndex = nullptr; for (decltype(inRank) i = 0; i < outRank; ++i) { if (outInDimMap.find(i) != outInDimMap.end()) { outLoopIVs.push_back(inLoopIVs[outInDimMap[i]]); } else { if (zeroIndex) { outLoopIVs.push_back(zeroIndex); } else { zeroIndex = rewriter.create(loc, 0); outLoopIVs.push_back(zeroIndex); } } } Value next, accumulated; next = rewriter.create(loc, operands[0], inLoopIVs); accumulated = rewriter.create(loc, alloc, outLoopIVs); accumulated = mapToLowerScalarOp( op, memRefOutType.getElementType(), {accumulated, next}, rewriter); rewriter.create(loc, accumulated, alloc, outLoopIVs); rewriter.replaceOp(op, alloc); return matchSuccess(); } }; void populateLoweringONNXReductionOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx) { patterns.insert, ONNXReductionOpLowering, ONNXReductionOpLowering, ONNXReductionOpLowering>(ctx); }