//===---------------- Elementwise.cpp - Elementwise Ops -------------------===// // // Copyright 2019 The IBM Research Authors. // // ============================================================================= // // This file lowers ONNX element-wise operators to Krnl dialect. // //===----------------------------------------------------------------------===// #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" using namespace mlir; template <> struct ScalarOp { using FOp = AddFOp; using IOp = AddIOp; }; template <> struct ScalarOp { using FOp = MulFOp; using IOp = MulIOp; }; template <> struct ScalarOp { using FOp = DivFOp; using IOp = SignedDivIOp; }; template <> struct ScalarOp { using FOp = SubFOp; using IOp = SubIOp; }; template <> struct ScalarOp { using FOp = AndOp; // not use using IOp = AndOp; }; template <> struct ScalarOp { using FOp = OrOp; // not use using IOp = OrOp; }; template <> struct ScalarOp { using FOp = XOrOp; // not use using IOp = XOrOp; }; template <> struct ScalarOp { using FOp = ExpOp; using IOp = ExpOp; // not use }; template <> struct ScalarOp { using FOp = AddFOp; using IOp = AddIOp; }; template <> struct ScalarOp { using FOp = CosOp; using IOp = CosOp; // not use }; template <> struct ScalarOp { using FOp = LogOp; using IOp = LogOp; // not use }; template <> struct ScalarOp { using FOp = SqrtOp; using IOp = SqrtOp; // not use }; //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSinhOp //===----------------------------------------------------------------------===// template <> Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc, Operation *op, Type elementType, ArrayRef scalarOperands) { // ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ConstantOp 2) Value operand = scalarOperands[0]; auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto two = emitConstantOp(rewriter, loc, elementType, 2); auto neg = rewriter.create(loc, zero, operand); auto exp = rewriter.create(loc, operand); auto negExp = rewriter.create(loc, neg); auto result = rewriter.create( loc, rewriter.create(loc, exp, negExp), two); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXCoshOp //===----------------------------------------------------------------------===// template <> Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc, Operation *op, Type elementType, ArrayRef scalarOperands) { // ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ConstantOp 2) Value operand = scalarOperands[0]; auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto two = emitConstantOp(rewriter, loc, elementType, 2); auto neg = rewriter.create(loc, zero, operand); auto exp = rewriter.create(loc, operand); auto negExp = rewriter.create(loc, neg); auto result = rewriter.create( loc, rewriter.create(loc, exp, negExp), two); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXTanhOp //===----------------------------------------------------------------------===// template <> Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc, Operation *op, Type elementType, ArrayRef scalarOperands) { // ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // AddFOp(ExpOp(%X), ExpOp(NegFOp(%X)))) Value operand = scalarOperands[0]; auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto neg = rewriter.create(loc, zero, operand); auto exp = rewriter.create(loc, operand); auto negExp = rewriter.create(loc, neg); auto dividend = rewriter.create(loc, exp, negExp); auto divisor = rewriter.create(loc, exp, negExp); auto result = rewriter.create(loc, dividend, divisor); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSigmoidOp //===----------------------------------------------------------------------===// template <> Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc, Operation *op, Type elementType, ArrayRef scalarOperands) { // ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1, // AddFOp(ConstantOp 1, ExpOp(NegFOp(%X)))) Value operand = scalarOperands[0]; auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto one = emitConstantOp(rewriter, loc, elementType, 1); auto neg = rewriter.create(loc, zero, operand); auto negExp = rewriter.create(loc, neg); auto result = rewriter.create( loc, one, rewriter.create(loc, one, negExp)); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXHardSigmoidOp //===----------------------------------------------------------------------===// template <> Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc, Operation *op, Type elementType, ArrayRef scalarOperands) { // %Y = AddFOp(MulFOp(alpha, %X), beta) // %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0), // %Y, // Constant 0) // ONNXHardSigmoidOp(%X) = SelectOp(CmpFOp(OLT, %Z, Constant 1), // %Z, // Constant 1) Value operand = scalarOperands[0]; auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), llvm::dyn_cast(op).alpha().convertToFloat()); auto betaAttribute = FloatAttr::get(rewriter.getF32Type(), llvm::dyn_cast(op).beta().convertToFloat()); auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto one = emitConstantOp(rewriter, loc, elementType, 1); auto alpha = rewriter.create(loc, alphaAttribute); auto beta = rewriter.create(loc, betaAttribute); auto add = rewriter.create( loc, rewriter.create(loc, alpha, operand), beta); auto maxPredicate = rewriter.create(loc, CmpFPredicate::OGT, add, zero); auto max = rewriter.create(loc, maxPredicate, add, zero); auto minPredicate = rewriter.create(loc, CmpFPredicate::OLT, max, one); auto result = rewriter.create(loc, minPredicate, max, one); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXEluOp //===----------------------------------------------------------------------===// template <> Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc, Operation *op, Type elementType, ArrayRef scalarOperands) { // ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // MulFOp(alpha, SubFOp(ExpOp(%X), 1)), // %X) Value operand = scalarOperands[0]; auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), llvm::dyn_cast(op).alpha().convertToFloat()); auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto one = emitConstantOp(rewriter, loc, elementType, 1); auto alpha = rewriter.create(loc, alphaAttribute); auto exp = rewriter.create(loc, operand); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); auto result = rewriter.create(loc, lessThanZero, rewriter.create( loc, alpha, rewriter.create(loc, exp, one)), operand); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXReluOp //===----------------------------------------------------------------------===// template <> Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc, Operation *op, Type elementType, ArrayRef scalarOperands) { // ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // ConstantOp 0, // %X) Value operand = scalarOperands[0]; auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); auto result = rewriter.create(loc, lessThanZero, zero, operand); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXLeakyReluOp //===----------------------------------------------------------------------===// template <> Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc, Operation *op, Type elementType, ArrayRef scalarOperands) { // ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // MulFOp(alpha, %X), // %X) Value operand = scalarOperands[0]; auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), llvm::dyn_cast(op).alpha().convertToFloat()); auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto alpha = rewriter.create(loc, alphaAttribute); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); auto result = rewriter.create( loc, lessThanZero, rewriter.create(loc, alpha, operand), operand); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSeluOp //===----------------------------------------------------------------------===// template <> Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc, Operation *op, Type elementType, ArrayRef scalarOperands) { // ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0), // MulFOp(gamma, %X), // MulFOp(gamma, // SubFOp(MulFOp(alpha, ExpOp(%X)), // alpha))) Value operand = scalarOperands[0]; auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), llvm::dyn_cast(op).alpha().convertToFloat()); auto gammaAttribute = FloatAttr::get(rewriter.getF32Type(), llvm::dyn_cast(op).gamma().convertToFloat()); auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto alpha = rewriter.create(loc, alphaAttribute); auto gamma = rewriter.create(loc, gammaAttribute); auto exp = rewriter.create(loc, operand); auto greaterThanZero = rewriter.create(loc, CmpFPredicate::OGT, operand, zero); auto select = rewriter.create(loc, greaterThanZero, operand, rewriter.create( loc, rewriter.create(loc, alpha, exp), alpha)); auto result = rewriter.create(loc, gamma, select); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXReciprocalOp //===----------------------------------------------------------------------===// template <> Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc, Operation *op, Type elementType, ArrayRef scalarOperands) { // ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X) Value operand = scalarOperands[0]; auto one = emitConstantOp(rewriter, loc, elementType, 1); auto result = rewriter.create(loc, one, operand); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSoftplusOp //===----------------------------------------------------------------------===// template <> Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc, Operation *op, Type elementType, ArrayRef scalarOperands) { // ONNXSoftplusOp(%X) = LogOp(AddFOp(ExpOp(%X), ConstantOp 1)) Value operand = scalarOperands[0]; auto exp = rewriter.create(loc, operand); auto one = emitConstantOp(rewriter, loc, elementType, 1); auto add = rewriter.create(loc, exp, one); auto result = rewriter.create(loc, add); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSoftsignOp //===----------------------------------------------------------------------===// template <> Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc, Operation *op, Type elementType, ArrayRef scalarOperands) { // ONNXSoftsignOp(%X) = DivFOp(ConstantOp 1, %X) Value operand = scalarOperands[0]; auto abs = rewriter.create(loc, operand); auto one = emitConstantOp(rewriter, loc, elementType, 1); auto add = rewriter.create(loc, abs, one); auto result = rewriter.create(loc, operand, add); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSignOp //===----------------------------------------------------------------------===// template <> Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc, Operation *op, Type elementType, ArrayRef scalarOperands) { Value operand = scalarOperands[0]; // TODO: unsigned int should be supported separately? if (elementType.isa()) { // %Y = SelectOP(CmpIOp(GT, %X, ConstantOp 0), // ConstantOp 1, // COnstantOp -1) // ONNXSignOp(%X) = SelectOP(CmpIOp(EQ, %X, ConstantOp 0), // ConstantOp 0, // %Y) auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto one = emitConstantOp(rewriter, loc, elementType, 1); auto minusOne = emitConstantOp(rewriter, loc, elementType, -1); auto plusPredicate = rewriter.create(loc, CmpIPredicate::sgt, operand, zero); auto plusSelect = rewriter.create(loc, plusPredicate, one, minusOne); auto zeroPredicate = rewriter.create(loc, CmpIPredicate::eq, operand, zero); auto result = rewriter.create(loc, zeroPredicate, zero, plusSelect); return result; } else if (elementType.isa()) { // %Y = SelectOP(CmpFOp(OGT, %X, ConstantOp 0), // ConstantOp 1, // ConstantOp -1) // ONNXSignOp(%X) = SelectOP(CmpFOp(OEQ, %X, ConstantOp 0), // ConstantOp 0, // %Y) auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto one = emitConstantOp(rewriter, loc, elementType, 1); auto minusOne = emitConstantOp(rewriter, loc, elementType, -1); auto plusPredicate = rewriter.create(loc, CmpFPredicate::OGT, operand, zero); auto plusSelect = rewriter.create(loc, plusPredicate, one, minusOne); auto zeroPredicate = rewriter.create(loc, CmpFPredicate::OEQ, operand, zero); auto result = rewriter.create(loc, zeroPredicate, zero, plusSelect); return result; } else { emitError(loc, "unsupported element type"); return {}; } } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXMaxOp //===----------------------------------------------------------------------===// template <> Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc, Operation *op, Type elementType, ArrayRef scalarOperands) { // ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y), // %X, // %Y) Value lhs = scalarOperands[0]; Value rhs = scalarOperands[1]; auto max = rewriter.create(loc, CmpFPredicate::OGT, lhs, rhs); auto result = rewriter.create(loc, max, lhs, rhs); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXMinOp //===----------------------------------------------------------------------===// template <> Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc, Operation *op, Type elementType, ArrayRef scalarOperands) { // ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y), // %X, // %Y) Value lhs = scalarOperands[0]; Value rhs = scalarOperands[1]; auto min = rewriter.create(loc, CmpFPredicate::OLT, lhs, rhs); auto result = rewriter.create(loc, min, lhs, rhs); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXAbsOp //===----------------------------------------------------------------------===// template <> Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc, Operation *op, Type elementType, ArrayRef scalarOperands) { Value operand = scalarOperands[0]; if (elementType.isa()) { return rewriter.create(loc, operand); } else if (elementType.isa()) { auto zero = emitConstantOp(rewriter, loc, elementType, 0); auto lessThanZero = rewriter.create(loc, CmpIPredicate::slt, operand, zero); auto negativeOperand = rewriter.create(loc, zero, operand); return rewriter.create( loc, lessThanZero, negativeOperand, operand); } else { emitError(loc, "unsupported element type"); return {}; } } // Element-wise unary ops lowering to Krnl dialect. //===----------------------------------------------------------------------===// template struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { ONNXElementwiseUnaryOpLowering(MLIRContext *ctx) : ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // TODO: Check that the types are valid. // An element-wise unary operation must have all operands and the result of // the same type. This should have been verified by the verifier. auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertToMemRefType(*op->result_type_begin()); // If the output has a dynamic dimension, pass the operands required for // each dynamic dimension to the AllocOp. The first operand of the // operation is used. The operands of the op need to match in terms of // dimensions with the result at this pre-optimization phase. // TODO: verify that dimensions match. // TODO: can the dimension of the result differ after optimizations? Value alloc; bool insertDealloc = checkInsertDealloc(op); if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else alloc = insertAllocAndDealloc( memRefType, loc, rewriter, insertDealloc, {operands[0]}); std::vector originalLoops; KrnlOptimizeLoopsOp optimizedLoopsOp; KrnlIterateOp iterateOp; emitKrnlLoopsAndIterationForOperand( rewriter, loc, operands[0], originalLoops, optimizedLoopsOp, iterateOp); Block &optimizationBlock = optimizedLoopsOp.region().front(); Block &iterationBlock = iterateOp.bodyRegion().front(); // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. rewriter.setInsertionPointToEnd(&optimizationBlock); // Return from KrnlOptimizeLoopsOp body. // When no optimizations are present we just return the loops // unchaged. rewriter.create(loc, originalLoops); // 2. Insert instructions inside the KernelIterateOp body. rewriter.setInsertionPointToStart(&iterationBlock); // Handle the operation: SmallVector loopIVs; for (auto arg : iterationBlock.getArguments()) loopIVs.push_back(arg); auto loadedVal = rewriter.create(loc, operands[0], loopIVs); auto loweredOpResult = emitScalarOpFor( rewriter, loc, op, memRefType.getElementType(), {loadedVal}); // Store result in the resulting array. rewriter.create(loc, loweredOpResult, alloc, loopIVs); rewriter.replaceOp(op, alloc); return success(); } }; // Element-wise variadic ops lowering to Krnl dialect. //===----------------------------------------------------------------------===// template struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { ONNXElementwiseVariadicOpLowering(MLIRContext *ctx) : ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // TODO: Check that the types are valid. // An element-wise variadic operation must have all operands and the result // of the same type. This should have been verified by the verifier. auto loc = op->getLoc(); auto numArgs = op->getNumOperands(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertToMemRefType(*op->result_type_begin()); Value alloc; bool insertDealloc = checkInsertDealloc(op); // If the output has a dynamic dimension, we compute its dimension at // runtime by using dimensions from the operands. // In particular, we need to know from which operand a result dimension // comes from. // TODO: can the dimension of the result differ after optimizations? if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else alloc = insertAllocAndDealloc( memRefType, loc, rewriter, insertDealloc, operands); // Get run-time dimension information for unknown dimensions used for // broadcasting. std::map> broadcastedDimInfo = getBroadcastedDimInfo(loc, rewriter, memRefType, operands); std::vector originalLoops; KrnlOptimizeLoopsOp optimizedLoopsOp; KrnlIterateOp iterateOp; emitKrnlLoopsAndIterationForOperand( rewriter, loc, alloc, originalLoops, optimizedLoopsOp, iterateOp); Block &optimizationBlock = optimizedLoopsOp.region().front(); Block &iterationBlock = iterateOp.bodyRegion().front(); // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. rewriter.setInsertionPointToEnd(&optimizationBlock); // Return from KrnlOptimizeLoopsOp body. // When no optimizations are present we just return the loops unchaged. rewriter.create(loc, originalLoops); // 2. Insert instructions inside the KernelIterateOp body. rewriter.setInsertionPointToStart(&iterationBlock); // Handle the operation: SmallVector loopIVs; for (auto arg : iterationBlock.getArguments()) loopIVs.push_back(arg); // Fold over operands for each of their scalar values Value accumulated, next; auto accumulatedLoopIVs = getLoopIVsForBroadcasting( loc, rewriter, loopIVs, operands[0], broadcastedDimInfo[0]); accumulated = rewriter.create(loc, operands[0], accumulatedLoopIVs); for (unsigned i = 1; i < numArgs; i++) { auto nextLoopIVs = getLoopIVsForBroadcasting( loc, rewriter, loopIVs, operands[i], broadcastedDimInfo[i]); next = rewriter.create(loc, operands[i], nextLoopIVs); accumulated = emitScalarOpFor( rewriter, loc, op, memRefType.getElementType(), {accumulated, next}); } // Store result in the resulting array. rewriter.create(loc, accumulated, alloc, loopIVs); rewriter.replaceOp(op, alloc); return success(); } }; void populateLoweringONNXElementwiseOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx) { patterns.insert, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseVariadicOpLowering>(ctx); }