//====- lower_frontend_to_krnl.cpp - Frontend dialects to Krnl lowering ---===// // // Copyright 2019 The IBM Research Authors. // // ============================================================================= // // This file implements the lowering of frontend operations to a combination of // Krnl IR and standard operations. // //===----------------------------------------------------------------------===// #include #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Sequence.h" #include "src/dialect/krnl/krnl_helper.hpp" #include "src/dialect/krnl/krnl_ops.hpp" #include "src/dialect/onnx/onnx_ops.hpp" #include "src/pass/passes.hpp" using namespace mlir; //===----------------------------------------------------------------------===// // FrontendToAffine RewritePatterns //===----------------------------------------------------------------------===// /// Check is all dimensions are known at compile time. static bool hasAllConstantDimensions(MemRefType type) { auto memRefShape = type.getShape(); for (int i = 0; i < memRefShape.size(); ++i) if (memRefShape[i] < 0) return false; return true; } /// Convert the given TensorType into the corresponding MemRefType. static MemRefType convertTensorToMemRef(TensorType type) { assert(type.hasRank() && "expected only ranked shapes"); return MemRefType::get(type.getShape(), type.getElementType()); } /// Insert an allocation and deallocation for the given MemRefType. static Value insertAllocAndDealloc(MemRefType type, Location loc, PatternRewriter &rewriter, bool insertDealloc, ArrayRef operands = {}) { // Put together alloc operands for any dynamic dimensions of the memref. AllocOp alloc; if (!operands.empty()) { auto memRefShape = type.getShape(); auto rank = memRefShape.size(); std::map fromOperands; for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { int memRefDimIdx = rank - 1 - reversedIdx; if (memRefShape[memRefDimIdx] < 0) { // unknown dimension Value maxDim = nullptr; for (int i = 0; i < operands.size(); i++) { auto operandShape = operands[i].getType().cast().getShape(); int operandDimIdx = operandShape.size() - 1 - reversedIdx; if (operandDimIdx < 0) continue; // In case of operations with broadcasting, the dimension of the // alloc result is the maximum size along each dimension of the // operands. auto operandDim = rewriter.create(loc, operands[i], operandDimIdx); if (maxDim) { auto maxCondition = rewriter.create(loc, CmpIPredicate::sgt, operandDim, maxDim); maxDim = rewriter.create(loc, maxCondition, operandDim, maxDim); } else { maxDim = operandDim; } } fromOperands.insert(std::make_pair(memRefDimIdx, maxDim)); } } SmallVector allocOperands; for (int i = 0; i < rank; ++i) if (memRefShape[i] < 0) allocOperands.push_back(fromOperands[i]); alloc = rewriter.create(loc, type, allocOperands); } else { alloc = rewriter.create(loc, type); } // Make sure to allocate at the beginning of the block if // all dimensions are known. auto *parentBlock = alloc.getOperation()->getBlock(); if (hasAllConstantDimensions(type)) alloc.getOperation()->moveBefore(&parentBlock->front()); if (insertDealloc) { auto dealloc = rewriter.create(loc, alloc); dealloc.getOperation()->moveBefore(&parentBlock->back()); } return alloc; } // Determine if current function returns the result value of the // current op being lowered. If it does then dealloc should not be // inserted. static bool checkInsertDealloc(Operation *currentOp) { auto parentBlock = currentOp->getBlock(); bool insertDealloc = true; parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) { assert(currentOp->getNumResults() < 2 && "No more than one result supported (for now)."); // If there is at least one result to investigate. if (currentOp->getNumResults() > 0) { auto result = currentOp->getResult(0); for (const auto &operand : op.getOperands()) if (operand == result) insertDealloc = false; } }); return insertDealloc; } // Create a mapping from result type's dimensions to input type's dimensions, // given that the result type is the result of a reduction op over the input // type. std::map getReductionMapping(MemRefType inputTy, ArrayRef axes, bool keepdims) { std::map OutInDimMap; int64_t rank = inputTy.getRank(); // Mark reduction axes. std::vector isReductionAxis; for (decltype(rank) i = 0; i < rank; ++i) { if (std::find(axes.begin(), axes.end(), i) != axes.end()) isReductionAxis.push_back(true); else isReductionAxis.push_back(false); } for (decltype(rank) inIndex = 0, outIndex = 0; inIndex < rank; ++inIndex) { // If it is a reduction axis, there is no relationship among dimensions. if (isReductionAxis[inIndex]) { if (keepdims) outIndex++; } else { OutInDimMap.insert(std::make_pair(outIndex, inIndex)); outIndex++; } } return OutInDimMap; } // Add bounds associated with the op operand to the KRNL iteration pack. // Dynamic dimenions are supported. static void addDimensionToPack(ConversionPatternRewriter &rewriter, Location loc, KrnlIterateOperandPack &pack, Value operand, int index) { auto shape = operand.getType().cast().getShape(); if (shape[index] < 0) { pack.pushConstantBound(0); pack.pushOperandBound( rewriter.create(loc, operand, index).getResult()); } else { pack.pushConstantBound(0); pack.pushConstantBound(shape[index]); } } // Function that defines the KRNL dialect loops and their respective // optimized version. static KrnlOptimizeLoopsOp emitOptimizedLoops( ConversionPatternRewriter &rewriter, Location loc, std::vector &loops, std::vector &optimizedLoops, int64_t numLoops) { // Define loops. auto loopsOp = rewriter.create(loc, numLoops); loops.reserve(numLoops); for (auto result : loopsOp.getResults()) loops.push_back(result); // Define optimized version of the loops. auto optimizedLoopsOp = rewriter.create(loc, numLoops); optimizedLoops.reserve(numLoops); for (auto result : optimizedLoopsOp.getResults()) optimizedLoops.push_back(result); return optimizedLoopsOp; } // Function that emits the loops and their optimized version. // The function returns a reference to the inner optimization block. static Block* defineLoops(ConversionPatternRewriter &rewriter, Location loc, std::vector &loops, std::vector &optimizedLoops, int64_t numLoops) { KrnlOptimizeLoopsOp optimizedLoopsOp = emitOptimizedLoops( rewriter, loc, loops, optimizedLoops, numLoops); return &optimizedLoopsOp.region().front(); } // Function which emits a basic set of loops and optimized loops // for a given operation argument. A reference to the loop optimization // block is returned in the last argument of the function. static void emitKrnlLoopsAndIterationForOperand( ConversionPatternRewriter &rewriter, Location loc, Value operand, std::vector &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp, KrnlIterateOp &iterateOp) { // Operand shape. auto shape = operand.getType().cast().getShape(); // Number of loops. int64_t rank = shape.size(); // Define loops and optimized loops. std::vector optimizedLoops; optimizedLoopsOp = emitOptimizedLoops(rewriter, loc, originalLoops, optimizedLoops, rank); KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); // Iterate over the loop nest. for (int i = 0; i < rank; ++i) addDimensionToPack(rewriter, loc, pack, operand, i); iterateOp = rewriter.create(loc, pack); } unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { auto elementType = memRefType.getElementType(); unsigned sizeInBits; if (elementType.isIntOrFloat()) { sizeInBits = elementType.getIntOrFloatBitWidth(); } else { auto vectorType = elementType.cast(); sizeInBits = vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); } return llvm::divideCeil(sizeInBits, 8); } // Get run-time dimension information for unknown dimensions used for // broadcasting. std::map> getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, MemRefType memRefType, ArrayRef operands) { auto memRefShape = memRefType.getShape(); int64_t rank = memRefShape.size(); // For unknown dimensions, we need to get dimension values at runtime in // order to do broadcasting. std::map> DimInfo; // For each result dimension, compute the number of sharing operands. // Sharing operands are operands sharing the same index (counting from the // rightmost to the leftmost) for a given dimension. std::map sharedDimCount; for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { int dimIdx = rank - 1 - reversedIdx; sharedDimCount[dimIdx] = 0; for (int i = 0; i < operands.size(); ++i) { auto shape = operands[i].getType().cast().getShape(); if (reversedIdx <= shape.size() - 1) sharedDimCount[dimIdx]++; } } // An unknown dimension can have a value of 1 or N (N > 1). // If its value is 1, it is broadcasted dimension. // Otherwise, non-broadcasted dimension. // We only care about unknown dimensions whose number of sharing operands is // more than one, since they are potentially broadcasted dimensions. for (int i = 0; i < operands.size(); ++i) { std::map broadcastedDims; auto shape = operands[i].getType().cast().getShape(); int size = shape.size(); for (int j = 0; j < shape.size(); ++j) { if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) { auto dim = rewriter.create(loc, operands[i], j).getResult(); auto one = rewriter.create(loc, 1); auto isBroadcasted = rewriter.create(loc, CmpIPredicate::eq, dim, one); broadcastedDims.insert(std::make_pair(j, isBroadcasted)); } } DimInfo.insert(std::make_pair(i, broadcastedDims)); } return DimInfo; } // Extract induction variables that are used for broadcasting values of a // given operand. std::vector getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, ArrayRef loopIVs, Value operand, std::map broadcastedDims) { // `operand` must has a ranked type. This should have been checked by the // shape inference pass. auto operandShape = operand.getType().cast().getShape(); auto rank = operandShape.size(); auto loopCount = loopIVs.size(); std::vector newLoopIVs; for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { auto dimIdx = rank - 1 - reversedIdx; auto loopIdx = loopCount - 1 - reversedIdx; if (operandShape[dimIdx] == 1) { // Broadcasted dimension auto zero = rewriter.create(loc, 0); newLoopIVs.insert(newLoopIVs.begin(), zero); } else if ((operandShape[dimIdx] == -1) && (broadcastedDims.find(dimIdx) != broadcastedDims.end())) { // Unknown dimension, it can have a value of 1 or N (N > 1). // If its value is 1, it is broadcasted dimension. // Otherwise, non-broadcasted dimension. auto zero = rewriter.create(loc, 0); auto idx = rewriter.create(loc, broadcastedDims[dimIdx], zero, loopIVs[loopIdx]); newLoopIVs.insert(newLoopIVs.begin(), idx); } else { // Non-broadcasted dimension newLoopIVs.insert(newLoopIVs.begin(), loopIVs[loopIdx]); } } return newLoopIVs; } namespace { template struct ScalarOp; template <> struct ScalarOp { using FOp = AddFOp; using IOp = AddIOp; }; template <> struct ScalarOp { using FOp = MulFOp; using IOp = MulIOp; }; template <> struct ScalarOp { using FOp = DivFOp; using IOp = SignedDivIOp; }; template <> struct ScalarOp { using FOp = SubFOp; using IOp = SubIOp; }; template <> struct ScalarOp { using FOp = AndOp; // not use using IOp = AndOp; }; template <> struct ScalarOp { using FOp = OrOp; // not use using IOp = OrOp; }; template <> struct ScalarOp { using FOp = XOrOp; // not use using IOp = XOrOp; }; template <> struct ScalarOp { using FOp = ExpOp; using IOp = ExpOp; // not use }; template <> struct ScalarOp { using FOp = AddFOp; using IOp = AddIOp; }; template <> struct ScalarOp { using FOp = TanhOp; using IOp = TanhOp; // not use }; template <> struct ScalarOp { using FOp = CosOp; using IOp = CosOp; // not use }; template <> struct ScalarOp { using FOp = LogOp; using IOp = LogOp; // not use }; template <> struct ScalarOp { using FOp = MulFOp; using IOp = MulIOp; }; template <> struct ScalarOp { using FOp = AddFOp; using IOp = AddIOp; }; template <> struct ScalarOp { using FOp = KrnlSqrtOp; using IOp = KrnlSqrtOp; // not use }; template using ScalarFOp = typename ScalarOp::FOp; template using ScalarIOp = typename ScalarOp::IOp; // Get the identity element of a operation. // Return NULL if the function does not have identity. template DataType getIdentityValue() { return NULL; } template <> float getIdentityValue(){ return (float)-std::numeric_limits::infinity(); } template <> int getIdentityValue(){ return std::numeric_limits::min(); } template <> float getIdentityValue(){ return (float)std::numeric_limits::infinity(); } template <> int getIdentityValue(){ return std::numeric_limits::max(); } template <> float getIdentityValue(){ return (float)1.0; } template <> int getIdentityValue(){ return 1; } template <> float getIdentityValue(){ return (float)0; } template <> int getIdentityValue(){ return 0; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering to Krnl dialect. //===----------------------------------------------------------------------===// template Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { /* Lower UnaryOp to Ops in the Standard dialect. */ auto loc = op->getLoc(); Type element_type = operands.front().getType(); if (element_type.isa()) { return rewriter.create>(loc, result_types, operands, mlir::None); } else if (element_type.isa()) { return rewriter.create>(loc, result_types, operands, mlir::None); } else { emitError(loc, "unsupported element type"); return nullptr; } } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSinhOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ConstantOp 2) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto two = rewriter.create(loc, FloatAttr::get(elementType, 2)); auto neg = rewriter.create(loc, zero, operand); auto exp = rewriter.create(loc, operand); auto negExp = rewriter.create(loc, neg); auto result = rewriter.create( loc, rewriter.create(loc, exp, negExp), two); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXCoshOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ConstantOp 2) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto two = rewriter.create(loc, FloatAttr::get(elementType, 2)); auto neg = rewriter.create(loc, zero, operand); auto exp = rewriter.create(loc, operand); auto negExp = rewriter.create(loc, neg); auto result = rewriter.create( loc, rewriter.create(loc, exp, negExp), two); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSigmoidOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1, // AddFOp(ConstantOp 1, ExpOp(NegFOp(%X)))) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto neg = rewriter.create(loc, zero, operand); auto negExp = rewriter.create(loc, neg); auto result = rewriter.create( loc, one, rewriter.create(loc, one, negExp)); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXHardSigmoidOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp( Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // %Y = AddFOp(MulFOp(alpha, %X), beta) // %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0), // %Y, // Constant 0) // ONNXHardSigmoidOp(%X) = SelectOp(CmpFOp(OLT, %Z, Constant 1), // %Z, // Constant 1) auto loc = op->getLoc(); Value operand = operands[0]; auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), llvm::dyn_cast(op).alpha().convertToFloat()); auto betaAttribute = FloatAttr::get(rewriter.getF32Type(), llvm::dyn_cast(op).beta().convertToFloat()); auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto alpha = rewriter.create(loc, alphaAttribute); auto beta = rewriter.create(loc, betaAttribute); auto add = rewriter.create( loc, rewriter.create(loc, alpha, operand), beta); auto maxPredicate = rewriter.create(loc, CmpFPredicate::OGT, add, zero); auto max = rewriter.create(loc, maxPredicate, add, zero); auto minPredicate = rewriter.create(loc, CmpFPredicate::OLT, max, one); auto result = rewriter.create(loc, minPredicate, max, one); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXEluOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // MulFOp(alpha, SubFOp(ExpOp(%X), 1)), // %X) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), llvm::dyn_cast(op).alpha().convertToFloat()); auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto alpha = rewriter.create(loc, alphaAttribute); auto exp = rewriter.create(loc, operand); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); auto result = rewriter.create( loc, lessThanZero, rewriter.create(loc, alpha, rewriter.create(loc, exp, one)), operand); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXReluOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // ConstantOp 0, // %X) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); auto result = rewriter.create(loc, lessThanZero, zero, operand); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXLeakyReluOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // MulFOp(alpha, %X), // %X) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), llvm::dyn_cast(op).alpha().convertToFloat()); auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto alpha = rewriter.create(loc, alphaAttribute); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); auto result = rewriter.create( loc, lessThanZero, rewriter.create(loc, alpha, operand), operand); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSeluOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0), // MulFOp(gamma, %X), // MulFOp(gamma, // SubFOp(MulFOp(alpha, ExpOp(%X)), // alpha))) auto loc = op->getLoc(); Value operand = operands[0]; auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), llvm::dyn_cast(op).alpha().convertToFloat()); auto gammaAttribute = FloatAttr::get(rewriter.getF32Type(), llvm::dyn_cast(op).gamma().convertToFloat()); auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto alpha = rewriter.create(loc, alphaAttribute); auto gamma = rewriter.create(loc, gammaAttribute); auto exp = rewriter.create(loc, operand); auto greaterThanZero = rewriter.create(loc, CmpFPredicate::OGT, operand, zero); auto select = rewriter.create( loc, greaterThanZero, operand, rewriter.create(loc, rewriter.create(loc, alpha, exp), alpha)); auto result = rewriter.create(loc, gamma, select); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXReciprocalOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp( Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto result = rewriter.create(loc, one, operand); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSoftplusOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp( Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXSoftplusOp(%X) = LogOp(AddFOp(ExpOp(%X), ConstantOp 1)) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto exp = rewriter.create(loc, operand); auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto add = rewriter.create(loc, exp, one); auto result = rewriter.create(loc, add); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSoftsignOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp( Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXSoftsignOp(%X) = DivFOp(ConstantOp 1, %X) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto abs = rewriter.create(loc, operand); auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto add = rewriter.create(loc, abs, one); auto result = rewriter.create(loc, operand, add); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSignOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { auto loc = op->getLoc(); Value operand = operands[0]; Type element_type = operands.front().getType(); // TODO: unsigned int should be supported separately? if (element_type.isa()) { // %Y = SelectOP(CmpIOp(GT, %X, ConstantOp 0), // ConstantOp 1, // COnstantOp -1) // ONNXSignOp(%X) = SelectOP(CmpIOp(EQ, %X, ConstantOp 0), // ConstantOp 0, // %Y) auto zero = rewriter.create(loc, rewriter.getI32IntegerAttr(0)); auto one = rewriter.create(loc, rewriter.getI32IntegerAttr(1)); auto minusOne = rewriter.create(loc, rewriter.getI32IntegerAttr(-1)); auto plusPredicate = rewriter.create(loc, CmpIPredicate::sgt, operand, zero); auto plusSelect = rewriter.create(loc, plusPredicate, one, minusOne); auto zeroPredicate = rewriter.create(loc, CmpIPredicate::eq, operand, zero); auto result = rewriter.create(loc, zeroPredicate, zero, plusSelect); return result; } else if (element_type.isa()) { // %Y = SelectOP(CmpFOp(OGT, %X, ConstantOp 0), // ConstantOp 1, // ConstantOp -1) // ONNXSignOp(%X) = SelectOP(CmpFOp(OEQ, %X, ConstantOp 0), // ConstantOp 0, // %Y) auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); auto one = rewriter.create(loc, rewriter.getF32FloatAttr(1.0f)); auto minusOne = rewriter.create(loc, rewriter.getF32FloatAttr(-1.0f)); auto plusPredicate = rewriter.create(loc, CmpFPredicate::OGT, operand, zero); auto plusSelect = rewriter.create(loc, plusPredicate, one, minusOne); auto zeroPredicate = rewriter.create(loc, CmpFPredicate::OEQ, operand, zero); auto result = rewriter.create(loc, zeroPredicate, zero, plusSelect); return result; } else { emitError(loc, "unsupported element type"); } } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXMaxOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y), // %X, // %Y) auto loc = op->getLoc(); Value lhs = operands[0]; Value rhs = operands[1]; auto max = rewriter.create(loc, CmpFPredicate::OGT, lhs, rhs); auto result = rewriter.create(loc, max, lhs, rhs); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXMinOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y), // %X, // %Y) auto loc = op->getLoc(); Value lhs = operands[0]; Value rhs = operands[1]; auto min = rewriter.create(loc, CmpFPredicate::OLT, lhs, rhs); auto result = rewriter.create(loc, min, lhs, rhs); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXReduceMaxOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { auto loc = op->getLoc(); Value lhs = operands[0]; Value rhs = operands[1]; Type element_type = lhs.getType(); if (element_type.isa()) { auto max = rewriter.create(loc, CmpIPredicate::sgt, lhs, rhs); auto result = rewriter.create(loc, max, lhs, rhs); return result; } else if (element_type.isa()) { auto max = rewriter.create(loc, CmpFPredicate::OGT, lhs, rhs); auto result = rewriter.create(loc, max, lhs, rhs); return result; } else { emitError(loc, "unsupported element type"); return nullptr; } } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXReduceMinOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { auto loc = op->getLoc(); Value lhs = operands[0]; Value rhs = operands[1]; Type element_type = lhs.getType(); if (element_type.isa()) { auto min = rewriter.create(loc, CmpIPredicate::slt, lhs, rhs); auto result = rewriter.create(loc, min, lhs, rhs); return result; } else if (element_type.isa()) { auto min = rewriter.create(loc, CmpFPredicate::OLT, lhs, rhs); auto result = rewriter.create(loc, min, lhs, rhs); return result; } else { emitError(loc, "unsupported element type"); return nullptr; } } // Element-wise unary ops lowering to Krnl dialect. //===----------------------------------------------------------------------===// template struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { ONNXElementwiseUnaryOpLowering(MLIRContext *ctx) : ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // TODO: Check that the types are valid. // An element-wise unary operation must have all operands and the result of // the same type. This should have been verified by the verifier. auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertTensorToMemRef(tensorType); // If the output has a dynamic dimension, pass the operands required for // each dynamic dimension to the AllocOp. The first operand of the // operation is used. The operands of the op need to match in terms of // dimensions with the result at this pre-optimization phase. // TODO: verify that dimensions match. // TODO: can the dimension of the result differ after optimizations? Value alloc; bool insertDealloc = checkInsertDealloc(op); if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, {operands[0]}); std::vector originalLoops; KrnlOptimizeLoopsOp optimizedLoopsOp; KrnlIterateOp iterateOp; emitKrnlLoopsAndIterationForOperand( rewriter, loc, operands[0], originalLoops, optimizedLoopsOp, iterateOp); Block &optimizationBlock = optimizedLoopsOp.region().front(); Block &iterationBlock = iterateOp.bodyRegion().front(); // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. rewriter.setInsertionPointToEnd(&optimizationBlock); // Return from KrnlOptimizeLoopsOp body. // When no optimizations are present we just return the loops // unchaged. rewriter.create(loc, originalLoops); // 2. Insert instructions inside the KernelIterateOp body. rewriter.setInsertionPointToStart(&iterationBlock); // Handle the operation: SmallVector loopIVs; for (auto arg : iterationBlock.getArguments()) loopIVs.push_back(arg); auto loadedVal = rewriter.create(loc, operands[0], loopIVs); auto loweredOpResult = mapToLowerScalarOp( op, memRefType.getElementType(), {loadedVal}, rewriter); // Store result in the resulting array. rewriter.create(loc, loweredOpResult, alloc, loopIVs); rewriter.replaceOp(op, alloc); return matchSuccess(); } }; // Element-wise variadic ops lowering to Krnl dialect. //===----------------------------------------------------------------------===// template struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { ONNXElementwiseVariadicOpLowering(MLIRContext *ctx) : ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // TODO: Check that the types are valid. // An element-wise variadic operation must have all operands and the result // of the same type. This should have been verified by the verifier. auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); auto numArgs = op->getNumOperands(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertTensorToMemRef(tensorType); Value alloc; bool insertDealloc = checkInsertDealloc(op); // If the output has a dynamic dimension, we compute its dimension at // runtime by using dimensions from the operands. // In particular, we need to know from which operand a result dimension // comes from. // TODO: can the dimension of the result differ after optimizations? if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, operands); // Get run-time dimension information for unknown dimensions used for // broadcasting. std::map> broadcastedDimInfo = getBroadcastedDimInfo(loc, rewriter, memRefType, operands); std::vector originalLoops; KrnlOptimizeLoopsOp optimizedLoopsOp; KrnlIterateOp iterateOp; emitKrnlLoopsAndIterationForOperand( rewriter, loc, alloc, originalLoops, optimizedLoopsOp, iterateOp); Block &optimizationBlock = optimizedLoopsOp.region().front(); Block &iterationBlock = iterateOp.bodyRegion().front(); // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. rewriter.setInsertionPointToEnd(&optimizationBlock); // Return from KrnlOptimizeLoopsOp body. // When no optimizations are present we just return the loops unchaged. rewriter.create(loc, originalLoops); // 2. Insert instructions inside the KernelIterateOp body. rewriter.setInsertionPointToStart(&iterationBlock); // Handle the operation: SmallVector loopIVs; for (auto arg : iterationBlock.getArguments()) loopIVs.push_back(arg); // Fold over operands for each of their scalar values Value accumulated, next; auto accumulatedLoopIVs = getLoopIVsForBroadcasting( loc, rewriter, loopIVs, operands[0], broadcastedDimInfo[0]); accumulated = rewriter.create(loc, operands[0], accumulatedLoopIVs); for (unsigned i = 1; i < numArgs; i++) { auto nextLoopIVs = getLoopIVsForBroadcasting( loc, rewriter, loopIVs, operands[i], broadcastedDimInfo[i]); next = rewriter.create(loc, operands[i], nextLoopIVs); accumulated = mapToLowerScalarOp( op, memRefType.getElementType(), {accumulated, next}, rewriter); } // Store result in the resulting array. rewriter.create(loc, accumulated, alloc, loopIVs); rewriter.replaceOp(op, alloc); return matchSuccess(); } }; struct ONNXSoftmaxOpLowering : public ConversionPattern { ONNXSoftmaxOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // softmax(x) = let max_x = max(x) in // let exp_x = exp(x - max_x) in // let sum = sum(exp_x) in // exp_x / sum auto tensorType = (*op->result_type_begin()).cast(); int64_t rank = tensorType.getRank(); int64_t axis = llvm::dyn_cast(op).axis().getSExtValue(); axis = axis >= 0 ? axis : rank + axis; assert(axis >= -rank && axis <= rank - 1); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertTensorToMemRef(tensorType); auto elementType = memRefType.getElementType(); Value alloc; bool insertDealloc = checkInsertDealloc(op); if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, operands[0]); // Shape of the result auto memRefShape = memRefType.getShape(); // Insert allocations and deallocations for sum and max. MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0); Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true); Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true); Value zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); Value negInfinity = rewriter.create( loc, FloatAttr::get(elementType, -std::numeric_limits::infinity())); // Define loops. std::vector originalLoops; std::vector optimizedLoops; Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, optimizedLoops, rank); // Coerce the input into a 2-D tensor. `axis` will be the coercing point. // This coercing follows the softmax definition in ONNX: // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softmax // Here, we create an outer loop and inner loop for handling the two // dimensions. The outer loop is only created once `axis` is not zero. // Define an outer loop with respect to axis. std::vector outerLoops, optimizedOuterLoops; outerLoops.reserve(axis); optimizedOuterLoops.reserve(axis); for (int i = 0; i < axis; ++i) { outerLoops.push_back(originalLoops[i]); optimizedOuterLoops.push_back(optimizedLoops[i]); } KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops); for (int i = 0; i < axis; ++i) addDimensionToPack(rewriter, loc, outerPack, operands[0], i); // Define an inner loop with respect to axis. std::vector innerLoops, optimizedInnerLoops; innerLoops.reserve(rank - axis); optimizedInnerLoops.reserve(rank - axis); for (int i = axis; i < rank; ++i) { innerLoops.push_back(originalLoops[i]); optimizedInnerLoops.push_back(optimizedLoops[i]); } KrnlIterateOperandPack innerPack(rewriter, innerLoops, optimizedInnerLoops); for (int i = axis; i < rank; ++i) addDimensionToPack(rewriter, loc, innerPack, operands[0], i); KrnlIterateOp outerIterateOp, maxIterateOp, sumIterateOp, softmaxIterateOp; SmallVector outerLoopIVs; if (axis != 0) { outerIterateOp = rewriter.create(loc, outerPack); // No optimization rewriter.setInsertionPointToEnd(optimizationBlock); rewriter.create(loc, originalLoops); // Insert instructions inside the outer loop. Block &outerIterationBlock = outerIterateOp.bodyRegion().front(); rewriter.setInsertionPointToStart(&outerIterationBlock); for (auto arg : outerIterationBlock.getArguments()) outerLoopIVs.push_back(arg); // Reset accumulators. rewriter.create(loc, zero, sumOp); rewriter.create(loc, negInfinity, maxOp); // Create an inner loop to compute max. maxIterateOp = rewriter.create(loc, innerPack); // Create an inner loop to compute sum. sumIterateOp = rewriter.create(loc, innerPack); // Create an inner loop to compute softmax. softmaxIterateOp = rewriter.create(loc, innerPack); } else { // Reset accumulators. rewriter.create(loc, zero, sumOp); rewriter.create(loc, negInfinity, maxOp); // Create an inner loop to compute max. maxIterateOp = rewriter.create(loc, innerPack); // Create an inner loop to compute sum. sumIterateOp = rewriter.create(loc, innerPack); // Create an inner loop to compute softmax. softmaxIterateOp = rewriter.create(loc, innerPack); // No optimization rewriter.setInsertionPointToEnd(optimizationBlock); rewriter.create(loc, originalLoops); } // Insert instructions inside the max loop. Block &maxIterationBlock = maxIterateOp.bodyRegion().front(); rewriter.setInsertionPointToStart(&maxIterationBlock); // Get induction variables. SmallVector maxLoopIVs; for (auto arg : outerLoopIVs) maxLoopIVs.push_back(arg); for (auto arg : maxIterationBlock.getArguments()) maxLoopIVs.push_back(arg); // Compute the max value. Value max = rewriter.create(loc, maxOp); Value nextMax = rewriter.create(loc, operands[0], maxLoopIVs); auto maxCond = rewriter.create(loc, CmpFPredicate::OGT, max, nextMax); max = rewriter.create(loc, maxCond, max, nextMax); rewriter.create(loc, max, maxOp); // Get the max. rewriter.setInsertionPoint(sumIterateOp); max = rewriter.create(loc, maxOp); // Insert instructions inside the sum loop. Block &sumIterationBlock = sumIterateOp.bodyRegion().front(); rewriter.setInsertionPointToStart(&sumIterationBlock); // Get induction variables. SmallVector sumLoopIVs; for (auto arg : outerLoopIVs) sumLoopIVs.push_back(arg); for (auto arg : sumIterationBlock.getArguments()) sumLoopIVs.push_back(arg); // Sum up values. Value sum = rewriter.create(loc, sumOp); Value next = rewriter.create(loc, operands[0], sumLoopIVs); Value sub = rewriter.create(loc, next, max); Value exp = rewriter.create(loc, sub); sum = rewriter.create(loc, sum, exp); rewriter.create(loc, sum, sumOp); // Store intermediate values in the result to avoid recomputation. rewriter.create(loc, exp, alloc, sumLoopIVs); // Get the sum. rewriter.setInsertionPoint(softmaxIterateOp); sum = rewriter.create(loc, sumOp); // Insert instructions inside the softmax loop. Block &softmaxIterationBlock = softmaxIterateOp.bodyRegion().front(); rewriter.setInsertionPointToStart(&softmaxIterationBlock); // Get induction variables. SmallVector softmaxLoopIVs; for (auto arg : outerLoopIVs) softmaxLoopIVs.push_back(arg); for (auto arg : softmaxIterationBlock.getArguments()) softmaxLoopIVs.push_back(arg); // Compute softmax. Value expLoadedVal = rewriter.create(loc, alloc, softmaxLoopIVs); Value result = rewriter.create(loc, expLoadedVal, sum); rewriter.create(loc, result, alloc, softmaxLoopIVs); rewriter.replaceOp(op, alloc); return matchSuccess(); } }; struct ONNXReshapeOpLowering : public ConversionPattern { ONNXReshapeOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto tensorType = (*op->result_type_begin()).cast(); auto inputShape = operands[0].getType().cast().getShape(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertTensorToMemRef(tensorType); auto memRefShape = memRefType.getShape(); Value alloc; // Compute size in bytes using the input tensor. Value tensorSize = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType))); for (int i = 0; i < inputShape.size(); ++i) { Value dimVal; if (inputShape[i] < 0) { Value dim = rewriter.create(loc, operands[0], i); dimVal = rewriter.create(loc, dim, rewriter.getIntegerType(64)); } else { dimVal = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), inputShape[i])); } tensorSize = rewriter.create(loc, tensorSize, dimVal); } bool insertDealloc = checkInsertDealloc(op); if (hasAllConstantDimensions(memRefType)) { alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); } else { // If a dimension is zero, the actual dimension value is taken from the // input tensor. // // If the shape array has a negative dimension (-1), we compute its actual // dimension value from the other dimensions. But we don't have enough // information about the other dimensions at this point. So, we need to // scan the shape first to calculate reduction of all of the dimensions. // If the reduction is negative, then the shape array contains a negative // dimension. Otherwise, the reduction is the same as the one computed // from the input tensor. Value tensorSizeFromShape = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType))); SmallVector DimInfo; for (int i = 0; i < memRefShape.size(); ++i) { Value index = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i)); // Load index from array of indices. Value loadedVal = rewriter.create(loc, operands[1], index); // If a dimension is zero, the actual dimension value is taken from the // input tensor. // // If a dimension is negative, it is computed from the other dimensions. // But we don't have enough information about the other dimensions at // this point. So, we let it as it is (-1), and compute it later. if (i < inputShape.size()) { Value dimVal; auto loadedValType = loadedVal.getType().cast(); if (inputShape[i] < 0) { Value dim = rewriter.create(loc, operands[0], i); dimVal = rewriter.create(loc, dim, loadedValType); } else { dimVal = rewriter.create( loc, rewriter.getIntegerAttr(loadedValType, inputShape[i])); } auto zero = rewriter.create( loc, rewriter.getIntegerAttr(loadedValType, 0)); auto isZero = rewriter.create(loc, CmpIPredicate::eq, loadedVal, zero); loadedVal = rewriter.create(loc, isZero, dimVal, loadedVal); } // Check if the loaded index is already the correct width of 64 bits. // Convert the value to a 64 bit integer if needed. Value int64LoadedVal = loadedVal; if (loadedVal.getType().cast().getWidth() < 64) int64LoadedVal = rewriter.create( loc, loadedVal, rewriter.getIntegerType(64)); tensorSizeFromShape = rewriter.create(loc, tensorSizeFromShape, int64LoadedVal); // Store intermediate results to use later. DimInfo.emplace_back(int64LoadedVal); } // Reverse tensorSizeFromShape since it is negative if the shape array has // a negative dimension. This is safe since we only use it to compute the // actual value for the negative dimension. auto zero = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); tensorSizeFromShape = rewriter.create(loc, zero, tensorSizeFromShape); // Obtain operands for AllocOp. SmallVector allocOperands; auto negOne = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), -1)); for (int i = 0; i < memRefShape.size(); ++i) { auto dimVal = DimInfo[i]; auto isNegOne = rewriter.create(loc, CmpIPredicate::eq, dimVal, negOne); // If dimension is negative, compute its value from the other // dimensions. auto actualDimVal = rewriter.create(loc, tensorSize, tensorSizeFromShape); auto loadedVal = rewriter.create(loc, isNegOne, actualDimVal, dimVal); allocOperands.push_back(rewriter.create( loc, loadedVal, rewriter.getIndexType())); } AllocOp allocateMemref = rewriter.create(loc, memRefType, allocOperands); // Make sure to allocate at the beginning of the block if // all dimensions are known. auto *parentBlock = allocateMemref.getOperation()->getBlock(); if (insertDealloc) { auto dealloc = rewriter.create(loc, allocateMemref); dealloc.getOperation()->moveBefore(&parentBlock->back()); } alloc = allocateMemref; } rewriter.create(loc, alloc, operands[0], tensorSize); rewriter.replaceOp(op, alloc); return matchSuccess(); } }; struct ONNXGemmOpLowering : public ConversionPattern { ONNXGemmOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXGemmOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); Value A, B, C; A = operands[0]; B = operands[1]; C = operands[2]; auto alphaAttr = FloatAttr::get(tensorType.getElementType(), llvm::dyn_cast(op).alpha().convertToFloat()); auto betaAttr = FloatAttr::get(tensorType.getElementType(), llvm::dyn_cast(op).beta().convertToFloat()); auto alpha = rewriter.create(loc, alphaAttr); auto beta = rewriter.create(loc, betaAttr); bool isTransA = (llvm::dyn_cast(op).transA() != 0); bool isTransB = (llvm::dyn_cast(op).transB() != 0); // Result type auto memRefType = convertTensorToMemRef(tensorType); // Insert an allocation and deallocation for the result of this operation. Value alloc; bool insertDealloc = checkInsertDealloc(op); if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else { auto memRefShape = memRefType.getShape(); SmallVector allocOperands; if (memRefShape[0] < 0) { auto dim = rewriter.create(loc, A, (isTransA) ? 1 : 0); allocOperands.emplace_back(dim); } if (memRefShape[1] < 0) { auto dim = rewriter.create(loc, B, (isTransB) ? 0 : 1); allocOperands.emplace_back(dim); } alloc = rewriter.create(loc, memRefType, allocOperands); if (insertDealloc) { auto *parentBlock = alloc.getDefiningOp()->getBlock(); auto dealloc = rewriter.create(loc, alloc); dealloc.getOperation()->moveBefore(&parentBlock->back()); } } // Number of loops auto memRefShape = memRefType.getShape(); int64_t numLoops = 3; // Define loops. std::vector originalLoops; std::vector optimizedLoops; Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, optimizedLoops, numLoops); // We have two Krnl loops: // - Outer loop iterates over the output matrix dimensions, and // - Reduction loop iterates over the reduction dimension. // Outer loop std::vector outerLoops, optimizedOuterLoops; outerLoops.reserve(2); optimizedOuterLoops.reserve(2); for (int i = 0; i < 2; ++i) { outerLoops.push_back(originalLoops[i]); optimizedOuterLoops.push_back(optimizedLoops[i]); } KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops); // Induction variables for the outer loops for (int i = 0; i < 2; ++i) addDimensionToPack(rewriter, loc, outerPack, alloc, i); // Reduction loop std::vector reductionLoops, optimizedReductionLoops; reductionLoops.reserve(1); optimizedReductionLoops.reserve(1); reductionLoops.push_back(originalLoops[2]); optimizedReductionLoops.push_back(optimizedLoops[2]); KrnlIterateOperandPack reductionPack(rewriter, reductionLoops, optimizedReductionLoops); // Induction variable for the reduction dimension // Try to find and use a static value from A or B first. // If it failed then use a dynamic value. auto ATy = A.getType().cast(); auto BTy = B.getType().cast(); int64_t K_A_Idx = (isTransA) ? 0 : 1; int64_t K_B_Idx = (isTransB) ? 1 : 0; reductionPack.pushConstantBound(0); if (ATy.getShape()[K_A_Idx] != -1) reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]); else if (BTy.getShape()[K_B_Idx] != -1) reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]); else reductionPack.pushOperandBound( rewriter.create(loc, B, K_B_Idx).getResult()); // Get run-time dimension information for unknown dimensions used for // broadcasting. // GemmOp supports unidirectional broadcasting from C to A*B. // Hence, it must be enough to get broadcasting information for C only. std::map broadcastedDimInfo; auto shape = C.getType().cast().getShape(); for (int i = 0; i < shape.size(); ++i) { if (shape[i] < 0) { auto dim = rewriter.create(loc, C, i).getResult(); auto one = rewriter.create(loc, 1); auto isBroadcasted = rewriter.create(loc, CmpIPredicate::eq, dim, one); broadcastedDimInfo.insert(std::make_pair(i, isBroadcasted)); } } auto outerIterateOp = rewriter.create(loc, outerPack); // Now perform the insertions into the body of the // just generated instructions: // No optimization rewriter.setInsertionPointToEnd(optimizationBlock); rewriter.create(loc, originalLoops); // Insert instructions inside the outer loop. Block &outerIterationBlock = outerIterateOp.bodyRegion().front(); rewriter.setInsertionPointToStart(&outerIterationBlock); // Induction variables SmallVector loopMNIVs; for (auto arg : outerIterationBlock.getArguments()) { loopMNIVs.emplace_back(arg); } // Initialize the output of A*B auto zero = rewriter.create( loc, FloatAttr::get(memRefType.getElementType(), 0)); rewriter.create(loc, zero, alloc, loopMNIVs); // Compute A*B auto matmulIterateOp = rewriter.create(loc, reductionPack); // Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting) auto loopCIVs = getLoopIVsForBroadcasting( loc, rewriter, loopMNIVs, C, broadcastedDimInfo); auto loadedC = rewriter.create(loc, C, loopCIVs); auto loadedAB = rewriter.create(loc, alloc, loopMNIVs); auto alphaAB = rewriter.create(loc, alpha, loadedAB); auto betaC = rewriter.create(loc, beta, loadedC); auto Y = rewriter.create(loc, alphaAB, betaC); rewriter.create(loc, Y, alloc, loopMNIVs); // Insert instructions to do matrix multiplication: A*B Block &matmulIterationBlock = matmulIterateOp.bodyRegion().front(); rewriter.setInsertionPointToStart(&matmulIterationBlock); // Induction variables SmallVector loopKIVs, loopAIVs, loopBIVs; for (auto arg : matmulIterationBlock.getArguments()) loopKIVs.emplace_back(arg); if (isTransA) { loopAIVs.emplace_back(loopKIVs[0]); loopAIVs.emplace_back(loopMNIVs[0]); } else { loopAIVs.emplace_back(loopMNIVs[0]); loopAIVs.emplace_back(loopKIVs[0]); } if (isTransB) { loopBIVs.emplace_back(loopMNIVs[1]); loopBIVs.emplace_back(loopKIVs[0]); } else { loopBIVs.emplace_back(loopKIVs[0]); loopBIVs.emplace_back(loopMNIVs[1]); } // Matmul computation auto loadedA = rewriter.create(loc, A, loopAIVs); auto loadedB = rewriter.create(loc, B, loopBIVs); auto loadedY = rewriter.create(loc, alloc, loopMNIVs); auto AB = rewriter.create(loc, loadedA, loadedB); auto accumulated = rewriter.create(loc, loadedY, AB); rewriter.create(loc, accumulated, alloc, loopMNIVs); rewriter.replaceOp(op, alloc); return matchSuccess(); } }; struct ONNXUnsqueezeOpLowering : public ConversionPattern { ONNXUnsqueezeOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); auto tensorType = (*op->result_type_begin()).cast(); int outRank = tensorType.getRank(); // Assume that `axes` has been validated by shape inference. // So, here we just get it. ArrayAttr axisAttrs = llvm::dyn_cast(op).axesAttr(); SmallVector axes; for (auto axisAttr : axisAttrs.getValue()) { int axis = axisAttr.cast().getInt(); axis = axis >= 0 ? axis : (outRank + axis); axes.emplace_back(axis); } // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertTensorToMemRef(tensorType); Value alloc; // Compute size in bytes. Value tensorSize = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType))); bool insertDealloc = checkInsertDealloc(op); auto memRefShape = memRefType.getShape(); if (hasAllConstantDimensions(memRefType)) { alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); for (int i = 0; i < memRefShape.size(); ++i) { Value dimVal = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), memRefShape[i])); tensorSize = rewriter.create(loc, tensorSize, dimVal); } } else { // Unknown dimensions are always the operand's dimensions. SmallVector allocOperands; for (int outIdx = 0, inIdx = 0; outIdx < memRefShape.size(); ++outIdx) { Value dimVal = nullptr; if (memRefShape[outIdx] < 0) { Value index = rewriter.create(loc, operands[0], inIdx); dimVal = rewriter.create( loc, index, rewriter.getIntegerType(64)); allocOperands.emplace_back(index); } else { dimVal = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), memRefShape[outIdx])); } tensorSize = rewriter.create(loc, tensorSize, dimVal); if (std::find(axes.begin(), axes.end(), outIdx) == axes.end()) inIdx++; } alloc = rewriter.create(loc, memRefType, allocOperands); auto *parentBlock = alloc.getDefiningOp()->getBlock(); if (insertDealloc) { auto dealloc = rewriter.create(loc, alloc); dealloc.getOperation()->moveBefore(&parentBlock->back()); } } rewriter.create(loc, alloc, operands[0], tensorSize); rewriter.replaceOp(op, alloc); return matchSuccess(); } }; struct ONNXTransposeOpLowering : public ConversionPattern { ONNXTransposeOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertTensorToMemRef(tensorType); Value alloc; bool insertDealloc = checkInsertDealloc(op); if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, {operands[0]}); // Number of loops auto memRefShape = memRefType.getShape(); int64_t rank = memRefShape.size(); // Define loops. std::vector originalLoops; std::vector optimizedLoops; Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, optimizedLoops, rank); KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); // Iterate over the loop nest using the input shape. for (int i = 0; i < rank; ++i) addDimensionToPack(rewriter, loc, pack, operands[0], i); auto iterateOp = rewriter.create(loc, pack); Block &iterationBlock = iterateOp.bodyRegion().front(); // Now perform the insertions into the body of the // just generated instructions: // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. rewriter.setInsertionPointToEnd(optimizationBlock); // Return from KrnlOptimizeLoopsOp body. // When no optimizations are present we just return the loops // unchaged. rewriter.create(loc, originalLoops); // 2. Insert instructions inside the KernelIterateOp body. rewriter.setInsertionPointToStart(&iterationBlock); // Handle the operation. // Read perm attribute. SmallVector perm; auto permAttribute = llvm::dyn_cast(op).permAttr(); if (permAttribute) { for (auto permVal : permAttribute.getValue()) perm.emplace_back(permVal.cast().getInt()); } else { // TODO: Remove when perm is guaranteed to be present (even for // the default case). This means that perm was added by shape // inference or another pass to contain the values corresponding // to the default behavior of Transpose. for (int i = iterationBlock.getArguments().size()-1; i >= 0; i--) perm.emplace_back(i); } SmallVector inLoopIVs; for (auto arg : iterationBlock.getArguments()) inLoopIVs.emplace_back(arg); SmallVector outLoopIVs; for (int i=0; i(loc, operands[0], inLoopIVs); rewriter.create(loc, inVal, alloc, outLoopIVs); rewriter.replaceOp(op, alloc); return matchSuccess(); } }; struct ONNXIdentityOpLowering : public ConversionPattern { ONNXIdentityOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { rewriter.replaceOp(op, operands[0]); return matchSuccess(); } }; struct ONNXConvNoBiasOpLowering : public ConversionPattern { ONNXConvNoBiasOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertTensorToMemRef(tensorType); Value alloc; bool insertDealloc = checkInsertDealloc(op); ONNXConvNoBiasOp convOp = llvm::dyn_cast(op); if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, {operands[0]}); auto resultShape = memRefType.getShape(); auto inputShape = operands[0].getType().cast().getShape(); auto kernelShape = operands[1].getType().cast().getShape(); // R = ConvNoBias(D, K) // // The input/output shapes will look like this: // // D (NxCxHxW) x K (MxC/groupxKHxKW) -> R (NxMxRHxRW) // // M is a multiple of the number of groups: // M = group * kernelsPerGroup // // The loop nest will look as follows: // // kernelsPerGroup = M / group; // for n = 0 .. N: // for g = 0 .. group: // for m = 0 .. kernelsPerGroup: // kernel = g * kernelsPerGroup + m; // for r1 = 0 .. RH: // for r2 = 0 .. RW: // R[n][kernel][r1][r2] = 0; // for c = 0 .. C/group: // for k1 = 0 .. KH: // for k2 = 0 .. KW: // R[n][kernel][r1][r2] = // D[n][g * (C / group) + c][r1 + k1][r2 + k2] * // K[kernel][c][k1][k2]; // // TODO: handle padding. // // In the general case: // // D (NxCxD1xD2x...xDdim) x K (MxC/groupxK1xK2x...xKdim) // -> R (NxMxR1xR2x...xRdim) // // The above loop nest can be adapted by increasing the number // of r- and k-index loop i.e. r1 r2 and k1 k2 loops. // Set up outermost loops: n g m r1 r2 ... rdim // Skip g if group is 1. // Before we start the iteration we need to compute the number of // unsplit kernels and fetch the number of groups from the attribute // list. Group is always a compilation constant. int64_t group = convOp.group().getSExtValue(); // Compute the number of unsplit kernels. The number of kernels // must be a multiple of the number of groups. int64_t kernelsPerGroup = floor(kernelShape[0] / group); auto kernelsPerGroupValue = rewriter.create(loc, kernelsPerGroup); auto zero = rewriter.create( loc, FloatAttr::get(memRefType.getElementType(), 0)); Value subchannels; if (kernelShape[1] < 0) { subchannels = rewriter.create(loc, operands[1], 1).getResult(); } else { subchannels = rewriter.create( loc, kernelShape[1]); } // 1. Define outer loops and emit empty optimization block: int64_t nOuterLoops = (group > 1) ? 3 : 2; std::vector outerLoops; std::vector optimizedOuterLoops; Block *optimizationBlock = defineLoops(rewriter, loc, outerLoops, optimizedOuterLoops, nOuterLoops); // Prepare iteration arguments over outer loop nest. KrnlIterateOperandPack pack( rewriter, outerLoops, optimizedOuterLoops); // for n = 0 .. N: pack.pushConstantBound(0); if (inputShape[0] < 0) pack.pushOperandBound( rewriter.create(loc, operands[0], 0).getResult()); else pack.pushConstantBound(inputShape[0]); // for g = 0 .. N: if (group > 1) { pack.pushConstantBound(0); pack.pushConstantBound(group); } // for m = 0 .. kernelsPerGroup: pack.pushConstantBound(0); pack.pushConstantBound(kernelsPerGroup); // Outer loop iteration. auto iterateOp = rewriter.create(loc, pack); Block &outerIterationBlock = iterateOp.bodyRegion().front(); // Emit optimizations for outer loops: rewriter.setInsertionPointToEnd(optimizationBlock); rewriter.create(loc, outerLoops); rewriter.setInsertionPointToStart(&outerIterationBlock); { // 2. Emit the body of the outer loop nest. // 2.1 Compute kernel order number: kernel = g * kernelsPerGroup + m; // If group is not set then the value of the kernel ID is // identical to that of the loop over kernels. Value kernel = outerIterationBlock.getArguments()[1]; if (group > 1) { // Middle loop is over groups and third loop is over the // kernel identifiers in the current group. auto kernelsOffset = rewriter.create(loc, outerIterationBlock.getArguments()[1], kernelsPerGroupValue); kernel = rewriter.create(loc, kernelsOffset, outerIterationBlock.getArguments()[2]); } // 2.2 Define spatial loops int64_t nSpatialLoops = resultShape.size() - 2; std::vector spatialLoops; std::vector optimizedSpatialLoops; Block *optSpatialLoopBlock = defineLoops(rewriter, loc, spatialLoops, optimizedSpatialLoops, nSpatialLoops); // 2.3 Prepare iteration arguments for spatial loop nest. KrnlIterateOperandPack spatialPack( rewriter, spatialLoops, optimizedSpatialLoops); for (int i = 2; i < resultShape.size(); ++i) addDimensionToPack(rewriter, loc, spatialPack, alloc, i); // 2.4 Emit loop nest over output spatial dimensions. // for rX = 0 .. RX auto spatialIterateOp = rewriter.create(loc, spatialPack); Block &spatialIterationBlock = spatialIterateOp.bodyRegion().front(); // 2.5 Emit optimizations for outer loops: rewriter.setInsertionPointToEnd(optSpatialLoopBlock); rewriter.create(loc, spatialLoops); rewriter.setInsertionPointToStart(&spatialIterationBlock); { // 3. Emit the body of the spatial loop nest. // 3.1 Emit: R[n][kernel][r1][r2] = 0; SmallVector resultIndices; // n resultIndices.emplace_back(outerIterationBlock.getArguments()[0]); // kernel resultIndices.emplace_back(kernel); // rX for (auto arg : spatialIterationBlock.getArguments()) resultIndices.emplace_back(arg); // Store initializer value into output location. rewriter.create(loc, zero, alloc, resultIndices); // 3.2 Define inner loops. int64_t nInnerLoops = 1 + (kernelShape.size() - 2); std::vector innerLoops; std::vector optimizedInnerLoops; Block *optInnerLoopBlock = defineLoops(rewriter, loc, innerLoops, optimizedInnerLoops, nInnerLoops); // 3.3 Prepare iteration arguments for inner loop nest. KrnlIterateOperandPack innerPack( rewriter, innerLoops, optimizedInnerLoops); // for c = 0 .. C/group innerPack.pushConstantBound(0); innerPack.pushConstantBound(kernelShape[1]); // for Kx = 0 .. KX for (int i = 2; i < kernelShape.size(); ++i) addDimensionToPack(rewriter, loc, innerPack, operands[1], i); // 3.4 Emit inner loop nest. auto innerIterateOp = rewriter.create(loc, innerPack); Block &innerIterationBlock = innerIterateOp.bodyRegion().front(); // 3.5 Emit optimizations for outer loops: rewriter.setInsertionPointToEnd(optInnerLoopBlock); rewriter.create(loc, innerLoops); rewriter.setInsertionPointToStart(&innerIterationBlock); { // 4. Emit inner loop body // R[n][kernel][r1][r2] = // D[n][g * (C / group) + c][r1 + k1][r2 + k2] * // K[kernel][c][k1][k2]; // 4.1 Prepare indices for accesing the data tensor. SmallVector dataIndices; // n dataIndices.emplace_back(outerIterationBlock.getArguments()[0]); // g * (C / group) + c Value channelDepth = innerIterationBlock.getArguments()[0]; if (group > 1) channelDepth = rewriter.create(loc, channelDepth, rewriter.create(loc, subchannels, outerIterationBlock.getArguments()[1])); dataIndices.emplace_back(channelDepth); // rX + kX for (int i = 0; i < kernelShape.size() - 2; ++i) dataIndices.emplace_back( rewriter.create(loc, spatialIterationBlock.getArguments()[i], innerIterationBlock.getArguments()[i+1])); // 4.2 Prepare indices for accessing the kernel tensor. SmallVector kernelIndices; // kernel kernelIndices.emplace_back(kernel); // c kernelIndices.emplace_back(innerIterationBlock.getArguments()[0]); // kX for (int i = 0; i < kernelShape.size() - 2; ++i) kernelIndices.emplace_back( innerIterationBlock.getArguments()[i+1]); // 4.3 Compute convolution. auto loadData = rewriter.create(loc, operands[0], dataIndices); auto loadKernel = rewriter.create(loc, operands[1], kernelIndices); auto loadPartialSum = rewriter.create(loc, alloc, resultIndices); Value result = rewriter.create(loc, loadPartialSum, rewriter.create(loc, loadData, loadKernel)); // 4.4 Store computed value into output location. rewriter.create(loc, result, alloc, resultIndices); } } } rewriter.replaceOp(op, alloc); return matchSuccess(); } }; //===----------------------------------------------------------------------===// // Reduction ops lowering to Krnl dialect. //===----------------------------------------------------------------------===// template struct ONNXReductionOpLowering : public ConversionPattern { ONNXReductionOpLowering(MLIRContext *ctx) : ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { /* * Condition: reduction function must be associative and commutative. * * Example 1 (here, reduction function is `+`): * Induction variables: (i0, i1, i2) * axes = [0, 2] * keepdims = true * krnl.iterate() with (i0, i1, i2) { * Y(0, i1, 0) += X(i0, i1, i2) * } * * Example 2 (here, reduction function is `+`): * Induction variables: (i0, i1, i2) * axes = [0, 2] * keepdims = false * krnl.iterate() with (i0, i1, i2) { * Y(i1) += X(i0, i1, i2) * } * */ auto loc = op->getLoc(); auto memRefInType = operands[0].getType().cast(); auto memRefInShape = memRefInType.getShape(); auto tensorOutType = (*op->result_type_begin()).cast(); int64_t inRank = memRefInType.getRank(); int64_t outRank = tensorOutType.getRank(); // Get attributes ArrayAttr axisAttrs = llvm::dyn_cast(op).axesAttr(); std::vector axes; if (axisAttrs) { for (auto axisAttr : axisAttrs.getValue()) { int64_t axis = axisAttr.cast().getInt(); axis = axis >= 0 ? axis : (inRank + axis); assert(axis >= -inRank && axis <= inRank - 1); if (std::find(axes.begin(), axes.end(), axis) == axes.end()) axes.push_back(axis); } } else { for (decltype(inRank) i = 0; i < inRank; ++i) { axes.push_back(i); } } // KeepDims auto keepdims = llvm::dyn_cast(op).keepdims(); bool isKeepdims = (keepdims == 1) ? true : false; // Get type information auto memRefOutType = convertTensorToMemRef(tensorOutType); auto memRefOutShape = memRefOutType.getShape(); auto elementOutType = memRefOutType.getElementType(); std::map outInDimMap = getReductionMapping(memRefInType, axes, isKeepdims); // Insert an allocation and deallocation for the result of this operation. Value alloc; bool insertDealloc = checkInsertDealloc(op); if (hasAllConstantDimensions(memRefOutType)) { alloc = insertAllocAndDealloc(memRefOutType, loc, rewriter, insertDealloc); } else { SmallVector allocOperands; for (decltype(outRank) i = 0; i < outRank; ++i) { if (memRefOutShape[i] < 0) { auto dim = rewriter.create(loc, operands[0], outInDimMap[i]); allocOperands.push_back(dim); } } alloc = rewriter.create(loc, memRefOutType, allocOperands); if (insertDealloc) { auto *parentBlock = alloc.getDefiningOp()->getBlock(); auto dealloc = rewriter.create(loc, alloc); dealloc.getOperation()->moveBefore(&parentBlock->back()); } } // There are two Krnl loops: // - One to initialize the result memref, and // - One to do reduction // Define loops to initialize the result. std::vector originalLoopsInit; std::vector optimizedLoopsInit; Block *optimizationBlockInit = defineLoops(rewriter, loc, originalLoopsInit, optimizedLoopsInit, outRank); // Iteration information KrnlIterateOperandPack packInit(rewriter, originalLoopsInit, optimizedLoopsInit); for (decltype(outRank) i = 0; i < outRank; ++i) { addDimensionToPack(rewriter, loc, packInit, alloc, i); } auto iterateOpInit = rewriter.create(loc, packInit); Block &iterationBlockInit = iterateOpInit.bodyRegion().front(); // Perform the insertions into the body of the initialization loop. // No optimization rewriter.setInsertionPointToEnd(optimizationBlockInit); rewriter.create(loc, originalLoopsInit); // Insert instructions inside the KernelIterateOp body. rewriter.setInsertionPointToStart(&iterationBlockInit); // Handle the operation: SmallVector loopIVs; for (auto arg : iterationBlockInit.getArguments()) { loopIVs.push_back(arg); } Value identity; if (elementOutType.isa()) { identity = rewriter.create( loc, FloatAttr::get(elementOutType, getIdentityValue())); } else if (elementOutType.isa()) { identity = rewriter.create( loc, IntegerAttr::get(elementOutType, getIdentityValue())); } else { emitError(loc, "unsupported element type"); } rewriter.create(loc, identity, alloc, loopIVs); // Define an Krnl loop to do reduction. rewriter.setInsertionPointAfter(iterateOpInit); std::vector originalLoops, optimizedLoops; Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, optimizedLoops, inRank); // Iteration information KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); for (decltype(inRank) i = 0; i < inRank; ++i) { addDimensionToPack(rewriter, loc, pack, operands[0], i); } auto iterateOp = rewriter.create(loc, pack); Block &iterationBlock = iterateOp.bodyRegion().front(); // Perform the insertions into the body of the reduction loop. // No optimization rewriter.setInsertionPointToEnd(optimizationBlock); rewriter.create(loc, originalLoops); // Insert instructions inside the KernelIterateOp body. rewriter.setInsertionPointToStart(&iterationBlock); // Handle the operation: SmallVector inLoopIVs, outLoopIVs; auto args = iterationBlock.getArguments(); for (int i = 0; i < args.size(); ++i) { inLoopIVs.push_back(args[i]); } Value zeroIndex = nullptr; for (decltype(inRank) i = 0; i < outRank; ++i) { if (outInDimMap.find(i) != outInDimMap.end()) { outLoopIVs.push_back(inLoopIVs[outInDimMap[i]]); } else { if (zeroIndex) { outLoopIVs.push_back(zeroIndex); } else { zeroIndex = rewriter.create(loc, 0); outLoopIVs.push_back(zeroIndex); } } } Value next, accumulated; next = rewriter.create(loc, operands[0], inLoopIVs); accumulated = rewriter.create(loc, alloc, outLoopIVs); accumulated = mapToLowerScalarOp( op, memRefOutType.getElementType(), {accumulated, next}, rewriter); rewriter.create(loc, accumulated, alloc, outLoopIVs); rewriter.replaceOp(op, alloc); return matchSuccess(); } }; //===----------------------------------------------------------------------===// // EntryPoint Op lowering to Krnl Entry Point. //===----------------------------------------------------------------------===// class ONNXEntryPointLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(ONNXEntryPointOp op, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, op.getAttrOfType( ONNXEntryPointOp::getEntryPointFuncAttrName()), op.getAttrOfType(ONNXEntryPointOp::getNumInputsAttrName()), op.getAttrOfType( ONNXEntryPointOp::getNumOutputsAttrName())); return matchSuccess(); } }; //===----------------------------------------------------------------------===// // Conversion from Tensor type to the Standard dialect MemRef type. //===----------------------------------------------------------------------===// struct TensorTypeConverter : public TypeConverter { using TypeConverter::TypeConverter; LogicalResult convertType(Type t, SmallVectorImpl &results) override { if (auto tensor_type = t.dyn_cast()) { results.push_back(convertTensorToMemRef(tensor_type)); return success(); } results.push_back(t); return success(); } /// Return true if the inputs and outputs of the given function type are /// legal. [Taken from MLIR and adapted to only check the legality of the /// inputs. Once unranked results can be handled gracefully this /// override needs to be removed in favour of the original MLIR one.] bool isSignatureLegal(FunctionType funcType) { return llvm::all_of(funcType.getInputs(), [this](Type type) { return isLegal(type); }); } }; } // end anonymous namespace. //===----------------------------------------------------------------------===// // Frontend to Krnl Dialect lowering pass //===----------------------------------------------------------------------===// /// This is a partial lowering to Krnl loops of the ONNX operations. namespace { struct FrontendToKrnlLoweringPass : public ModulePass { void runOnModule() final; }; } // end anonymous namespace. void FrontendToKrnlLoweringPass::runOnModule() { auto module = getModule(); // The first thing to define is the conversion target. This will define the // final target for this lowering. ConversionTarget target(getContext()); // We define the specific operations, or dialects, that are legal targets for // this lowering. target .addLegalDialect(); // TODO: enable this once more ops are supported. // We also define the ONNX dialect as Illegal so that the conversion will fail // if any of these operations are *not* converted. // target.addIllegalDialect(); // TODO: add any other ops which are considered legal. // Some operations can be marked as being still legal. // Example: target.addLegalOp(); // Now that the conversion target has been defined, we just need to provide // the set of patterns that will lower the frontend operations. OwningRewritePatternList patterns; // Convert TensorType to MemRef TensorTypeConverter tensor_to_memref_converter; target.addDynamicallyLegalOp([&](FuncOp op) { // FuncOp is legal only if types have been converted to Std types. return tensor_to_memref_converter.isSignatureLegal(op.getType()); }); // Type conversion for function signatures. // Call MLIR FuncOp signature conversion when result type is // a ranked tensor. populateFuncOpTypeConversionPattern(patterns, &getContext(), tensor_to_memref_converter); // Frontent operation lowering. patterns.insert, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXReshapeOpLowering, ONNXEntryPointLowering, ONNXReductionOpLowering, ONNXReductionOpLowering, ONNXReductionOpLowering, ONNXReductionOpLowering, ONNXSoftmaxOpLowering, ONNXGemmOpLowering, ONNXUnsqueezeOpLowering, ONNXTransposeOpLowering, ONNXIdentityOpLowering, ONNXConvNoBiasOpLowering >(&getContext()); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` // operations were not converted successfully. if (failed(applyPartialConversion(module, target, patterns))) signalPassFailure(); } std::unique_ptr mlir::createLowerToKrnlPass() { return std::make_unique(); } static PassRegistration pass("lower-frontend", "Lower frontend ops to Krnl dialect.");