//====- lower_frontend_to_krnl.cpp - Frontend dialects to Krnl lowering ---===// // // Copyright 2019 The IBM Research Authors. // // ============================================================================= // // This file implements the lowering of frontend operations to a combination of // Krnl IR and standard operations. // //===----------------------------------------------------------------------===// #include #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Sequence.h" #include "src/dialect/krnl/krnl_helper.hpp" #include "src/dialect/krnl/krnl_ops.hpp" #include "src/dialect/onnx/onnx_ops.hpp" #include "src/pass/passes.hpp" using namespace mlir; //===----------------------------------------------------------------------===// // FrontendToAffine RewritePatterns //===----------------------------------------------------------------------===// /// Check is all dimensions are known at compile time. static bool hasAllConstantDimensions(MemRefType type) { auto memRefShape = type.getShape(); for (int i = 0; i < memRefShape.size(); ++i) if (memRefShape[i] < 0) return false; return true; } /// Convert the given TensorType into the corresponding MemRefType. static MemRefType convertTensorToMemRef(TensorType type) { assert(type.hasRank() && "expected only ranked shapes"); return MemRefType::get(type.getShape(), type.getElementType()); } /// Insert an allocation and deallocation for the given MemRefType. static Value insertAllocAndDealloc(MemRefType type, Location loc, PatternRewriter &rewriter, bool insertDealloc, ArrayRef operands = {}) { // Put together alloc operands for any dynamic dimensions of the memref. AllocOp alloc; if (!operands.empty()) { auto memRefShape = type.getShape(); auto rank = memRefShape.size(); std::map fromOperands; for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { int memRefDimIdx = rank - 1 - reversedIdx; if (memRefShape[memRefDimIdx] < 0) { // unknown dimension Value maxDim = nullptr; for (int i = 0; i < operands.size(); i++) { auto operandShape = operands[i].getType().cast().getShape(); int operandDimIdx = operandShape.size() - 1 - reversedIdx; if (operandDimIdx < 0) continue; // In case of operations with broadcasting, the dimension of the // alloc result is the maximum size along each dimension of the // operands. auto operandDim = rewriter.create(loc, operands[i], operandDimIdx); if (maxDim) { auto maxCondition = rewriter.create(loc, CmpIPredicate::sgt, operandDim, maxDim); maxDim = rewriter.create(loc, maxCondition, operandDim, maxDim); } else { maxDim = operandDim; } } fromOperands.insert(std::make_pair(memRefDimIdx, maxDim)); } } SmallVector allocOperands; for (int i = 0; i < rank; ++i) if (memRefShape[i] < 0) allocOperands.push_back(fromOperands[i]); alloc = rewriter.create(loc, type, allocOperands); } else { alloc = rewriter.create(loc, type); } // Make sure to allocate at the beginning of the block if // all dimensions are known. auto *parentBlock = alloc.getOperation()->getBlock(); if (hasAllConstantDimensions(type)) alloc.getOperation()->moveBefore(&parentBlock->front()); if (insertDealloc) { auto dealloc = rewriter.create(loc, alloc); dealloc.getOperation()->moveBefore(&parentBlock->back()); } return alloc; } // Determine if current function returns the result value of the // current op being lowered. If it does then dealloc should not be // inserted. static bool checkInsertDealloc(Operation *currentOp) { auto parentBlock = currentOp->getBlock(); bool insertDealloc = true; parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) { assert(currentOp->getNumResults() < 2 && "No more than one result supported (for now)."); // If there is at least one result to investigate. if (currentOp->getNumResults() > 0) { auto result = currentOp->getResult(0); for (const auto &operand : op.getOperands()) if (operand == result) insertDealloc = false; } }); return insertDealloc; } unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { auto elementType = memRefType.getElementType(); unsigned sizeInBits; if (elementType.isIntOrFloat()) { sizeInBits = elementType.getIntOrFloatBitWidth(); } else { auto vectorType = elementType.cast(); sizeInBits = vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); } return llvm::divideCeil(sizeInBits, 8); } // Get run-time dimension information for unknown dimensions used for // broadcasting. std::map> getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, MemRefType memRefType, ArrayRef operands) { auto memRefShape = memRefType.getShape(); int64_t rank = memRefShape.size(); // For unknown dimensions, we need to get dimension values at runtime in // order to do broadcasting. std::map> DimInfo; // For each result dimension, compute the number of sharing operands. // Sharing operands are operands sharing the same index (counting from the // rightmost to the leftmost) for a given dimension. std::map sharedDimCount; for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { int dimIdx = rank - 1 - reversedIdx; sharedDimCount[dimIdx] = 0; for (int i = 0; i < operands.size(); ++i) { auto shape = operands[i].getType().cast().getShape(); if (reversedIdx <= shape.size() - 1) sharedDimCount[dimIdx]++; } } // An unknown dimension can have a value of 1 or N (N > 1). // If its value is 1, it is broadcasted dimension. // Otherwise, non-broadcasted dimension. // We only care about unknown dimensions whose number of sharing operands is // more than one, since they are potentially broadcasted dimensions. for (int i = 0; i < operands.size(); ++i) { std::map broadcastedDims; auto shape = operands[i].getType().cast().getShape(); int size = shape.size(); for (int j = 0; j < shape.size(); ++j) { if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) { auto dim = rewriter.create(loc, operands[i], j).getResult(); auto one = rewriter.create(loc, 1); auto isBroadcasted = rewriter.create(loc, CmpIPredicate::eq, dim, one); broadcastedDims.insert(std::make_pair(j, isBroadcasted)); } } DimInfo.insert(std::make_pair(i, broadcastedDims)); } return DimInfo; } // Extract induction variables that are used for broadcasting values of a // given operand. std::vector getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, ArrayRef loopIVs, Value operand, std::map broadcastedDims) { // `operand` must has a ranked type. This should have been checked by the // shape inference pass. auto operandShape = operand.getType().cast().getShape(); auto rank = operandShape.size(); auto loopCount = loopIVs.size(); std::vector newLoopIVs; for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { auto dimIdx = rank - 1 - reversedIdx; auto loopIdx = loopCount - 1 - reversedIdx; if (operandShape[dimIdx] == 1) { // Broadcasted dimension auto zero = rewriter.create(loc, 0); newLoopIVs.insert(newLoopIVs.begin(), zero); } else if ((operandShape[dimIdx] == -1) && (broadcastedDims.find(dimIdx) != broadcastedDims.end())) { // Unknown dimension, it can have a value of 1 or N (N > 1). // If its value is 1, it is broadcasted dimension. // Otherwise, non-broadcasted dimension. auto zero = rewriter.create(loc, 0); auto idx = rewriter.create(loc, broadcastedDims[dimIdx], zero, loopIVs[loopIdx]); newLoopIVs.insert(newLoopIVs.begin(), idx); } else { // Non-broadcasted dimension newLoopIVs.insert(newLoopIVs.begin(), loopIVs[loopIdx]); } } return newLoopIVs; } namespace { template struct ScalarOp; template <> struct ScalarOp { using FOp = AddFOp; using IOp = AddIOp; }; template <> struct ScalarOp { using FOp = MulFOp; using IOp = MulIOp; }; template <> struct ScalarOp { using FOp = DivFOp; using IOp = SignedDivIOp; }; template <> struct ScalarOp { using FOp = SubFOp; using IOp = SubIOp; }; template <> struct ScalarOp { using FOp = AndOp; // not use using IOp = AndOp; }; template <> struct ScalarOp { using FOp = OrOp; // not use using IOp = OrOp; }; template <> struct ScalarOp { using FOp = XOrOp; // not use using IOp = XOrOp; }; template <> struct ScalarOp { using FOp = ExpOp; using IOp = ExpOp; // not use }; template <> struct ScalarOp { using FOp = AddFOp; using IOp = AddIOp; }; template <> struct ScalarOp { using FOp = TanhOp; using IOp = TanhOp; // not use }; template <> struct ScalarOp { using FOp = CosOp; using IOp = CosOp; // not use }; template <> struct ScalarOp { using FOp = LogOp; using IOp = LogOp; // not use }; template using ScalarFOp = typename ScalarOp::FOp; template using ScalarIOp = typename ScalarOp::IOp; //===----------------------------------------------------------------------===// // Scalar unary ops for lowering to Krnl dialect. //===----------------------------------------------------------------------===// template Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { /* Lower UnaryOp to Ops in the Standard dialect. */ auto loc = op->getLoc(); Type element_type = operands.front().getType(); if (element_type.isa()) { return rewriter.create>(loc, result_types, operands, mlir::None); } else if (element_type.isa()) { return rewriter.create>(loc, result_types, operands, mlir::None); } else { emitError(loc, "unsupported element type"); return nullptr; } } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSinhOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ConstantOp 2) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto two = rewriter.create(loc, FloatAttr::get(elementType, 2)); auto neg = rewriter.create(loc, zero, operand); auto exp = rewriter.create(loc, operand); auto negExp = rewriter.create(loc, neg); auto result = rewriter.create( loc, rewriter.create(loc, exp, negExp), two); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXCoshOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ConstantOp 2) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto two = rewriter.create(loc, FloatAttr::get(elementType, 2)); auto neg = rewriter.create(loc, zero, operand); auto exp = rewriter.create(loc, operand); auto negExp = rewriter.create(loc, neg); auto result = rewriter.create( loc, rewriter.create(loc, exp, negExp), two); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSigmoidOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1, // AddFOp(ConstantOp 1, ExpOp(NegFOp(%X)))) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto neg = rewriter.create(loc, zero, operand); auto negExp = rewriter.create(loc, neg); auto result = rewriter.create( loc, one, rewriter.create(loc, one, negExp)); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXHardSigmoidOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp( Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // %Y = AddFOp(MulFOp(alpha, %X), beta) // %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0), // %Y, // Constant 0) // ONNXHardSigmoidOp(%X) = SelectOp(CmpFOp(OLT, %Z, Constant 1), // %Z, // Constant 1) auto loc = op->getLoc(); Value operand = operands[0]; auto alphaAttr = op->getAttrOfType("alpha"); auto betaAttr = op->getAttrOfType("beta"); auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto alpha = rewriter.create(loc, alphaAttr); auto beta = rewriter.create(loc, betaAttr); auto add = rewriter.create( loc, rewriter.create(loc, alpha, operand), beta); auto maxPredicate = rewriter.create(loc, CmpFPredicate::OGT, add, zero); auto max = rewriter.create(loc, maxPredicate, add, zero); auto minPredicate = rewriter.create(loc, CmpFPredicate::OLT, max, one); auto result = rewriter.create(loc, minPredicate, max, one); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXEluOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // MulFOp(alpha, SubFOp(ExpOp(%X), 1)), // %X) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto alphaAttr = op->getAttrOfType("alpha"); auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto alpha = rewriter.create(loc, alphaAttr); auto exp = rewriter.create(loc, operand); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); auto result = rewriter.create( loc, lessThanZero, rewriter.create(loc, alpha, rewriter.create(loc, exp, one)), operand); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXReluOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // ConstantOp 0, // %X) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); auto result = rewriter.create(loc, lessThanZero, zero, operand); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXLeakyReluOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // MulFOp(alpha, %X), // %X) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto alphaAttr = op->getAttrOfType("alpha"); auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto alpha = rewriter.create(loc, alphaAttr); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); auto result = rewriter.create( loc, lessThanZero, rewriter.create(loc, alpha, operand), operand); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSeluOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0), // MulFOp(gamma, %X), // MulFOp(gamma, // SubFOp(MulFOp(alpha, ExpOp(%X)), // alpha))) auto loc = op->getLoc(); Value operand = operands[0]; auto alphaAttr = op->getAttrOfType("alpha"); auto gammaAttr = op->getAttrOfType("gamma"); auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto alpha = rewriter.create(loc, alphaAttr); auto gamma = rewriter.create(loc, gammaAttr); auto exp = rewriter.create(loc, operand); auto greaterThanZero = rewriter.create(loc, CmpFPredicate::OGT, operand, zero); auto select = rewriter.create( loc, greaterThanZero, operand, rewriter.create(loc, rewriter.create(loc, alpha, exp), alpha)); auto result = rewriter.create(loc, gamma, select); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXReciprocalOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp( Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto result = rewriter.create(loc, one, operand); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSoftplusOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp( Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXSoftplusOp(%X) = LogOp(AddFOp(ExpOp(%X), ConstantOp 1)) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto exp = rewriter.create(loc, operand); auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto add = rewriter.create(loc, exp, one); auto result = rewriter.create(loc, add); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSoftsignOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp( Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXSoftsignOp(%X) = DivFOp(ConstantOp 1, %X) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto abs = rewriter.create(loc, operand); auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto add = rewriter.create(loc, abs, one); auto result = rewriter.create(loc, operand, add); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXMaxOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y), // %X, // %Y) auto loc = op->getLoc(); Value lhs = operands[0]; Value rhs = operands[1]; auto max = rewriter.create(loc, CmpFPredicate::OGT, lhs, rhs); auto result = rewriter.create(loc, max, lhs, rhs); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXMinOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y), // %X, // %Y) auto loc = op->getLoc(); Value lhs = operands[0]; Value rhs = operands[1]; auto min = rewriter.create(loc, CmpFPredicate::OLT, lhs, rhs); auto result = rewriter.create(loc, min, lhs, rhs); return result; } // Element-wise unary ops lowering to Krnl dialect. //===----------------------------------------------------------------------===// template struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { ONNXElementwiseUnaryOpLowering(MLIRContext *ctx) : ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // TODO: Check that the types are valid. // An element-wise unary operation must have all operands and the result of // the same type. This should have been verified by the verifier. auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertTensorToMemRef(tensorType); // If the output has a dynamic dimension, pass the operands required for // each dynamic dimension to the AllocOp. The first operand of the // operation is used. The operands of the op need to match in terms of // dimensions with the result at this pre-optimization phase. // TODO: verify that dimensions match. // TODO: can the dimension of the result differ after optimizations? Value alloc; bool insertDealloc = checkInsertDealloc(op); if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, {operands[0]}); // Number of loops auto memRefShape = memRefType.getShape(); int64_t rank = memRefShape.size(); // Define loops. auto loopsOp = rewriter.create(loc, rank); std::vector originalLoops; originalLoops.reserve(rank); for (auto result : loopsOp.getResults()) { originalLoops.push_back(result); } // Define loop optimization. auto optimizedLoopsOp = rewriter.create(loc, rank); std::vector optimizedLoops; optimizedLoops.reserve(rank); for (auto result : optimizedLoopsOp.getResults()) { optimizedLoops.push_back(result); } Block &optimizationBlock = optimizedLoopsOp.region().front(); KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); // Iterate over the loop nest. // TODO (Tian): move this logic inside KrnlIterateOp. Pass MemRefShape // to KrnlIterateOp instead. for (int i = 0; i < rank; ++i) { if (memRefShape[i] < 0) { pack.pushConstantBound(0); pack.pushOperandBound( rewriter.create(loc, operands[0], i).getResult()); } else { pack.pushConstantBound(0); pack.pushConstantBound(memRefShape[i]); } } auto iterateOp = rewriter.create(loc, pack); Block &iterationBlock = iterateOp.bodyRegion().front(); // Now perform the insertions into the body of the // just generated instructions: // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. rewriter.setInsertionPointToEnd(&optimizationBlock); // Return from KrnlOptimizeLoopsOp body. // When no optimizations are present we just return the loops // unchaged. rewriter.create(loc, originalLoops); rewriter.setInsertionPoint(optimizedLoopsOp); // 2. Insert instructions inside the KernelIterateOp body. rewriter.setInsertionPointToStart(&iterationBlock); // Handle the operation: SmallVector loopIVs; for (auto arg : iterationBlock.getArguments()) loopIVs.push_back(arg); auto loadedVal = rewriter.create(loc, operands[0], loopIVs); auto loweredOpResult = mapToLowerScalarOp( op, memRefType.getElementType(), {loadedVal}, rewriter); // Store result in the resulting array. rewriter.create(loc, loweredOpResult, alloc, loopIVs); rewriter.replaceOp(op, alloc); return matchSuccess(); } }; // Element-wise variadic ops lowering to Krnl dialect. //===----------------------------------------------------------------------===// template struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { ONNXElementwiseVariadicOpLowering(MLIRContext *ctx) : ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // TODO: Check that the types are valid. // An element-wise variadic operation must have all operands and the result // of the same type. This should have been verified by the verifier. auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); auto numArgs = op->getNumOperands(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertTensorToMemRef(tensorType); Value alloc; bool insertDealloc = checkInsertDealloc(op); // If the output has a dynamic dimension, we compute its dimension at // runtime by using dimensions from the operands. // In particular, we need to know from which operand a result dimension // comes from. // TODO: can the dimension of the result differ after optimizations? if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, operands); // Number of loops auto memRefShape = memRefType.getShape(); int64_t rank = memRefShape.size(); // Define loops. auto loopsOp = rewriter.create(loc, rank); std::vector originalLoops; originalLoops.reserve(rank); for (auto result : loopsOp.getResults()) { originalLoops.push_back(result); } // Define loop optimization. auto optimizedLoopsOp = rewriter.create(loc, rank); std::vector optimizedLoops; optimizedLoops.reserve(rank); for (auto result : optimizedLoopsOp.getResults()) { optimizedLoops.push_back(result); } Block &optimizationBlock = optimizedLoopsOp.region().front(); KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); // Iterate over the loop nest. // TODO (Tian): move this logic inside KrnlIterateOp. Pass MemRefShape // to KrnlIterateOp instead. for (int i = 0; i < rank; ++i) { if (memRefShape[i] < 0) { pack.pushConstantBound(0); pack.pushOperandBound( rewriter.create(loc, alloc, i).getResult()); } else { pack.pushConstantBound(0); pack.pushConstantBound(memRefShape[i]); } } // Get run-time dimension information for unknown dimensions used for // broadcasting. std::map> broadcastedDimInfo = getBroadcastedDimInfo(loc, rewriter, memRefType, operands); auto iterateOp = rewriter.create(loc, pack); Block &iterationBlock = iterateOp.bodyRegion().front(); // Now perform the insertions into the body of the // just generated instructions: // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. rewriter.setInsertionPointToEnd(&optimizationBlock); // Return from KrnlOptimizeLoopsOp body. // When no optimizations are present we just return the loops unchaged. rewriter.create(loc, originalLoops); rewriter.setInsertionPoint(optimizedLoopsOp); // 2. Insert instructions inside the KernelIterateOp body. rewriter.setInsertionPointToStart(&iterationBlock); // Handle the operation: SmallVector loopIVs; for (auto arg : iterationBlock.getArguments()) loopIVs.push_back(arg); // Fold over operands for each of their scalar values Value accumulated, next; auto accumulatedLoopIVs = getLoopIVsForBroadcasting( loc, rewriter, loopIVs, operands[0], broadcastedDimInfo[0]); accumulated = rewriter.create(loc, operands[0], accumulatedLoopIVs); for (unsigned i = 1; i < numArgs; i++) { auto nextLoopIVs = getLoopIVsForBroadcasting( loc, rewriter, loopIVs, operands[i], broadcastedDimInfo[i]); next = rewriter.create(loc, operands[i], nextLoopIVs); accumulated = mapToLowerScalarOp( op, memRefType.getElementType(), {accumulated, next}, rewriter); } // Store result in the resulting array. rewriter.create(loc, accumulated, alloc, loopIVs); rewriter.replaceOp(op, alloc); return matchSuccess(); } }; struct ONNXSoftmaxOpLowering : public ConversionPattern { ONNXSoftmaxOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // softmax(x) = let max_x = max(x) in // let exp_x = exp(x - max_x) in // let sum = sum(exp_x) in // exp_x / sum auto tensorType = (*op->result_type_begin()).cast(); int64_t rank = tensorType.getRank(); int64_t axis = op->getAttrOfType("axis").getInt(); axis = axis >= 0 ? axis : rank + axis; assert(axis >= -rank && axis <= rank - 1); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertTensorToMemRef(tensorType); auto elementType = memRefType.getElementType(); Value alloc; bool insertDealloc = checkInsertDealloc(op); if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, operands[0]); // Shape of the result auto memRefShape = memRefType.getShape(); // Insert allocations and deallocations for sum and max. MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0); Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true); Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true); Value zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); Value negInfinity = rewriter.create( loc, FloatAttr::get(elementType, -std::numeric_limits::infinity())); // Define loops. auto loopsOp = rewriter.create(loc, rank); std::vector originalLoops; originalLoops.reserve(rank); for (auto result : loopsOp.getResults()) { originalLoops.push_back(result); } // Define loop optimization. auto optimizedLoopsOp = rewriter.create(loc, rank); std::vector optimizedLoops; optimizedLoops.reserve(rank); for (auto result : optimizedLoopsOp.getResults()) { optimizedLoops.push_back(result); } Block &optimizationBlock = optimizedLoopsOp.region().front(); // Coerce the input into a 2-D tensor. `axis` will be the coercing point. // This coercing follows the softmax definition in ONNX: // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softmax // Here, we create an outer loop and inner loop for handling the two // dimensions. The outer loop is only created once `axis` is not zero. // Define an outer loop with respect to axis. std::vector outerLoops, optimizedOuterLoops; outerLoops.reserve(axis); optimizedOuterLoops.reserve(axis); for (int i = 0; i < axis; ++i) { outerLoops.push_back(originalLoops[i]); optimizedOuterLoops.push_back(optimizedLoops[i]); } KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops); for (int i = 0; i < axis; ++i) { if (memRefShape[i] < 0) { outerPack.pushConstantBound(0); outerPack.pushOperandBound( rewriter.create(loc, operands[0], i).getResult()); } else { outerPack.pushConstantBound(0); outerPack.pushConstantBound(memRefShape[i]); } } // Define an inner loop with respect to axis. std::vector innerLoops, optimizedInnerLoops; innerLoops.reserve(rank - axis); optimizedInnerLoops.reserve(rank - axis); for (int i = axis; i < rank; ++i) { innerLoops.push_back(originalLoops[i]); optimizedInnerLoops.push_back(optimizedLoops[i]); } KrnlIterateOperandPack innerPack(rewriter, innerLoops, optimizedInnerLoops); for (int i = axis; i < rank; ++i) { if (memRefShape[i] < 0) { innerPack.pushConstantBound(0); innerPack.pushOperandBound( rewriter.create(loc, operands[0], i).getResult()); } else { innerPack.pushConstantBound(0); innerPack.pushConstantBound(memRefShape[i]); } } KrnlIterateOp outerIterateOp, maxIterateOp, sumIterateOp, softmaxIterateOp; SmallVector outerLoopIVs; if (axis != 0) { outerIterateOp = rewriter.create(loc, outerPack); // No optimization rewriter.setInsertionPointToEnd(&optimizationBlock); rewriter.create(loc, originalLoops); rewriter.setInsertionPoint(optimizedLoopsOp); // Insert instructions inside the outer loop. Block &outerIterationBlock = outerIterateOp.bodyRegion().front(); rewriter.setInsertionPointToStart(&outerIterationBlock); for (auto arg : outerIterationBlock.getArguments()) outerLoopIVs.push_back(arg); // Reset accumulators. rewriter.create(loc, zero, sumOp); rewriter.create(loc, negInfinity, maxOp); // Create an inner loop to compute max. maxIterateOp = rewriter.create(loc, innerPack); // Create an inner loop to compute sum. sumIterateOp = rewriter.create(loc, innerPack); // Create an inner loop to compute softmax. softmaxIterateOp = rewriter.create(loc, innerPack); } else { // Reset accumulators. rewriter.create(loc, zero, sumOp); rewriter.create(loc, negInfinity, maxOp); // Create an inner loop to compute max. maxIterateOp = rewriter.create(loc, innerPack); // Create an inner loop to compute sum. sumIterateOp = rewriter.create(loc, innerPack); // Create an inner loop to compute softmax. softmaxIterateOp = rewriter.create(loc, innerPack); // No optimization rewriter.setInsertionPointToEnd(&optimizationBlock); rewriter.create(loc, originalLoops); rewriter.setInsertionPoint(optimizedLoopsOp); } // Insert instructions inside the max loop. Block &maxIterationBlock = maxIterateOp.bodyRegion().front(); rewriter.setInsertionPointToStart(&maxIterationBlock); // Get induction variables. SmallVector maxLoopIVs; for (auto arg : outerLoopIVs) maxLoopIVs.push_back(arg); for (auto arg : maxIterationBlock.getArguments()) maxLoopIVs.push_back(arg); // Compute the max value. Value max = rewriter.create(loc, maxOp); Value nextMax = rewriter.create(loc, operands[0], maxLoopIVs); auto maxCond = rewriter.create(loc, CmpFPredicate::OGT, max, nextMax); max = rewriter.create(loc, maxCond, max, nextMax); rewriter.create(loc, max, maxOp); // Get the max. rewriter.setInsertionPoint(sumIterateOp); max = rewriter.create(loc, maxOp); // Insert instructions inside the sum loop. Block &sumIterationBlock = sumIterateOp.bodyRegion().front(); rewriter.setInsertionPointToStart(&sumIterationBlock); // Get induction variables. SmallVector sumLoopIVs; for (auto arg : outerLoopIVs) sumLoopIVs.push_back(arg); for (auto arg : sumIterationBlock.getArguments()) sumLoopIVs.push_back(arg); // Sum up values. Value sum = rewriter.create(loc, sumOp); Value next = rewriter.create(loc, operands[0], sumLoopIVs); Value sub = rewriter.create(loc, next, max); Value exp = rewriter.create(loc, sub); sum = rewriter.create(loc, sum, exp); rewriter.create(loc, sum, sumOp); // Store intermediate values in the result to avoid recomputation. rewriter.create(loc, exp, alloc, sumLoopIVs); // Get the sum. rewriter.setInsertionPoint(softmaxIterateOp); sum = rewriter.create(loc, sumOp); // Insert instructions inside the softmax loop. Block &softmaxIterationBlock = softmaxIterateOp.bodyRegion().front(); rewriter.setInsertionPointToStart(&softmaxIterationBlock); // Get induction variables. SmallVector softmaxLoopIVs; for (auto arg : outerLoopIVs) softmaxLoopIVs.push_back(arg); for (auto arg : softmaxIterationBlock.getArguments()) softmaxLoopIVs.push_back(arg); // Compute softmax. Value expLoadedVal = rewriter.create(loc, alloc, softmaxLoopIVs); Value result = rewriter.create(loc, expLoadedVal, sum); rewriter.create(loc, result, alloc, softmaxLoopIVs); rewriter.replaceOp(op, alloc); return matchSuccess(); } }; struct ONNXReshapeOpLowering : public ConversionPattern { ONNXReshapeOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertTensorToMemRef(tensorType); Value alloc; // Compute size in bytes. Value tensorSize = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType))); bool insertDealloc = checkInsertDealloc(op); if (hasAllConstantDimensions(memRefType)) { alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); } else { auto memRefShape = memRefType.getShape(); SmallVector allocOperands; for (int i = 0; i < memRefShape.size(); ++i) { // The shape array can always be used to construct shape information of // the result. Value index = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i)); Value loadedVal = rewriter.create(loc, operands[1], index); Value int64LoadedVal = rewriter.create( loc, loadedVal, rewriter.getIntegerType(64)); tensorSize = rewriter.create(loc, tensorSize, int64LoadedVal); allocOperands.push_back(rewriter.create( loc, loadedVal, rewriter.getIndexType())); } AllocOp allocateMemref = rewriter.create(loc, memRefType, allocOperands); // Make sure to allocate at the beginning of the block if // all dimensions are known. auto *parentBlock = allocateMemref.getOperation()->getBlock(); if (insertDealloc) { auto dealloc = rewriter.create(loc, allocateMemref); dealloc.getOperation()->moveBefore(&parentBlock->back()); } alloc = allocateMemref; } rewriter.create(loc, alloc, operands[0], tensorSize); rewriter.replaceOp(op, alloc); return matchSuccess(); } }; //===----------------------------------------------------------------------===// // EntryPoint Op lowering to Krnl Entry Point. //===----------------------------------------------------------------------===// class ONNXEntryPointLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(ONNXEntryPointOp op, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, op.getAttrOfType( ONNXEntryPointOp::getEntryPointFuncAttrName()), op.getAttrOfType(ONNXEntryPointOp::getNumInputsAttrName()), op.getAttrOfType( ONNXEntryPointOp::getNumOutputsAttrName())); return matchSuccess(); } }; //===----------------------------------------------------------------------===// // Conversion from Tensor type to the Standard dialect MemRef type. //===----------------------------------------------------------------------===// struct TensorTypeConverter : public TypeConverter { using TypeConverter::TypeConverter; LogicalResult convertType(Type t, SmallVectorImpl &results) override { if (auto tensor_type = t.dyn_cast()) { results.push_back(convertTensorToMemRef(tensor_type)); return success(); } results.push_back(t); return success(); } /// Return true if the inputs and outputs of the given function type are /// legal. [Taken from MLIR and adapted to only check the legality of the /// inputs. Once unranked results can be handled gracefully this /// override needs to be removed in favour of the original MLIR one.] bool isSignatureLegal(FunctionType funcType) { return llvm::all_of(funcType.getInputs(), [this](Type type) { return isLegal(type); }); } }; } // end anonymous namespace. //===----------------------------------------------------------------------===// // Frontend to Krnl Dialect lowering pass //===----------------------------------------------------------------------===// /// This is a partial lowering to Krnl loops of the ONNX operations. namespace { struct FrontendToKrnlLoweringPass : public ModulePass { void runOnModule() final; }; } // end anonymous namespace. void FrontendToKrnlLoweringPass::runOnModule() { auto module = getModule(); // The first thing to define is the conversion target. This will define the // final target for this lowering. ConversionTarget target(getContext()); // We define the specific operations, or dialects, that are legal targets for // this lowering. target .addLegalDialect(); // TODO: enable this once more ops are supported. // We also define the ONNX dialect as Illegal so that the conversion will fail // if any of these operations are *not* converted. // target.addIllegalDialect(); // TODO: add any other ops which are considered legal. // Some operations can be marked as being still legal. // Example: target.addLegalOp(); // Now that the conversion target has been defined, we just need to provide // the set of patterns that will lower the frontend operations. OwningRewritePatternList patterns; // Convert TensorType to MemRef TensorTypeConverter tensor_to_memref_converter; target.addDynamicallyLegalOp([&](FuncOp op) { // FuncOp is legal only if types have been converted to Std types. return tensor_to_memref_converter.isSignatureLegal(op.getType()); }); // Type conversion for function signatures. // Call MLIR FuncOp signature conversion when result type is // a ranked tensor. populateFuncOpTypeConversionPattern(patterns, &getContext(), tensor_to_memref_converter); // Frontent operation lowering. patterns.insert, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXReshapeOpLowering, ONNXEntryPointLowering, ONNXSoftmaxOpLowering>(&getContext()); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` // operations were not converted successfully. if (failed(applyPartialConversion(module, target, patterns))) signalPassFailure(); } std::unique_ptr mlir::createLowerToKrnlPass() { return std::make_unique(); } static PassRegistration pass("lower-frontend", "Lower frontend ops to Krnl dialect.");