//====- lower_frontend_to_krnl.cpp - Frontend dialects to Krnl lowering ---===// // // Copyright 2019 The IBM Research Authors. // // ============================================================================= // // This file implements the lowering of frontend operations to a combination of // Krnl IR and standard operations. // //===----------------------------------------------------------------------===// #include #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Sequence.h" #include "src/dialect/krnl/krnl_helper.hpp" #include "src/dialect/krnl/krnl_ops.hpp" #include "src/dialect/onnx/onnx_ops.hpp" #include "src/pass/passes.hpp" using namespace mlir; //===----------------------------------------------------------------------===// // FrontendToAffine RewritePatterns //===----------------------------------------------------------------------===// /// Check is all dimensions are known at compile time. static bool hasAllConstantDimensions(MemRefType type) { auto memRefShape = type.getShape(); for (int i = 0; i < memRefShape.size(); ++i) if (memRefShape[i] < 0) return false; return true; } /// Convert the given TensorType into the corresponding MemRefType. static MemRefType convertTensorToMemRef(TensorType type) { assert(type.hasRank() && "expected only ranked shapes"); return MemRefType::get(type.getShape(), type.getElementType()); } /// Insert an allocation and deallocation for the given MemRefType. static Value insertAllocAndDealloc(MemRefType type, Location loc, PatternRewriter &rewriter, bool insertDealloc, ArrayRef operands = {}) { // Put together alloc operands for any dynamic dimensions of the memref. AllocOp alloc; if (!operands.empty()) { auto memRefShape = type.getShape(); auto rank = memRefShape.size(); std::map fromOperands; for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { int memRefDimIdx = rank - 1 - reversedIdx; if (memRefShape[memRefDimIdx] < 0) { // unknown dimension Value maxDim = nullptr; for (int i = 0; i < operands.size(); i++) { auto operandShape = operands[i].getType().cast().getShape(); int operandDimIdx = operandShape.size() - 1 - reversedIdx; if (operandDimIdx < 0) continue; // In case of operations with broadcasting, the dimension of the // alloc result is the maximum size along each dimension of the // operands. auto operandDim = rewriter.create(loc, operands[i], operandDimIdx); if (maxDim) { auto maxCondition = rewriter.create(loc, CmpIPredicate::sgt, operandDim, maxDim); maxDim = rewriter.create(loc, maxCondition, operandDim, maxDim); } else { maxDim = operandDim; } } fromOperands.insert(std::make_pair(memRefDimIdx, maxDim)); } } SmallVector allocOperands; for (int i = 0; i < rank; ++i) if (memRefShape[i] < 0) allocOperands.push_back(fromOperands[i]); alloc = rewriter.create(loc, type, allocOperands); } else { alloc = rewriter.create(loc, type); } // Make sure to allocate at the beginning of the block if // all dimensions are known. auto *parentBlock = alloc.getOperation()->getBlock(); if (hasAllConstantDimensions(type)) alloc.getOperation()->moveBefore(&parentBlock->front()); if (insertDealloc) { auto dealloc = rewriter.create(loc, alloc); dealloc.getOperation()->moveBefore(&parentBlock->back()); } return alloc; } // Determine if current function returns the result value of the // current op being lowered. If it does then dealloc should not be // inserted. static bool checkInsertDealloc(Operation *currentOp) { auto parentBlock = currentOp->getBlock(); bool insertDealloc = true; parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) { assert(currentOp->getNumResults() < 2 && "No more than one result supported (for now)."); // If there is at least one result to investigate. if (currentOp->getNumResults() > 0) { auto result = currentOp->getResult(0); for (const auto &operand : op.getOperands()) if (operand == result) insertDealloc = false; } }); return insertDealloc; } unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { auto elementType = memRefType.getElementType(); unsigned sizeInBits; if (elementType.isIntOrFloat()) { sizeInBits = elementType.getIntOrFloatBitWidth(); } else { auto vectorType = elementType.cast(); sizeInBits = vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); } return llvm::divideCeil(sizeInBits, 8); } // Get run-time dimension information for unknown dimensions used for // broadcasting. std::map> getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, MemRefType memRefType, ArrayRef operands) { auto memRefShape = memRefType.getShape(); int64_t rank = memRefShape.size(); // For unknown dimensions, we need to get dimension values at runtime in // order to do broadcasting. std::map> DimInfo; // For each result dimension, compute the number of sharing operands. // Sharing operands are operands sharing the same index (counting from the // rightmost to the leftmost) for a given dimension. std::map sharedDimCount; for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { int dimIdx = rank - 1 - reversedIdx; sharedDimCount[dimIdx] = 0; for (int i = 0; i < operands.size(); ++i) { auto shape = operands[i].getType().cast().getShape(); if (reversedIdx <= shape.size() - 1) sharedDimCount[dimIdx]++; } } // An unknown dimension can have a value of 1 or N (N > 1). // If its value is 1, it is broadcasted dimension. // Otherwise, non-broadcasted dimension. // We only care about unknown dimensions whose number of sharing operands is // more than one, since they are potentially broadcasted dimensions. for (int i = 0; i < operands.size(); ++i) { std::map broadcastedDims; auto shape = operands[i].getType().cast().getShape(); int size = shape.size(); for (int j = 0; j < shape.size(); ++j) { if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) { auto dim = rewriter.create(loc, operands[i], j).getResult(); auto one = rewriter.create(loc, 1); auto isBroadcasted = rewriter.create(loc, CmpIPredicate::eq, dim, one); broadcastedDims.insert(std::make_pair(j, isBroadcasted)); } } DimInfo.insert(std::make_pair(i, broadcastedDims)); } return DimInfo; } // Extract induction variables that are used for broadcasting values of a // given operand. std::vector getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, ArrayRef loopIVs, Value operand, std::map broadcastedDims) { // `operand` must has a ranked type. This should have been checked by the // shape inference pass. auto operandShape = operand.getType().cast().getShape(); auto rank = operandShape.size(); auto loopCount = loopIVs.size(); std::vector newLoopIVs; for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { auto dimIdx = rank - 1 - reversedIdx; auto loopIdx = loopCount - 1 - reversedIdx; if (operandShape[dimIdx] == 1) { // Broadcasted dimension auto zero = rewriter.create(loc, 0); newLoopIVs.insert(newLoopIVs.begin(), zero); } else if ((operandShape[dimIdx] == -1) && (broadcastedDims.find(dimIdx) != broadcastedDims.end())) { // Unknown dimension, it can have a value of 1 or N (N > 1). // If its value is 1, it is broadcasted dimension. // Otherwise, non-broadcasted dimension. auto zero = rewriter.create(loc, 0); auto idx = rewriter.create(loc, broadcastedDims[dimIdx], zero, loopIVs[loopIdx]); newLoopIVs.insert(newLoopIVs.begin(), idx); } else { // Non-broadcasted dimension newLoopIVs.insert(newLoopIVs.begin(), loopIVs[loopIdx]); } } return newLoopIVs; } namespace { template struct ScalarOp; template <> struct ScalarOp { using FOp = AddFOp; using IOp = AddIOp; }; template <> struct ScalarOp { using FOp = MulFOp; using IOp = MulIOp; }; template <> struct ScalarOp { using FOp = DivFOp; using IOp = SignedDivIOp; }; template <> struct ScalarOp { using FOp = SubFOp; using IOp = SubIOp; }; template <> struct ScalarOp { using FOp = AndOp; // not use using IOp = AndOp; }; template <> struct ScalarOp { using FOp = OrOp; // not use using IOp = OrOp; }; template <> struct ScalarOp { using FOp = XOrOp; // not use using IOp = XOrOp; }; template <> struct ScalarOp { using FOp = ExpOp; using IOp = ExpOp; // not use }; template <> struct ScalarOp { using FOp = AddFOp; using IOp = AddIOp; }; template <> struct ScalarOp { using FOp = TanhOp; using IOp = TanhOp; // not use }; template <> struct ScalarOp { using FOp = CosOp; using IOp = CosOp; // not use }; template <> struct ScalarOp { using FOp = LogOp; using IOp = LogOp; // not use }; template using ScalarFOp = typename ScalarOp::FOp; template using ScalarIOp = typename ScalarOp::IOp; //===----------------------------------------------------------------------===// // Scalar unary ops for lowering to Krnl dialect. //===----------------------------------------------------------------------===// template Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { /* Lower UnaryOp to Ops in the Standard dialect. */ auto loc = op->getLoc(); Type element_type = operands.front().getType(); if (element_type.isa()) { return rewriter.create>(loc, result_types, operands, mlir::None); } else if (element_type.isa()) { return rewriter.create>(loc, result_types, operands, mlir::None); } else { emitError(loc, "unsupported element type"); return nullptr; } } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSinhOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ConstantOp 2) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto two = rewriter.create(loc, FloatAttr::get(elementType, 2)); auto neg = rewriter.create(loc, zero, operand); auto exp = rewriter.create(loc, operand); auto negExp = rewriter.create(loc, neg); auto result = rewriter.create( loc, rewriter.create(loc, exp, negExp), two); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXCoshOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ConstantOp 2) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto two = rewriter.create(loc, FloatAttr::get(elementType, 2)); auto neg = rewriter.create(loc, zero, operand); auto exp = rewriter.create(loc, operand); auto negExp = rewriter.create(loc, neg); auto result = rewriter.create( loc, rewriter.create(loc, exp, negExp), two); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSigmoidOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1, // AddFOp(ConstantOp 1, ExpOp(NegFOp(%X)))) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto neg = rewriter.create(loc, zero, operand); auto negExp = rewriter.create(loc, neg); auto result = rewriter.create( loc, one, rewriter.create(loc, one, negExp)); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXHardSigmoidOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp( Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // %Y = AddFOp(MulFOp(alpha, %X), beta) // %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0), // %Y, // Constant 0) // ONNXHardSigmoidOp(%X) = SelectOp(CmpFOp(OLT, %Z, Constant 1), // %Z, // Constant 1) auto loc = op->getLoc(); Value operand = operands[0]; auto alphaAttr = op->getAttrOfType("HardSigmoid.alpha"); auto betaAttr = op->getAttrOfType("HardSigmoid.beta"); auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto alpha = rewriter.create(loc, alphaAttr); auto beta = rewriter.create(loc, betaAttr); auto add = rewriter.create( loc, rewriter.create(loc, alpha, operand), beta); auto maxPredicate = rewriter.create(loc, CmpFPredicate::OGT, add, zero); auto max = rewriter.create(loc, maxPredicate, add, zero); auto minPredicate = rewriter.create(loc, CmpFPredicate::OLT, max, one); auto result = rewriter.create(loc, minPredicate, max, one); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXEluOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // MulFOp(alpha, SubFOp(ExpOp(%X), 1)), // %X) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto alphaAttr = op->getAttrOfType("Elu.alpha"); auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto alpha = rewriter.create(loc, alphaAttr); auto exp = rewriter.create(loc, operand); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); auto result = rewriter.create( loc, lessThanZero, rewriter.create(loc, alpha, rewriter.create(loc, exp, one)), operand); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXReluOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // ConstantOp 0, // %X) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); auto result = rewriter.create(loc, lessThanZero, zero, operand); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXLeakyReluOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // MulFOp(alpha, %X), // %X) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto alphaAttr = op->getAttrOfType("LeakyRelu.alpha"); auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto alpha = rewriter.create(loc, alphaAttr); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); auto result = rewriter.create( loc, lessThanZero, rewriter.create(loc, alpha, operand), operand); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSeluOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0), // MulFOp(gamma, %X), // MulFOp(gamma, // SubFOp(MulFOp(alpha, ExpOp(%X)), // alpha))) auto loc = op->getLoc(); Value operand = operands[0]; auto alphaAttr = op->getAttrOfType("Selu.alpha"); auto gammaAttr = op->getAttrOfType("Selu.gamma"); auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto alpha = rewriter.create(loc, alphaAttr); auto gamma = rewriter.create(loc, gammaAttr); auto exp = rewriter.create(loc, operand); auto greaterThanZero = rewriter.create(loc, CmpFPredicate::OGT, operand, zero); auto select = rewriter.create( loc, greaterThanZero, operand, rewriter.create(loc, rewriter.create(loc, alpha, exp), alpha)); auto result = rewriter.create(loc, gamma, select); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXReciprocalOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp( Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X) auto loc = op->getLoc(); Value operand = operands[0]; auto elementType = result_types[0]; auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto result = rewriter.create(loc, one, operand); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXMaxOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y), // %X, // %Y) auto loc = op->getLoc(); Value lhs = operands[0]; Value rhs = operands[1]; auto max = rewriter.create(loc, CmpFPredicate::OGT, lhs, rhs); auto result = rewriter.create(loc, max, lhs, rhs); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXMinOp //===----------------------------------------------------------------------===// template <> Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y), // %X, // %Y) auto loc = op->getLoc(); Value lhs = operands[0]; Value rhs = operands[1]; auto min = rewriter.create(loc, CmpFPredicate::OLT, lhs, rhs); auto result = rewriter.create(loc, min, lhs, rhs); return result; } // Element-wise unary ops lowering to Krnl dialect. //===----------------------------------------------------------------------===// template struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { ONNXElementwiseUnaryOpLowering(MLIRContext *ctx) : ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // TODO: Check that the types are valid. // An element-wise unary operation must have all operands and the result of // the same type. This should have been verified by the verifier. auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertTensorToMemRef(tensorType); // If the output has a dynamic dimension, pass the operands required for // each dynamic dimension to the AllocOp. The first operand of the // operation is used. The operands of the op need to match in terms of // dimensions with the result at this pre-optimization phase. // TODO: verify that dimensions match. // TODO: can the dimension of the result differ after optimizations? Value alloc; bool insertDealloc = checkInsertDealloc(op); if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, {operands[0]}); // Number of loops auto memRefShape = memRefType.getShape(); int64_t rank = memRefShape.size(); // Define loops. auto loopsOp = rewriter.create(loc, rank); std::vector originalLoops; originalLoops.reserve(rank); for (auto result : loopsOp.getResults()) { originalLoops.push_back(result); } // Define loop optimization. auto optimizedLoopsOp = rewriter.create(loc, rank); std::vector optimizedLoops; optimizedLoops.reserve(rank); for (auto result : optimizedLoopsOp.getResults()) { optimizedLoops.push_back(result); } Block &optimizationBlock = optimizedLoopsOp.region().front(); KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); // Iterate over the loop nest. // TODO (Tian): move this logic inside KrnlIterateOp. Pass MemRefShape // to KrnlIterateOp instead. for (int i = 0; i < rank; ++i) { if (memRefShape[i] < 0) { pack.pushConstantBound(0); pack.pushOperandBound( rewriter.create(loc, operands[0], i).getResult()); } else { pack.pushConstantBound(0); pack.pushConstantBound(memRefShape[i]); } } auto iterateOp = rewriter.create(loc, pack); Block &iterationBlock = iterateOp.bodyRegion().front(); // Now perform the insertions into the body of the // just generated instructions: // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. rewriter.setInsertionPointToEnd(&optimizationBlock); // Return from KrnlOptimizeLoopsOp body. // When no optimizations are present we just return the loops // unchaged. rewriter.create(loc, originalLoops); rewriter.setInsertionPoint(optimizedLoopsOp); // 2. Insert instructions inside the KernelIterateOp body. rewriter.setInsertionPointToStart(&iterationBlock); // Handle the operation: SmallVector loopIVs; for (auto arg : iterationBlock.getArguments()) loopIVs.push_back(arg); auto loadedVal = rewriter.create(loc, operands[0], loopIVs); auto loweredOpResult = mapToLowerScalarOp( op, memRefType.getElementType(), {loadedVal}, rewriter); // Store result in the resulting array. rewriter.create(loc, loweredOpResult, alloc, loopIVs); rewriter.replaceOp(op, alloc); return matchSuccess(); } }; // Element-wise variadic ops lowering to Krnl dialect. //===----------------------------------------------------------------------===// template struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { ONNXElementwiseVariadicOpLowering(MLIRContext *ctx) : ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // TODO: Check that the types are valid. // An element-wise variadic operation must have all operands and the result // of the same type. This should have been verified by the verifier. auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); auto numArgs = op->getNumOperands(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertTensorToMemRef(tensorType); Value alloc; bool insertDealloc = checkInsertDealloc(op); // If the output has a dynamic dimension, we compute its dimension at // runtime by using dimensions from the operands. // In particular, we need to know from which operand a result dimension // comes from. // TODO: can the dimension of the result differ after optimizations? if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, operands); // Number of loops auto memRefShape = memRefType.getShape(); int64_t rank = memRefShape.size(); // Define loops. auto loopsOp = rewriter.create(loc, rank); std::vector originalLoops; originalLoops.reserve(rank); for (auto result : loopsOp.getResults()) { originalLoops.push_back(result); } // Define loop optimization. auto optimizedLoopsOp = rewriter.create(loc, rank); std::vector optimizedLoops; optimizedLoops.reserve(rank); for (auto result : optimizedLoopsOp.getResults()) { optimizedLoops.push_back(result); } Block &optimizationBlock = optimizedLoopsOp.region().front(); KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); // Iterate over the loop nest. // TODO (Tian): move this logic inside KrnlIterateOp. Pass MemRefShape // to KrnlIterateOp instead. for (int i = 0; i < rank; ++i) { if (memRefShape[i] < 0) { pack.pushConstantBound(0); pack.pushOperandBound( rewriter.create(loc, alloc, i).getResult()); } else { pack.pushConstantBound(0); pack.pushConstantBound(memRefShape[i]); } } // Get run-time dimension information for unknown dimensions used for // broadcasting. std::map> broadcastedDimInfo = getBroadcastedDimInfo(loc, rewriter, memRefType, operands); auto iterateOp = rewriter.create(loc, pack); Block &iterationBlock = iterateOp.bodyRegion().front(); // Now perform the insertions into the body of the // just generated instructions: // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. rewriter.setInsertionPointToEnd(&optimizationBlock); // Return from KrnlOptimizeLoopsOp body. // When no optimizations are present we just return the loops unchaged. rewriter.create(loc, originalLoops); rewriter.setInsertionPoint(optimizedLoopsOp); // 2. Insert instructions inside the KernelIterateOp body. rewriter.setInsertionPointToStart(&iterationBlock); // Handle the operation: SmallVector loopIVs; for (auto arg : iterationBlock.getArguments()) loopIVs.push_back(arg); // Fold over operands for each of their scalar values Value accumulated, next; auto accumulatedLoopIVs = getLoopIVsForBroadcasting( loc, rewriter, loopIVs, operands[0], broadcastedDimInfo[0]); accumulated = rewriter.create(loc, operands[0], accumulatedLoopIVs); for (unsigned i = 1; i < numArgs; i++) { auto nextLoopIVs = getLoopIVsForBroadcasting( loc, rewriter, loopIVs, operands[i], broadcastedDimInfo[i]); next = rewriter.create(loc, operands[i], nextLoopIVs); accumulated = mapToLowerScalarOp( op, memRefType.getElementType(), {accumulated, next}, rewriter); } // Store result in the resulting array. rewriter.create(loc, accumulated, alloc, loopIVs); rewriter.replaceOp(op, alloc); return matchSuccess(); } }; struct ONNXSoftmaxOpLowering : public ConversionPattern { ONNXSoftmaxOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // softmax(x) = let max_x = max(x) in // let exp_x = exp(x - max_x) in // let sum = sum(exp_x) in // exp_x / sum auto tensorType = (*op->result_type_begin()).cast(); int64_t rank = tensorType.getRank(); int64_t axis = op->getAttrOfType("Softmax.axis").getInt(); axis = axis >= 0 ? axis : rank + axis; assert(axis >= -rank && axis <= rank - 1); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertTensorToMemRef(tensorType); auto elementType = memRefType.getElementType(); Value alloc; bool insertDealloc = checkInsertDealloc(op); if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, operands[0]); // Shape of the result auto memRefShape = memRefType.getShape(); // Insert allocations and deallocations for sum and max. MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0); Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true); Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true); Value zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); Value negInfinity = rewriter.create( loc, FloatAttr::get(elementType, -std::numeric_limits::infinity())); // Define loops. auto loopsOp = rewriter.create(loc, rank); std::vector originalLoops; originalLoops.reserve(rank); for (auto result : loopsOp.getResults()) { originalLoops.push_back(result); } // Define loop optimization. auto optimizedLoopsOp = rewriter.create(loc, rank); std::vector optimizedLoops; optimizedLoops.reserve(rank); for (auto result : optimizedLoopsOp.getResults()) { optimizedLoops.push_back(result); } Block &optimizationBlock = optimizedLoopsOp.region().front(); // Coerce the input into a 2-D tensor. `axis` will be the coercing point. // This coercing follows the softmax definition in ONNX: // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softmax // Here, we create an outer loop and inner loop for handling the two // dimensions. The outer loop is only created once `axis` is not zero. // Define an outer loop with respect to axis. std::vector outerLoops, optimizedOuterLoops; outerLoops.reserve(axis); optimizedOuterLoops.reserve(axis); for (int i = 0; i < axis; ++i) { outerLoops.push_back(originalLoops[i]); optimizedOuterLoops.push_back(optimizedLoops[i]); } KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops); for (int i = 0; i < axis; ++i) { if (memRefShape[i] < 0) { outerPack.pushConstantBound(0); outerPack.pushOperandBound( rewriter.create(loc, operands[0], i).getResult()); } else { outerPack.pushConstantBound(0); outerPack.pushConstantBound(memRefShape[i]); } } // Define an inner loop with respect to axis. std::vector innerLoops, optimizedInnerLoops; innerLoops.reserve(rank - axis); optimizedInnerLoops.reserve(rank - axis); for (int i = axis; i < rank; ++i) { innerLoops.push_back(originalLoops[i]); optimizedInnerLoops.push_back(optimizedLoops[i]); } KrnlIterateOperandPack innerPack(rewriter, innerLoops, optimizedInnerLoops); for (int i = axis; i < rank; ++i) { if (memRefShape[i] < 0) { innerPack.pushConstantBound(0); innerPack.pushOperandBound( rewriter.create(loc, operands[0], i).getResult()); } else { innerPack.pushConstantBound(0); innerPack.pushConstantBound(memRefShape[i]); } } KrnlIterateOp outerIterateOp, maxIterateOp, sumIterateOp, softmaxIterateOp; SmallVector outerLoopIVs; if (axis != 0) { outerIterateOp = rewriter.create(loc, outerPack); // No optimization rewriter.setInsertionPointToEnd(&optimizationBlock); rewriter.create(loc, originalLoops); rewriter.setInsertionPoint(optimizedLoopsOp); // Insert instructions inside the outer loop. Block &outerIterationBlock = outerIterateOp.bodyRegion().front(); rewriter.setInsertionPointToStart(&outerIterationBlock); for (auto arg : outerIterationBlock.getArguments()) outerLoopIVs.push_back(arg); // Reset accumulators. rewriter.create(loc, zero, sumOp); rewriter.create(loc, negInfinity, maxOp); // Create an inner loop to compute max. maxIterateOp = rewriter.create(loc, innerPack); // Create an inner loop to compute sum. sumIterateOp = rewriter.create(loc, innerPack); // Create an inner loop to compute softmax. softmaxIterateOp = rewriter.create(loc, innerPack); } else { // Reset accumulators. rewriter.create(loc, zero, sumOp); rewriter.create(loc, negInfinity, maxOp); // Create an inner loop to compute max. maxIterateOp = rewriter.create(loc, innerPack); // Create an inner loop to compute sum. sumIterateOp = rewriter.create(loc, innerPack); // Create an inner loop to compute softmax. softmaxIterateOp = rewriter.create(loc, innerPack); // No optimization rewriter.setInsertionPointToEnd(&optimizationBlock); rewriter.create(loc, originalLoops); rewriter.setInsertionPoint(optimizedLoopsOp); } // Insert instructions inside the max loop. Block &maxIterationBlock = maxIterateOp.bodyRegion().front(); rewriter.setInsertionPointToStart(&maxIterationBlock); // Get induction variables. SmallVector maxLoopIVs; for (auto arg : outerLoopIVs) maxLoopIVs.push_back(arg); for (auto arg : maxIterationBlock.getArguments()) maxLoopIVs.push_back(arg); // Compute the max value. Value max = rewriter.create(loc, maxOp); Value nextMax = rewriter.create(loc, operands[0], maxLoopIVs); auto maxCond = rewriter.create(loc, CmpFPredicate::OGT, max, nextMax); max = rewriter.create(loc, maxCond, max, nextMax); rewriter.create(loc, max, maxOp); // Get the max. rewriter.setInsertionPoint(sumIterateOp); max = rewriter.create(loc, maxOp); // Insert instructions inside the sum loop. Block &sumIterationBlock = sumIterateOp.bodyRegion().front(); rewriter.setInsertionPointToStart(&sumIterationBlock); // Get induction variables. SmallVector sumLoopIVs; for (auto arg : outerLoopIVs) sumLoopIVs.push_back(arg); for (auto arg : sumIterationBlock.getArguments()) sumLoopIVs.push_back(arg); // Sum up values. Value sum = rewriter.create(loc, sumOp); Value next = rewriter.create(loc, operands[0], sumLoopIVs); Value sub = rewriter.create(loc, next, max); Value exp = rewriter.create(loc, sub); sum = rewriter.create(loc, sum, exp); rewriter.create(loc, sum, sumOp); // Store intermediate values in the result to avoid recomputation. rewriter.create(loc, exp, alloc, sumLoopIVs); // Get the sum. rewriter.setInsertionPoint(softmaxIterateOp); sum = rewriter.create(loc, sumOp); // Insert instructions inside the softmax loop. Block &softmaxIterationBlock = softmaxIterateOp.bodyRegion().front(); rewriter.setInsertionPointToStart(&softmaxIterationBlock); // Get induction variables. SmallVector softmaxLoopIVs; for (auto arg : outerLoopIVs) softmaxLoopIVs.push_back(arg); for (auto arg : softmaxIterationBlock.getArguments()) softmaxLoopIVs.push_back(arg); // Compute softmax. Value expLoadedVal = rewriter.create(loc, alloc, softmaxLoopIVs); Value result = rewriter.create(loc, expLoadedVal, sum); rewriter.create(loc, result, alloc, softmaxLoopIVs); rewriter.replaceOp(op, alloc); return matchSuccess(); } }; struct ONNXReshapeOpLowering : public ConversionPattern { ONNXReshapeOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertTensorToMemRef(tensorType); Value alloc; // Compute size in bytes. Value tensorSize = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType))); bool insertDealloc = checkInsertDealloc(op); if (hasAllConstantDimensions(memRefType)) { alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); } else { auto memRefShape = memRefType.getShape(); SmallVector allocOperands; for (int i = 0; i < memRefShape.size(); ++i) { // The shape array can always be used to construct shape information of // the result. Value index = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i)); Value loadedVal = rewriter.create(loc, operands[1], index); Value int64LoadedVal = rewriter.create( loc, loadedVal, rewriter.getIntegerType(64)); tensorSize = rewriter.create(loc, tensorSize, int64LoadedVal); allocOperands.push_back(rewriter.create( loc, loadedVal, rewriter.getIndexType())); } AllocOp allocateMemref = rewriter.create(loc, memRefType, allocOperands); // Make sure to allocate at the beginning of the block if // all dimensions are known. auto *parentBlock = allocateMemref.getOperation()->getBlock(); if (insertDealloc) { auto dealloc = rewriter.create(loc, allocateMemref); dealloc.getOperation()->moveBefore(&parentBlock->back()); } alloc = allocateMemref; } rewriter.create(loc, alloc, operands[0], tensorSize); rewriter.replaceOp(op, alloc); return matchSuccess(); } }; //===----------------------------------------------------------------------===// // EntryPoint Op lowering to Krnl Entry Point. //===----------------------------------------------------------------------===// class ONNXEntryPointLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(ONNXEntryPointOp op, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, op.getAttrOfType( ONNXEntryPointOp::getEntryPointFuncAttrName()), op.getAttrOfType(ONNXEntryPointOp::getNumInputsAttrName()), op.getAttrOfType( ONNXEntryPointOp::getNumOutputsAttrName())); return matchSuccess(); } }; //===----------------------------------------------------------------------===// // Conversion from Tensor type to the Standard dialect MemRef type. //===----------------------------------------------------------------------===// struct TensorTypeConverter : public TypeConverter { using TypeConverter::TypeConverter; LogicalResult convertType(Type t, SmallVectorImpl &results) override { if (auto tensor_type = t.dyn_cast()) { results.push_back(convertTensorToMemRef(tensor_type)); return success(); } results.push_back(t); return success(); } /// Return true if the inputs and outputs of the given function type are /// legal. [Taken from MLIR and adapted to only check the legality of the /// inputs. Once unranked results can be handled gracefully this /// override needs to be removed in favour of the original MLIR one.] bool isSignatureLegal(FunctionType funcType) { return llvm::all_of(funcType.getInputs(), [this](Type type) { return isLegal(type); }); } }; } // end anonymous namespace. //===----------------------------------------------------------------------===// // Frontend to Krnl Dialect lowering pass //===----------------------------------------------------------------------===// /// This is a partial lowering to Krnl loops of the ONNX operations. namespace { struct FrontendToKrnlLoweringPass : public ModulePass { void runOnModule() final; }; } // end anonymous namespace. void FrontendToKrnlLoweringPass::runOnModule() { auto module = getModule(); // The first thing to define is the conversion target. This will define the // final target for this lowering. ConversionTarget target(getContext()); // We define the specific operations, or dialects, that are legal targets for // this lowering. target .addLegalDialect(); // TODO: enable this once more ops are supported. // We also define the ONNX dialect as Illegal so that the conversion will fail // if any of these operations are *not* converted. // target.addIllegalDialect(); // TODO: add any other ops which are considered legal. // Some operations can be marked as being still legal. // Example: target.addLegalOp(); // Now that the conversion target has been defined, we just need to provide // the set of patterns that will lower the frontend operations. OwningRewritePatternList patterns; // Convert TensorType to MemRef TensorTypeConverter tensor_to_memref_converter; target.addDynamicallyLegalOp([&](FuncOp op) { // FuncOp is legal only if types have been converted to Std types. return tensor_to_memref_converter.isSignatureLegal(op.getType()); }); // Type conversion for function signatures. // Call MLIR FuncOp signature conversion when result type is // a ranked tensor. populateFuncOpTypeConversionPattern(patterns, &getContext(), tensor_to_memref_converter); // Frontent operation lowering. patterns.insert, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXReshapeOpLowering, ONNXEntryPointLowering, ONNXSoftmaxOpLowering>(&getContext()); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` // operations were not converted successfully. if (failed(applyPartialConversion(module, target, patterns))) signalPassFailure(); } std::unique_ptr mlir::createLowerToKrnlPass() { return std::make_unique(); } static PassRegistration pass("lower-frontend", "Lower frontend ops to Krnl dialect.");