//===----- elementwise.inc - Elementwise Ops ------------------------------===// // // Copyright 2019 The IBM Research Authors. // // ============================================================================= // // This file lowers ONNX element-wise operators to Krnl dialect. // //===----------------------------------------------------------------------===// template <> struct ScalarOp { using FOp = AddFOp; using IOp = AddIOp; }; template <> struct ScalarOp { using FOp = MulFOp; using IOp = MulIOp; }; template <> struct ScalarOp { using FOp = DivFOp; using IOp = SignedDivIOp; }; template <> struct ScalarOp { using FOp = SubFOp; using IOp = SubIOp; }; template <> struct ScalarOp { using FOp = AndOp; // not use using IOp = AndOp; }; template <> struct ScalarOp { using FOp = OrOp; // not use using IOp = OrOp; }; template <> struct ScalarOp { using FOp = XOrOp; // not use using IOp = XOrOp; }; template <> struct ScalarOp { using FOp = ExpOp; using IOp = ExpOp; // not use }; template <> struct ScalarOp { using FOp = AddFOp; using IOp = AddIOp; }; template <> struct ScalarOp { using FOp = TanhOp; using IOp = TanhOp; // not use }; template <> struct ScalarOp { using FOp = CosOp; using IOp = CosOp; // not use }; template <> struct ScalarOp { using FOp = LogOp; using IOp = LogOp; // not use }; template <> struct ScalarOp { using FOp = KrnlSqrtOp; using IOp = KrnlSqrtOp; // not use }; //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSinhOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ConstantOp 2) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto two = rewriter.create(loc, FloatAttr::get(elementType, 2)); auto neg = rewriter.create(loc, zero, operand); auto exp = rewriter.create(loc, operand); auto negExp = rewriter.create(loc, neg); auto result = rewriter.create( loc, rewriter.create(loc, exp, negExp), two); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXCoshOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ConstantOp 2) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto two = rewriter.create(loc, FloatAttr::get(elementType, 2)); auto neg = rewriter.create(loc, zero, operand); auto exp = rewriter.create(loc, operand); auto negExp = rewriter.create(loc, neg); auto result = rewriter.create( loc, rewriter.create(loc, exp, negExp), two); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSigmoidOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1, // AddFOp(ConstantOp 1, ExpOp(NegFOp(%X)))) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto neg = rewriter.create(loc, zero, operand); auto negExp = rewriter.create(loc, neg); auto result = rewriter.create( loc, one, rewriter.create(loc, one, negExp)); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXHardSigmoidOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp( Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // %Y = AddFOp(MulFOp(alpha, %X), beta) // %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0), // %Y, // Constant 0) // ONNXHardSigmoidOp(%X) = SelectOp(CmpFOp(OLT, %Z, Constant 1), // %Z, // Constant 1) auto loc = op->getLoc(); Value operand = operands[0]; auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), llvm::dyn_cast(op).alpha().convertToFloat()); auto betaAttribute = FloatAttr::get(rewriter.getF32Type(), llvm::dyn_cast(op).beta().convertToFloat()); auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto alpha = rewriter.create(loc, alphaAttribute); auto beta = rewriter.create(loc, betaAttribute); auto add = rewriter.create( loc, rewriter.create(loc, alpha, operand), beta); auto maxPredicate = rewriter.create(loc, CmpFPredicate::OGT, add, zero); auto max = rewriter.create(loc, maxPredicate, add, zero); auto minPredicate = rewriter.create(loc, CmpFPredicate::OLT, max, one); auto result = rewriter.create(loc, minPredicate, max, one); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXEluOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // MulFOp(alpha, SubFOp(ExpOp(%X), 1)), // %X) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), llvm::dyn_cast(op).alpha().convertToFloat()); auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto alpha = rewriter.create(loc, alphaAttribute); auto exp = rewriter.create(loc, operand); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); auto result = rewriter.create( loc, lessThanZero, rewriter.create(loc, alpha, rewriter.create(loc, exp, one)), operand); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXReluOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // ConstantOp 0, // %X) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); auto result = rewriter.create(loc, lessThanZero, zero, operand); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXLeakyReluOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // MulFOp(alpha, %X), // %X) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), llvm::dyn_cast(op).alpha().convertToFloat()); auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto alpha = rewriter.create(loc, alphaAttribute); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); auto result = rewriter.create( loc, lessThanZero, rewriter.create(loc, alpha, operand), operand); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSeluOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0), // MulFOp(gamma, %X), // MulFOp(gamma, // SubFOp(MulFOp(alpha, ExpOp(%X)), // alpha))) auto loc = op->getLoc(); Value operand = operands[0]; auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), llvm::dyn_cast(op).alpha().convertToFloat()); auto gammaAttribute = FloatAttr::get(rewriter.getF32Type(), llvm::dyn_cast(op).gamma().convertToFloat()); auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto alpha = rewriter.create(loc, alphaAttribute); auto gamma = rewriter.create(loc, gammaAttribute); auto exp = rewriter.create(loc, operand); auto greaterThanZero = rewriter.create(loc, CmpFPredicate::OGT, operand, zero); auto select = rewriter.create( loc, greaterThanZero, operand, rewriter.create(loc, rewriter.create(loc, alpha, exp), alpha)); auto result = rewriter.create(loc, gamma, select); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXReciprocalOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp( Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto result = rewriter.create(loc, one, operand); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSoftplusOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp( Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXSoftplusOp(%X) = LogOp(AddFOp(ExpOp(%X), ConstantOp 1)) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto exp = rewriter.create(loc, operand); auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto add = rewriter.create(loc, exp, one); auto result = rewriter.create(loc, add); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSoftsignOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp( Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXSoftsignOp(%X) = DivFOp(ConstantOp 1, %X) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto abs = rewriter.create(loc, operand); auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto add = rewriter.create(loc, abs, one); auto result = rewriter.create(loc, operand, add); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSignOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { auto loc = op->getLoc(); Value operand = operands[0]; Type element_type = operands.front().getType(); // TODO: unsigned int should be supported separately? if (element_type.isa()) { // %Y = SelectOP(CmpIOp(GT, %X, ConstantOp 0), // ConstantOp 1, // COnstantOp -1) // ONNXSignOp(%X) = SelectOP(CmpIOp(EQ, %X, ConstantOp 0), // ConstantOp 0, // %Y) auto zero = rewriter.create(loc, rewriter.getI32IntegerAttr(0)); auto one = rewriter.create(loc, rewriter.getI32IntegerAttr(1)); auto minusOne = rewriter.create(loc, rewriter.getI32IntegerAttr(-1)); auto plusPredicate = rewriter.create(loc, CmpIPredicate::sgt, operand, zero); auto plusSelect = rewriter.create(loc, plusPredicate, one, minusOne); auto zeroPredicate = rewriter.create(loc, CmpIPredicate::eq, operand, zero); auto result = rewriter.create(loc, zeroPredicate, zero, plusSelect); return result; } else if (element_type.isa()) { // %Y = SelectOP(CmpFOp(OGT, %X, ConstantOp 0), // ConstantOp 1, // ConstantOp -1) // ONNXSignOp(%X) = SelectOP(CmpFOp(OEQ, %X, ConstantOp 0), // ConstantOp 0, // %Y) auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); auto one = rewriter.create(loc, rewriter.getF32FloatAttr(1.0f)); auto minusOne = rewriter.create(loc, rewriter.getF32FloatAttr(-1.0f)); auto plusPredicate = rewriter.create(loc, CmpFPredicate::OGT, operand, zero); auto plusSelect = rewriter.create(loc, plusPredicate, one, minusOne); auto zeroPredicate = rewriter.create(loc, CmpFPredicate::OEQ, operand, zero); auto result = rewriter.create(loc, zeroPredicate, zero, plusSelect); return result; } else { emitError(loc, "unsupported element type"); } } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXMaxOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y), // %X, // %Y) auto loc = op->getLoc(); Value lhs = operands[0]; Value rhs = operands[1]; auto max = rewriter.create(loc, CmpFPredicate::OGT, lhs, rhs); auto result = rewriter.create(loc, max, lhs, rhs); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXMinOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y), // %X, // %Y) auto loc = op->getLoc(); Value lhs = operands[0]; Value rhs = operands[1]; auto min = rewriter.create(loc, CmpFPredicate::OLT, lhs, rhs); auto result = rewriter.create(loc, min, lhs, rhs); return result; } // Element-wise unary ops lowering to Krnl dialect. //===----------------------------------------------------------------------===// template struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { ONNXElementwiseUnaryOpLowering(MLIRContext *ctx) : ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // TODO: Check that the types are valid. // An element-wise unary operation must have all operands and the result of // the same type. This should have been verified by the verifier. auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertTensorToMemRef(tensorType); // If the output has a dynamic dimension, pass the operands required for // each dynamic dimension to the AllocOp. The first operand of the // operation is used. The operands of the op need to match in terms of // dimensions with the result at this pre-optimization phase. // TODO: verify that dimensions match. // TODO: can the dimension of the result differ after optimizations? Value alloc; bool insertDealloc = checkInsertDealloc(op); if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, {operands[0]}); std::vector originalLoops; KrnlOptimizeLoopsOp optimizedLoopsOp; KrnlIterateOp iterateOp; emitKrnlLoopsAndIterationForOperand( rewriter, loc, operands[0], originalLoops, optimizedLoopsOp, iterateOp); Block &optimizationBlock = optimizedLoopsOp.region().front(); Block &iterationBlock = iterateOp.bodyRegion().front(); // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. rewriter.setInsertionPointToEnd(&optimizationBlock); // Return from KrnlOptimizeLoopsOp body. // When no optimizations are present we just return the loops // unchaged. rewriter.create(loc, originalLoops); // 2. Insert instructions inside the KernelIterateOp body. rewriter.setInsertionPointToStart(&iterationBlock); // Handle the operation: SmallVector loopIVs; for (auto arg : iterationBlock.getArguments()) loopIVs.push_back(arg); auto loadedVal = rewriter.create(loc, operands[0], loopIVs); auto loweredOpResult = mapToLowerScalarOp( op, memRefType.getElementType(), {loadedVal}, rewriter); // Store result in the resulting array. rewriter.create(loc, loweredOpResult, alloc, loopIVs); rewriter.replaceOp(op, alloc); return matchSuccess(); } }; // Element-wise variadic ops lowering to Krnl dialect. //===----------------------------------------------------------------------===// template struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { ONNXElementwiseVariadicOpLowering(MLIRContext *ctx) : ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // TODO: Check that the types are valid. // An element-wise variadic operation must have all operands and the result // of the same type. This should have been verified by the verifier. auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); auto numArgs = op->getNumOperands(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertTensorToMemRef(tensorType); Value alloc; bool insertDealloc = checkInsertDealloc(op); // If the output has a dynamic dimension, we compute its dimension at // runtime by using dimensions from the operands. // In particular, we need to know from which operand a result dimension // comes from. // TODO: can the dimension of the result differ after optimizations? if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, operands); // Get run-time dimension information for unknown dimensions used for // broadcasting. std::map> broadcastedDimInfo = getBroadcastedDimInfo(loc, rewriter, memRefType, operands); std::vector originalLoops; KrnlOptimizeLoopsOp optimizedLoopsOp; KrnlIterateOp iterateOp; emitKrnlLoopsAndIterationForOperand( rewriter, loc, alloc, originalLoops, optimizedLoopsOp, iterateOp); Block &optimizationBlock = optimizedLoopsOp.region().front(); Block &iterationBlock = iterateOp.bodyRegion().front(); // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. rewriter.setInsertionPointToEnd(&optimizationBlock); // Return from KrnlOptimizeLoopsOp body. // When no optimizations are present we just return the loops unchaged. rewriter.create(loc, originalLoops); // 2. Insert instructions inside the KernelIterateOp body. rewriter.setInsertionPointToStart(&iterationBlock); // Handle the operation: SmallVector loopIVs; for (auto arg : iterationBlock.getArguments()) loopIVs.push_back(arg); // Fold over operands for each of their scalar values Value accumulated, next; auto accumulatedLoopIVs = getLoopIVsForBroadcasting( loc, rewriter, loopIVs, operands[0], broadcastedDimInfo[0]); accumulated = rewriter.create(loc, operands[0], accumulatedLoopIVs); for (unsigned i = 1; i < numArgs; i++) { auto nextLoopIVs = getLoopIVsForBroadcasting( loc, rewriter, loopIVs, operands[i], broadcastedDimInfo[i]); next = rewriter.create(loc, operands[i], nextLoopIVs); accumulated = mapToLowerScalarOp( op, memRefType.getElementType(), {accumulated, next}, rewriter); } // Store result in the resulting array. rewriter.create(loc, accumulated, alloc, loopIVs); rewriter.replaceOp(op, alloc); return matchSuccess(); } }; void populateLoweringONNXElementwiseOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx) { patterns.insert, ONNXElementwiseVariadicOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseVariadicOpLowering>(ctx); }