From 4e66488ad336bbee2a54c68754935fbd01e3b968 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Thu, 9 Apr 2020 17:06:56 +0900 Subject: [PATCH] Change the name and signature of mapToLowerScalarOp (#67) * Revise mapToLowerScalarOp() * Update TanhOp Co-authored-by: Tian Jin --- .../ONNXToKrnl/Math/Elementwise.cpp | 259 ++++++++---------- src/Conversion/ONNXToKrnl/Math/Reduction.cpp | 59 ++-- src/Conversion/ONNXToKrnl/NN/Pooling.cpp | 15 +- .../ONNXToKrnl/ONNXToKrnlCommon.hpp | 20 +- 4 files changed, 153 insertions(+), 200 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp index 222c538..be9d031 100644 --- a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp @@ -88,17 +88,15 @@ struct ScalarOp { // Scalar unary ops for lowering ONNXSinhOp //===----------------------------------------------------------------------===// template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { +Value emitScalarOpFor(ConversionPatternRewriter &rewriter, + Location loc, Operation *op, Type elementType, + ArrayRef scalarOperands) { // ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ConstantOp 2) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; + Value operand = scalarOperands[0]; auto zero = emitConstantOp(rewriter, loc, elementType, 0); - auto two = emitConstantOp(rewriter, loc, elementType, 2); + auto two = emitConstantOp(rewriter, loc, elementType, 2); auto neg = rewriter.create(loc, zero, operand); auto exp = rewriter.create(loc, operand); auto negExp = rewriter.create(loc, neg); @@ -112,17 +110,15 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, // Scalar unary ops for lowering ONNXCoshOp //===----------------------------------------------------------------------===// template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { +Value emitScalarOpFor(ConversionPatternRewriter &rewriter, + Location loc, Operation *op, Type elementType, + ArrayRef scalarOperands) { // ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ConstantOp 2) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; + Value operand = scalarOperands[0]; auto zero = emitConstantOp(rewriter, loc, elementType, 0); - auto two = emitConstantOp(rewriter, loc, elementType, 2); + auto two = emitConstantOp(rewriter, loc, elementType, 2); auto neg = rewriter.create(loc, zero, operand); auto exp = rewriter.create(loc, operand); auto negExp = rewriter.create(loc, neg); @@ -136,14 +132,12 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, // Scalar unary ops for lowering ONNXTanhOp //===----------------------------------------------------------------------===// template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { +Value emitScalarOpFor(ConversionPatternRewriter &rewriter, + Location loc, Operation *op, Type elementType, + ArrayRef scalarOperands) { // ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // AddFOp(ExpOp(%X), ExpOp(NegFOp(%X)))) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; + Value operand = scalarOperands[0]; auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto neg = rewriter.create(loc, zero, operand); @@ -160,15 +154,12 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, // Scalar unary ops for lowering ONNXSigmoidOp //===----------------------------------------------------------------------===// template <> -Value mapToLowerScalarOp(Operation *op, - ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { +Value emitScalarOpFor(ConversionPatternRewriter &rewriter, + Location loc, Operation *op, Type elementType, + ArrayRef scalarOperands) { // ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1, // AddFOp(ConstantOp 1, ExpOp(NegFOp(%X)))) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; + Value operand = scalarOperands[0]; auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto one = emitConstantOp(rewriter, loc, elementType, 1); @@ -184,9 +175,9 @@ Value mapToLowerScalarOp(Operation *op, // Scalar unary ops for lowering ONNXHardSigmoidOp //===----------------------------------------------------------------------===// template <> -Value mapToLowerScalarOp( - Operation *op, ArrayRef result_types, ArrayRef operands, - ConversionPatternRewriter &rewriter) { +Value emitScalarOpFor(ConversionPatternRewriter &rewriter, + Location loc, Operation *op, Type elementType, + ArrayRef scalarOperands) { // %Y = AddFOp(MulFOp(alpha, %X), beta) // %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0), // %Y, @@ -194,13 +185,11 @@ Value mapToLowerScalarOp( // ONNXHardSigmoidOp(%X) = SelectOp(CmpFOp(OLT, %Z, Constant 1), // %Z, // Constant 1) - auto loc = op->getLoc(); - Value operand = operands[0]; + Value operand = scalarOperands[0]; auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), llvm::dyn_cast(op).alpha().convertToFloat()); auto betaAttribute = FloatAttr::get(rewriter.getF32Type(), llvm::dyn_cast(op).beta().convertToFloat()); - auto elementType = result_types[0]; auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto one = emitConstantOp(rewriter, loc, elementType, 1); @@ -223,15 +212,13 @@ Value mapToLowerScalarOp( // Scalar unary ops for lowering ONNXEluOp //===----------------------------------------------------------------------===// template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { +Value emitScalarOpFor(ConversionPatternRewriter &rewriter, + Location loc, Operation *op, Type elementType, + ArrayRef scalarOperands) { // ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // MulFOp(alpha, SubFOp(ExpOp(%X), 1)), // %X) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; + Value operand = scalarOperands[0]; auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), llvm::dyn_cast(op).alpha().convertToFloat()); @@ -241,10 +228,9 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, auto exp = rewriter.create(loc, operand); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); - auto result = rewriter.create( - loc, lessThanZero, - rewriter.create(loc, alpha, - rewriter.create(loc, exp, one)), + auto result = rewriter.create(loc, lessThanZero, + rewriter.create( + loc, alpha, rewriter.create(loc, exp, one)), operand); return result; @@ -254,15 +240,13 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, // Scalar unary ops for lowering ONNXReluOp //===----------------------------------------------------------------------===// template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { +Value emitScalarOpFor(ConversionPatternRewriter &rewriter, + Location loc, Operation *op, Type elementType, + ArrayRef scalarOperands) { // ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // ConstantOp 0, // %X) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; + Value operand = scalarOperands[0]; auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto lessThanZero = @@ -276,16 +260,13 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, // Scalar unary ops for lowering ONNXLeakyReluOp //===----------------------------------------------------------------------===// template <> -Value mapToLowerScalarOp(Operation *op, - ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { +Value emitScalarOpFor(ConversionPatternRewriter &rewriter, + Location loc, Operation *op, Type elementType, + ArrayRef scalarOperands) { // ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // MulFOp(alpha, %X), // %X) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; + Value operand = scalarOperands[0]; auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), llvm::dyn_cast(op).alpha().convertToFloat()); @@ -303,21 +284,19 @@ Value mapToLowerScalarOp(Operation *op, // Scalar unary ops for lowering ONNXSeluOp //===----------------------------------------------------------------------===// template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { +Value emitScalarOpFor(ConversionPatternRewriter &rewriter, + Location loc, Operation *op, Type elementType, + ArrayRef scalarOperands) { // ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0), // MulFOp(gamma, %X), // MulFOp(gamma, // SubFOp(MulFOp(alpha, ExpOp(%X)), // alpha))) - auto loc = op->getLoc(); - Value operand = operands[0]; + Value operand = scalarOperands[0]; auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), llvm::dyn_cast(op).alpha().convertToFloat()); auto gammaAttribute = FloatAttr::get(rewriter.getF32Type(), llvm::dyn_cast(op).gamma().convertToFloat()); - auto elementType = result_types[0]; auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto alpha = rewriter.create(loc, alphaAttribute); @@ -325,10 +304,9 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, auto exp = rewriter.create(loc, operand); auto greaterThanZero = rewriter.create(loc, CmpFPredicate::OGT, operand, zero); - auto select = rewriter.create( - loc, greaterThanZero, operand, - rewriter.create(loc, rewriter.create(loc, alpha, exp), - alpha)); + auto select = rewriter.create(loc, greaterThanZero, operand, + rewriter.create( + loc, rewriter.create(loc, alpha, exp), alpha)); auto result = rewriter.create(loc, gamma, select); return result; @@ -338,14 +316,11 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, // Scalar unary ops for lowering ONNXReciprocalOp //===----------------------------------------------------------------------===// template <> -Value mapToLowerScalarOp( - Operation *op, ArrayRef result_types, ArrayRef operands, - ConversionPatternRewriter &rewriter) { +Value emitScalarOpFor(ConversionPatternRewriter &rewriter, + Location loc, Operation *op, Type elementType, + ArrayRef scalarOperands) { // ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; - + Value operand = scalarOperands[0]; auto one = emitConstantOp(rewriter, loc, elementType, 1); auto result = rewriter.create(loc, one, operand); @@ -356,13 +331,11 @@ Value mapToLowerScalarOp( // Scalar unary ops for lowering ONNXSoftplusOp //===----------------------------------------------------------------------===// template <> -Value mapToLowerScalarOp( - Operation *op, ArrayRef result_types, ArrayRef operands, - ConversionPatternRewriter &rewriter) { +Value emitScalarOpFor(ConversionPatternRewriter &rewriter, + Location loc, Operation *op, Type elementType, + ArrayRef scalarOperands) { // ONNXSoftplusOp(%X) = LogOp(AddFOp(ExpOp(%X), ConstantOp 1)) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; + Value operand = scalarOperands[0]; auto exp = rewriter.create(loc, operand); auto one = emitConstantOp(rewriter, loc, elementType, 1); @@ -376,13 +349,11 @@ Value mapToLowerScalarOp( // Scalar unary ops for lowering ONNXSoftsignOp //===----------------------------------------------------------------------===// template <> -Value mapToLowerScalarOp( - Operation *op, ArrayRef result_types, ArrayRef operands, - ConversionPatternRewriter &rewriter) { +Value emitScalarOpFor(ConversionPatternRewriter &rewriter, + Location loc, Operation *op, Type elementType, + ArrayRef scalarOperands) { // ONNXSoftsignOp(%X) = DivFOp(ConstantOp 1, %X) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; + Value operand = scalarOperands[0]; auto abs = rewriter.create(loc, operand); auto one = emitConstantOp(rewriter, loc, elementType, 1); @@ -396,13 +367,10 @@ Value mapToLowerScalarOp( // Scalar unary ops for lowering ONNXSignOp //===----------------------------------------------------------------------===// template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - - auto loc = op->getLoc(); - Value operand = operands[0]; - Type elementType = operands.front().getType(); +Value emitScalarOpFor(ConversionPatternRewriter &rewriter, + Location loc, Operation *op, Type elementType, + ArrayRef scalarOperands) { + Value operand = scalarOperands[0]; // TODO: unsigned int should be supported separately? if (elementType.isa()) { // %Y = SelectOP(CmpIOp(GT, %X, ConstantOp 0), @@ -451,15 +419,14 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, // Scalar unary ops for lowering ONNXMaxOp //===----------------------------------------------------------------------===// template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { +Value emitScalarOpFor(ConversionPatternRewriter &rewriter, + Location loc, Operation *op, Type elementType, + ArrayRef scalarOperands) { // ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y), // %X, // %Y) - auto loc = op->getLoc(); - Value lhs = operands[0]; - Value rhs = operands[1]; + Value lhs = scalarOperands[0]; + Value rhs = scalarOperands[1]; auto max = rewriter.create(loc, CmpFPredicate::OGT, lhs, rhs); auto result = rewriter.create(loc, max, lhs, rhs); return result; @@ -469,15 +436,14 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, // Scalar unary ops for lowering ONNXMinOp //===----------------------------------------------------------------------===// template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { +Value emitScalarOpFor(ConversionPatternRewriter &rewriter, + Location loc, Operation *op, Type elementType, + ArrayRef scalarOperands) { // ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y), // %X, // %Y) - auto loc = op->getLoc(); - Value lhs = operands[0]; - Value rhs = operands[1]; + Value lhs = scalarOperands[0]; + Value rhs = scalarOperands[1]; auto min = rewriter.create(loc, CmpFPredicate::OLT, lhs, rhs); auto result = rewriter.create(loc, min, lhs, rhs); return result; @@ -487,11 +453,10 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, // Scalar unary ops for lowering ONNXAbsOp //===----------------------------------------------------------------------===// template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, ConversionPatternRewriter &rewriter) { - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; +Value emitScalarOpFor(ConversionPatternRewriter &rewriter, + Location loc, Operation *op, Type elementType, + ArrayRef scalarOperands) { + Value operand = scalarOperands[0]; if (elementType.isa()) { return rewriter.create(loc, operand); @@ -536,15 +501,14 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, - {operands[0]}); + alloc = insertAllocAndDealloc( + memRefType, loc, rewriter, insertDealloc, {operands[0]}); std::vector originalLoops; KrnlOptimizeLoopsOp optimizedLoopsOp; KrnlIterateOp iterateOp; emitKrnlLoopsAndIterationForOperand( - rewriter, loc, operands[0], originalLoops, - optimizedLoopsOp, iterateOp); + rewriter, loc, operands[0], originalLoops, optimizedLoopsOp, iterateOp); Block &optimizationBlock = optimizedLoopsOp.region().front(); Block &iterationBlock = iterateOp.bodyRegion().front(); @@ -564,8 +528,8 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { loopIVs.push_back(arg); auto loadedVal = rewriter.create(loc, operands[0], loopIVs); - auto loweredOpResult = mapToLowerScalarOp( - op, memRefType.getElementType(), {loadedVal}, rewriter); + auto loweredOpResult = emitScalarOpFor( + rewriter, loc, op, memRefType.getElementType(), {loadedVal}); // Store result in the resulting array. rewriter.create(loc, loweredOpResult, alloc, loopIVs); @@ -603,8 +567,8 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, - operands); + alloc = insertAllocAndDealloc( + memRefType, loc, rewriter, insertDealloc, operands); // Get run-time dimension information for unknown dimensions used for // broadcasting. @@ -615,8 +579,7 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { KrnlOptimizeLoopsOp optimizedLoopsOp; KrnlIterateOp iterateOp; emitKrnlLoopsAndIterationForOperand( - rewriter, loc, alloc, originalLoops, - optimizedLoopsOp, iterateOp); + rewriter, loc, alloc, originalLoops, optimizedLoopsOp, iterateOp); Block &optimizationBlock = optimizedLoopsOp.region().front(); Block &iterationBlock = iterateOp.bodyRegion().front(); @@ -643,8 +606,8 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { auto nextLoopIVs = getLoopIVsForBroadcasting( loc, rewriter, loopIVs, operands[i], broadcastedDimInfo[i]); next = rewriter.create(loc, operands[i], nextLoopIVs); - accumulated = mapToLowerScalarOp( - op, memRefType.getElementType(), {accumulated, next}, rewriter); + accumulated = emitScalarOpFor( + rewriter, loc, op, memRefType.getElementType(), {accumulated, next}); } // Store result in the resulting array. rewriter.create(loc, accumulated, alloc, loopIVs); @@ -658,31 +621,31 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { void populateLoweringONNXElementwiseOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx) { patterns.insert, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseVariadicOpLowering>(ctx); + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseVariadicOpLowering>(ctx); } diff --git a/src/Conversion/ONNXToKrnl/Math/Reduction.cpp b/src/Conversion/ONNXToKrnl/Math/Reduction.cpp index f1b0bb5..3520827 100644 --- a/src/Conversion/ONNXToKrnl/Math/Reduction.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Reduction.cpp @@ -54,13 +54,11 @@ struct ScalarOp { // Scalar unary ops for lowering ONNXReduceMaxOp //===----------------------------------------------------------------------===// template <> -Value mapToLowerScalarOp(Operation *op, - ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - auto loc = op->getLoc(); - Value lhs = operands[0]; - Value rhs = operands[1]; +Value emitScalarOpFor(ConversionPatternRewriter &rewriter, + Location loc, Operation *op, Type elementType, + ArrayRef scalarOperands) { + Value lhs = scalarOperands[0]; + Value rhs = scalarOperands[1]; Type element_type = lhs.getType(); if (element_type.isa()) { auto max = rewriter.create(loc, CmpIPredicate::sgt, lhs, rhs); @@ -80,19 +78,16 @@ Value mapToLowerScalarOp(Operation *op, // Scalar unary ops for lowering ONNXReduceMinOp //===----------------------------------------------------------------------===// template <> -Value mapToLowerScalarOp(Operation *op, - ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - auto loc = op->getLoc(); - Value lhs = operands[0]; - Value rhs = operands[1]; - Type element_type = lhs.getType(); - if (element_type.isa()) { +Value emitScalarOpFor(ConversionPatternRewriter &rewriter, + Location loc, Operation *op, Type elementType, + ArrayRef scalarOperands) { + Value lhs = scalarOperands[0]; + Value rhs = scalarOperands[1]; + if (elementType.isa()) { auto min = rewriter.create(loc, CmpIPredicate::slt, lhs, rhs); auto result = rewriter.create(loc, min, lhs, rhs); return result; - } else if (element_type.isa()) { + } else if (elementType.isa()) { auto min = rewriter.create(loc, CmpFPredicate::OLT, lhs, rhs); auto result = rewriter.create(loc, min, lhs, rhs); return result; @@ -129,7 +124,7 @@ struct ONNXReductionOpLowering : public ConversionPattern { * Y(i1) += X(i0, i1, i2) * } * - */ + */ auto loc = op->getLoc(); auto memRefInType = operands[0].getType().cast(); auto memRefInShape = memRefInType.getShape(); @@ -154,8 +149,7 @@ struct ONNXReductionOpLowering : public ConversionPattern { } } // KeepDims - auto keepdims = - llvm::dyn_cast(op).keepdims(); + auto keepdims = llvm::dyn_cast(op).keepdims(); bool isKeepdims = (keepdims == 1) ? true : false; // Get type information @@ -168,7 +162,8 @@ struct ONNXReductionOpLowering : public ConversionPattern { Value alloc; bool insertDealloc = checkInsertDealloc(op); if (hasAllConstantDimensions(memRefOutType)) { - alloc = insertAllocAndDealloc(memRefOutType, loc, rewriter, insertDealloc); + alloc = + insertAllocAndDealloc(memRefOutType, loc, rewriter, insertDealloc); } else { SmallVector allocOperands; for (decltype(outRank) i = 0; i < outRank; ++i) { @@ -192,12 +187,12 @@ struct ONNXReductionOpLowering : public ConversionPattern { // Define loops to initialize the result. std::vector originalLoopsInit; std::vector optimizedLoopsInit; - Block *optimizationBlockInit = defineLoops(rewriter, loc, originalLoopsInit, - optimizedLoopsInit, outRank); + Block *optimizationBlockInit = defineLoops( + rewriter, loc, originalLoopsInit, optimizedLoopsInit, outRank); // Iteration information - KrnlIterateOperandPack packInit(rewriter, originalLoopsInit, - optimizedLoopsInit); + KrnlIterateOperandPack packInit( + rewriter, originalLoopsInit, optimizedLoopsInit); for (decltype(outRank) i = 0; i < outRank; ++i) { addDimensionToPack(rewriter, loc, packInit, alloc, i); } @@ -225,8 +220,8 @@ struct ONNXReductionOpLowering : public ConversionPattern { // Define an Krnl loop to do reduction. rewriter.setInsertionPointAfter(iterateOpInit); std::vector originalLoops, optimizedLoops; - Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, - optimizedLoops, inRank); + Block *optimizationBlock = + defineLoops(rewriter, loc, originalLoops, optimizedLoops, inRank); // Iteration information KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); for (decltype(inRank) i = 0; i < inRank; ++i) { @@ -266,8 +261,8 @@ struct ONNXReductionOpLowering : public ConversionPattern { Value next, accumulated; next = rewriter.create(loc, operands[0], inLoopIVs); accumulated = rewriter.create(loc, alloc, outLoopIVs); - accumulated = mapToLowerScalarOp( - op, memRefOutType.getElementType(), {accumulated, next}, rewriter); + accumulated = emitScalarOpFor( + rewriter, loc, op, memRefOutType.getElementType(), {accumulated, next}); rewriter.create(loc, accumulated, alloc, outLoopIVs); rewriter.replaceOp(op, alloc); @@ -278,7 +273,7 @@ struct ONNXReductionOpLowering : public ConversionPattern { void populateLoweringONNXReductionOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx) { patterns.insert, - ONNXReductionOpLowering, - ONNXReductionOpLowering, - ONNXReductionOpLowering>(ctx); + ONNXReductionOpLowering, + ONNXReductionOpLowering, + ONNXReductionOpLowering>(ctx); } diff --git a/src/Conversion/ONNXToKrnl/NN/Pooling.cpp b/src/Conversion/ONNXToKrnl/NN/Pooling.cpp index c3b090f..19f9b0d 100644 --- a/src/Conversion/ONNXToKrnl/NN/Pooling.cpp +++ b/src/Conversion/ONNXToKrnl/NN/Pooling.cpp @@ -20,12 +20,11 @@ Value getIdentityValue( } template <> -Value mapToLowerScalarOp(Operation *op, - ArrayRef result_types, ArrayRef operands, - ConversionPatternRewriter &rewriter) { - auto loc = op->getLoc(); - Value lhs = operands[0]; - Value rhs = operands[1]; +Value emitScalarOpFor( + ConversionPatternRewriter &rewriter, Location loc, Operation *op, + Type elementType, ArrayRef scalarOperands) { + Value lhs = scalarOperands[0]; + Value rhs = scalarOperands[1]; auto max = rewriter.create(loc, CmpFPredicate::OGT, lhs, rhs); auto result = rewriter.create(loc, max, lhs, rhs); return result; @@ -308,8 +307,8 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { auto loadData = rewriter.create(loc, inputOperand, dataIndices); auto loadPartialResult = rewriter.create(loc, alloc, resultIndices); - Value result = mapToLowerScalarOp( - op, resultElementType, {loadPartialResult, loadData}, rewriter); + Value result = emitScalarOpFor(rewriter, loc, + op, resultElementType, {loadPartialResult, loadData}); rewriter.create(loc, result, alloc, resultIndices); } } diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index f725b8d..2c0c5f7 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -148,17 +148,14 @@ Value getIdentityValue( // Use template specialization for each of different ONNX operations. //===----------------------------------------------------------------------===// template -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - auto loc = op->getLoc(); - Type element_type = operands.front().getType(); - if (element_type.isa()) { - return rewriter.create>(loc, result_types, operands, - mlir::None); - } else if (element_type.isa()) { - return rewriter.create>(loc, result_types, operands, - mlir::None); +Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc, + Operation *op, Type elementType, ArrayRef scalarOperands) { + if (elementType.isa()) { + return rewriter.create>( + loc, elementType, scalarOperands, mlir::None); + } else if (elementType.isa()) { + return rewriter.create>( + loc, elementType, scalarOperands, mlir::None); } else { emitError(loc, "unsupported element type"); return nullptr; @@ -247,4 +244,3 @@ void populateLoweringONNXIdentityOpPattern( void populateLoweringONNXConstantOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx); -