//====- lower_frontend_to_krnl.cpp - Frontend dialects to Krnl lowering ---===// // // Copyright 2019 The IBM Research Authors. // // ============================================================================= // // This file implements the lowering of frontend operations to a combination of // Krnl IR and standard operations. // //===----------------------------------------------------------------------===// #include #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Sequence.h" #include "src/compiler/dialect/krnl/krnl_helper.hpp" #include "src/compiler/dialect/krnl/krnl_ops.hpp" #include "src/compiler/dialect/onnx/onnx_ops.hpp" #include "src/compiler/pass/passes.hpp" using namespace mlir; //===----------------------------------------------------------------------===// // FrontendToAffine RewritePatterns //===----------------------------------------------------------------------===// /// Check is all dimensions are known at compile time. static bool hasAllConstantDimensions(MemRefType type) { auto memRefShape = type.getShape(); for (int i = 0; i < memRefShape.size(); ++i) if (memRefShape[i] < 0) return false; return true; } /// Convert the given TensorType into the corresponding MemRefType. static MemRefType convertTensorToMemRef(TensorType type) { assert(type.hasRank() && "expected only ranked shapes"); return MemRefType::get(type.getShape(), type.getElementType()); } /// Insert an allocation and deallocation for the given MemRefType. static Value *insertAllocAndDealloc(MemRefType type, Location loc, PatternRewriter &rewriter, bool insertDealloc, ArrayRef operands = {}) { // Put together alloc operands for any dynamic dimensions of the memref. AllocOp alloc; if (!operands.empty()) { auto memRefShape = type.getShape(); auto rank = memRefShape.size(); std::map fromOperands; for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { int memRefDimIdx = rank - 1 - reversedIdx; if (memRefShape[memRefDimIdx] < 0) { // unknown dimension Value *maxDim = nullptr; for (int i = 0; i < operands.size(); i++) { auto operandShape = operands[i]->getType().cast().getShape(); int operandDimIdx = operandShape.size() - 1 - reversedIdx; if (operandDimIdx < 0) continue; // In case of operations with broadcasting, the dimension of the // alloc result is the maximum size along each dimension of the // operands. auto operandDim = rewriter.create(loc, operands[i], operandDimIdx); if (maxDim) { auto maxCondition = rewriter.create(loc, CmpIPredicate::sgt, operandDim, maxDim); maxDim = rewriter.create(loc, maxCondition, operandDim, maxDim); } else { maxDim = operandDim; } } fromOperands.insert(std::make_pair(memRefDimIdx, maxDim)); } } SmallVector allocOperands; for (int i = 0; i < rank; ++i) if (memRefShape[i] < 0) allocOperands.push_back(fromOperands[i]); alloc = rewriter.create(loc, type, allocOperands); } else { alloc = rewriter.create(loc, type); } // Make sure to allocate at the beginning of the block if // all dimensions are known. auto *parentBlock = alloc.getOperation()->getBlock(); if (hasAllConstantDimensions(type)) alloc.getOperation()->moveBefore(&parentBlock->front()); if (insertDealloc) { auto dealloc = rewriter.create(loc, alloc); dealloc.getOperation()->moveBefore(&parentBlock->back()); } return alloc; } // Determine if current function returns the result value of the // current op being lowered. If it does then dealloc should not be // inserted. static bool checkInsertDealloc(Operation *currentOp) { auto parentBlock = currentOp->getBlock(); bool insertDealloc = true; parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) { assert(currentOp->getNumResults() < 2 && "No more than one result supported (for now)."); // If there is at least one result to investigate. if (currentOp->getNumResults() > 0) { auto result = currentOp->getResult(0); for (const auto &operand : op.getOperands()) if (operand == result) insertDealloc = false; } }); return insertDealloc; } unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { auto elementType = memRefType.getElementType(); unsigned sizeInBits; if (elementType.isIntOrFloat()) { sizeInBits = elementType.getIntOrFloatBitWidth(); } else { auto vectorType = elementType.cast(); sizeInBits = vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); } return llvm::divideCeil(sizeInBits, 8); } // Get run-time dimension information for unknown dimensions used for // broadcasting. std::map> getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, MemRefType memRefType, ArrayRef operands) { auto memRefShape = memRefType.getShape(); int64_t rank = memRefShape.size(); // For unknown dimensions, we need to get dimension values at runtime in // order to do broadcasting. std::map> DimInfo; // For each result dimension, compute the number of sharing operands. // Sharing operands are operands sharing the same index (counting from the // rightmost to the leftmost) for a given dimension. std::map sharedDimCount; for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { int dimIdx = rank - 1 - reversedIdx; sharedDimCount[dimIdx] = 0; for (int i = 0; i < operands.size(); ++i) { auto shape = operands[i]->getType().cast().getShape(); if (reversedIdx <= shape.size() - 1) sharedDimCount[dimIdx]++; } } // An unknown dimension can have a value of 1 or N (N > 1). // If its value is 1, it is broadcasted dimension. // Otherwise, non-broadcasted dimension. // We only care about unknown dimensions whose number of sharing operands is // more than one, since they are potentially broadcasted dimensions. for (int i = 0; i < operands.size(); ++i) { std::map broadcastedDims; auto shape = operands[i]->getType().cast().getShape(); int size = shape.size(); for (int j = 0; j < shape.size(); ++j) { if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) { auto dim = rewriter.create(loc, operands[i], j).getResult(); auto one = rewriter.create(loc, 1); auto isBroadcasted = rewriter.create(loc, CmpIPredicate::eq, dim, one); broadcastedDims.insert(std::make_pair(j, isBroadcasted)); } } DimInfo.insert(std::make_pair(i, broadcastedDims)); } return DimInfo; } // Extract induction variables that are used for broadcasting values of a // given operand. std::vector getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, ArrayRef loopIVs, Value *operand, std::map broadcastedDims) { // `operand` must has a ranked type. This should have been checked by the // shape inference pass. auto operandShape = operand->getType().cast().getShape(); auto rank = operandShape.size(); auto loopCount = loopIVs.size(); std::vector newLoopIVs; for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { auto dimIdx = rank - 1 - reversedIdx; auto loopIdx = loopCount - 1 - reversedIdx; if (operandShape[dimIdx] == 1) { // Broadcasted dimension auto zero = rewriter.create(loc, 0); newLoopIVs.insert(newLoopIVs.begin(), zero); } else if ((operandShape[dimIdx] == -1) && (broadcastedDims.find(dimIdx) != broadcastedDims.end())) { // Unknown dimension, it can have a value of 1 or N (N > 1). // If its value is 1, it is broadcasted dimension. // Otherwise, non-broadcasted dimension. auto zero = rewriter.create(loc, 0); auto idx = rewriter.create(loc, broadcastedDims[dimIdx], zero, loopIVs[loopIdx]); newLoopIVs.insert(newLoopIVs.begin(), idx); } else { // Non-broadcasted dimension newLoopIVs.insert(newLoopIVs.begin(), loopIVs[loopIdx]); } } return newLoopIVs; } namespace { template struct ScalarOp; template <> struct ScalarOp { using FOp = AddFOp; using IOp = AddIOp; }; template <> struct ScalarOp { using FOp = MulFOp; using IOp = MulIOp; }; template <> struct ScalarOp { using FOp = DivFOp; using IOp = DivISOp; }; template <> struct ScalarOp { using FOp = SubFOp; using IOp = SubIOp; }; template <> struct ScalarOp { using FOp = AndOp; // not use using IOp = AndOp; }; template <> struct ScalarOp { using FOp = OrOp; // not use using IOp = OrOp; }; template <> struct ScalarOp { using FOp = XOrOp; // not use using IOp = XOrOp; }; template <> struct ScalarOp { using FOp = ExpOp; using IOp = ExpOp; // not use }; template <> struct ScalarOp { using FOp = AddFOp; using IOp = AddIOp; }; template using ScalarFOp = typename ScalarOp::FOp; template using ScalarIOp = typename ScalarOp::IOp; //===----------------------------------------------------------------------===// // Scalar unary ops for lowering to Krnl dialect. //===----------------------------------------------------------------------===// template Value *mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { /* Lower UnaryOp to Ops in the Standard dialect. */ auto loc = op->getLoc(); Type element_type = operands.front()->getType(); if (element_type.isa()) { return rewriter.create>(loc, result_types, operands, mlir::None); } else if (element_type.isa()) { return rewriter.create>(loc, result_types, operands, mlir::None); } else { emitError(loc, "unsupported element type"); return nullptr; } } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXTanhOp //===----------------------------------------------------------------------===// template <> Value *mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // AddFOp(ExpOp(%X), ExpOp(NegFOp(%X)))) auto loc = op->getLoc(); Value *operand = operands[0]; auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); auto neg = rewriter.create(loc, zero, operand); auto exp = rewriter.create(loc, operand); auto negExp = rewriter.create(loc, neg); auto result = rewriter.create(loc, rewriter.create(loc, exp, negExp), rewriter.create(loc, exp, negExp)); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSinhOp //===----------------------------------------------------------------------===// template <> Value *mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ConstantOp 2) auto loc = op->getLoc(); Value *operand = operands[0]; auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); auto two = rewriter.create(loc, rewriter.getF32FloatAttr(2.0f)); auto neg = rewriter.create(loc, zero, operand); auto exp = rewriter.create(loc, operand); auto negExp = rewriter.create(loc, neg); auto result = rewriter.create( loc, rewriter.create(loc, exp, negExp), two); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXCoshOp //===----------------------------------------------------------------------===// template <> Value *mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))), // ConstantOp 2) auto loc = op->getLoc(); Value *operand = operands[0]; auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); auto two = rewriter.create(loc, rewriter.getF32FloatAttr(2.0f)); auto neg = rewriter.create(loc, zero, operand); auto exp = rewriter.create(loc, operand); auto negExp = rewriter.create(loc, neg); auto result = rewriter.create( loc, rewriter.create(loc, exp, negExp), two); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSigmoidOp //===----------------------------------------------------------------------===// template <> Value *mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1, // AddFOp(ConstantOp 1, ExpOp(NegFOp(%X)))) auto loc = op->getLoc(); Value *operand = operands[0]; auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); auto one = rewriter.create(loc, rewriter.getF32FloatAttr(1.0f)); auto neg = rewriter.create(loc, zero, operand); auto negExp = rewriter.create(loc, neg); auto result = rewriter.create( loc, one, rewriter.create(loc, one, negExp)); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXHardSigmoidOp //===----------------------------------------------------------------------===// template <> Value *mapToLowerScalarOp( Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // %Y = AddFOp(MulFOp(alpha, %X), beta) // %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0), // %Y, // Constant 0) // ONNXHardSigmoidOp(%X) = SelectOp(CmpFOp(OLT, %Z, Constant 1), // %Z, // Constant 1) auto loc = op->getLoc(); Value *operand = operands[0]; auto alphaAttr = op->getAttrOfType("HardSigmoid.alpha"); auto betaAttr = op->getAttrOfType("HardSigmoid.beta"); auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); auto one = rewriter.create(loc, rewriter.getF32FloatAttr(1.0f)); auto alpha = rewriter.create(loc, alphaAttr); auto beta = rewriter.create(loc, betaAttr); auto add = rewriter.create( loc, rewriter.create(loc, alpha, operand), beta); auto maxPredicate = rewriter.create(loc, CmpFPredicate::OGT, add, zero); auto max = rewriter.create(loc, maxPredicate, add, zero); auto minPredicate = rewriter.create(loc, CmpFPredicate::OLT, max, one); auto result = rewriter.create(loc, minPredicate, max, one); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXEluOp //===----------------------------------------------------------------------===// template <> Value *mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // MulFOp(alpha, SubFOp(ExpOp(%X), 1)), // %X) auto loc = op->getLoc(); Value *operand = operands[0]; auto alphaAttr = op->getAttrOfType("Elu.alpha"); auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); auto one = rewriter.create(loc, rewriter.getF32FloatAttr(1.0f)); auto alpha = rewriter.create(loc, alphaAttr); auto exp = rewriter.create(loc, operand); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); auto result = rewriter.create( loc, lessThanZero, rewriter.create(loc, alpha, rewriter.create(loc, exp, one)), operand); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXReluOp //===----------------------------------------------------------------------===// template <> Value *mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // ConstantOp 0, // %X) auto loc = op->getLoc(); Value *operand = operands[0]; auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); auto result = rewriter.create(loc, lessThanZero, zero, operand); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXLeakyReluOp //===----------------------------------------------------------------------===// template <> Value * mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), // MulFOp(alpha, %X), // %X) auto loc = op->getLoc(); Value *operand = operands[0]; auto alphaAttr = op->getAttrOfType("LeakyRelu.alpha"); auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); auto alpha = rewriter.create(loc, alphaAttr); auto lessThanZero = rewriter.create(loc, CmpFPredicate::OLT, operand, zero); auto result = rewriter.create( loc, lessThanZero, rewriter.create(loc, alpha, operand), operand); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSeluOp //===----------------------------------------------------------------------===// template <> Value *mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0), // MulFOp(gamma, %X), // MulFOp(gamma, // SubFOp(MulFOp(alpha, ExpOp(%X)), // alpha))) auto loc = op->getLoc(); Value *operand = operands[0]; auto alphaAttr = op->getAttrOfType("Selu.alpha"); auto gammaAttr = op->getAttrOfType("Selu.gamma"); auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); auto alpha = rewriter.create(loc, alphaAttr); auto gamma = rewriter.create(loc, gammaAttr); auto exp = rewriter.create(loc, operand); auto greaterThanZero = rewriter.create(loc, CmpFPredicate::OGT, operand, zero); auto select = rewriter.create( loc, greaterThanZero, operand, rewriter.create(loc, rewriter.create(loc, alpha, exp), alpha)); auto result = rewriter.create(loc, gamma, select); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXReciprocalOp //===----------------------------------------------------------------------===// template <> Value * mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X) auto loc = op->getLoc(); Value *operand = operands[0]; auto one = rewriter.create(loc, rewriter.getF32FloatAttr(1.0f)); auto result = rewriter.create(loc, one, operand); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXMaxOp //===----------------------------------------------------------------------===// template <> Value *mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y), // %X, // %Y) auto loc = op->getLoc(); Value *lhs = operands[0]; Value *rhs = operands[1]; auto max = rewriter.create(loc, CmpFPredicate::OGT, lhs, rhs); auto result = rewriter.create(loc, max, lhs, rhs); return result; } //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXMinOp //===----------------------------------------------------------------------===// template <> Value *mapToLowerScalarOp(Operation *op, ArrayRef result_types, ArrayRef operands, ConversionPatternRewriter &rewriter) { // ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y), // %X, // %Y) auto loc = op->getLoc(); Value *lhs = operands[0]; Value *rhs = operands[1]; auto min = rewriter.create(loc, CmpFPredicate::OLT, lhs, rhs); auto result = rewriter.create(loc, min, lhs, rhs); return result; } // Element-wise unary ops lowering to Krnl dialect. //===----------------------------------------------------------------------===// template struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { ONNXElementwiseUnaryOpLowering(MLIRContext *ctx) : ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // TODO: Check that the types are valid. // An element-wise unary operation must have all operands and the result of // the same type. This should have been verified by the verifier. auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertTensorToMemRef(tensorType); // If the output has a dynamic dimension, pass the operands required for // each dynamic dimension to the AllocOp. The first operand of the // operation is used. The operands of the op need to match in terms of // dimensions with the result at this pre-optimization phase. // TODO: verify that dimensions match. // TODO: can the dimension of the result differ after optimizations? Value *alloc; bool insertDealloc = checkInsertDealloc(op); if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, {operands[0]}); // Number of loops auto memRefShape = memRefType.getShape(); int64_t rank = memRefShape.size(); // Define loops. auto loopsOp = rewriter.create(loc, rank); std::vector originalLoops; originalLoops.reserve(rank); for (auto result : loopsOp.getResults()) { originalLoops.push_back(result); } // Define loop optimization. auto optimizedLoopsOp = rewriter.create(loc, rank); std::vector optimizedLoops; optimizedLoops.reserve(rank); for (auto result : optimizedLoopsOp.getResults()) { optimizedLoops.push_back(result); } Block &optimizationBlock = optimizedLoopsOp.region().front(); KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); // Iterate over the loop nest. // TODO (Tian): move this logic inside KrnlIterateOp. Pass MemRefShape // to KrnlIterateOp instead. for (int i = 0; i < rank; ++i) { if (memRefShape[i] < 0) { pack.pushConstantBound(0); pack.pushOperandBound( rewriter.create(loc, operands[0], i).getResult()); } else { pack.pushConstantBound(0); pack.pushConstantBound(memRefShape[i]); } } auto iterateOp = rewriter.create(loc, pack); Block &iterationBlock = iterateOp.bodyRegion().front(); // Now perform the insertions into the body of the // just generated instructions: // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. rewriter.setInsertionPointToEnd(&optimizationBlock); // Return from KrnlOptimizeLoopsOp body. // When no optimizations are present we just return the loops // unchaged. rewriter.create(loc, originalLoops); rewriter.setInsertionPoint(optimizedLoopsOp); // 2. Insert instructions inside the KernelIterateOp body. rewriter.setInsertionPointToStart(&iterationBlock); // Handle the operation: SmallVector loopIVs; for (auto arg : iterationBlock.getArguments()) loopIVs.push_back(arg); auto loadedVal = rewriter.create(loc, operands[0], loopIVs); auto loweredOpResult = mapToLowerScalarOp( op, memRefType.getElementType(), {loadedVal}, rewriter); // Store result in the resulting array. rewriter.create(loc, loweredOpResult, alloc, loopIVs); rewriter.replaceOp(op, alloc); return matchSuccess(); } }; // Element-wise variadic ops lowering to Krnl dialect. //===----------------------------------------------------------------------===// template struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { ONNXElementwiseVariadicOpLowering(MLIRContext *ctx) : ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // TODO: Check that the types are valid. // An element-wise variadic operation must have all operands and the result // of the same type. This should have been verified by the verifier. auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); auto numArgs = op->getNumOperands(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertTensorToMemRef(tensorType); Value *alloc; bool insertDealloc = checkInsertDealloc(op); // If the output has a dynamic dimension, we compute its dimension at // runtime by using dimensions from the operands. // In particular, we need to know from which operand a result dimension // comes from. // TODO: can the dimension of the result differ after optimizations? if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, operands); // Number of loops auto memRefShape = memRefType.getShape(); int64_t rank = memRefShape.size(); // Define loops. auto loopsOp = rewriter.create(loc, rank); std::vector originalLoops; originalLoops.reserve(rank); for (auto result : loopsOp.getResults()) { originalLoops.push_back(result); } // Define loop optimization. auto optimizedLoopsOp = rewriter.create(loc, rank); std::vector optimizedLoops; optimizedLoops.reserve(rank); for (auto result : optimizedLoopsOp.getResults()) { optimizedLoops.push_back(result); } Block &optimizationBlock = optimizedLoopsOp.region().front(); KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); // Iterate over the loop nest. // TODO (Tian): move this logic inside KrnlIterateOp. Pass MemRefShape // to KrnlIterateOp instead. for (int i = 0; i < rank; ++i) { if (memRefShape[i] < 0) { pack.pushConstantBound(0); pack.pushOperandBound( rewriter.create(loc, alloc, i).getResult()); } else { pack.pushConstantBound(0); pack.pushConstantBound(memRefShape[i]); } } // Get run-time dimension information for unknown dimensions used for // broadcasting. std::map> broadcastedDimInfo = getBroadcastedDimInfo(loc, rewriter, memRefType, operands); auto iterateOp = rewriter.create(loc, pack); Block &iterationBlock = iterateOp.bodyRegion().front(); // Now perform the insertions into the body of the // just generated instructions: // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. rewriter.setInsertionPointToEnd(&optimizationBlock); // Return from KrnlOptimizeLoopsOp body. // When no optimizations are present we just return the loops unchaged. rewriter.create(loc, originalLoops); rewriter.setInsertionPoint(optimizedLoopsOp); // 2. Insert instructions inside the KernelIterateOp body. rewriter.setInsertionPointToStart(&iterationBlock); // Handle the operation: SmallVector loopIVs; for (auto arg : iterationBlock.getArguments()) loopIVs.push_back(arg); // Fold over operands for each of their scalar values Value *accumulated, *next; auto accumulatedLoopIVs = getLoopIVsForBroadcasting( loc, rewriter, loopIVs, operands[0], broadcastedDimInfo[0]); accumulated = rewriter.create(loc, operands[0], accumulatedLoopIVs); for (unsigned i = 1; i < numArgs; i++) { auto nextLoopIVs = getLoopIVsForBroadcasting( loc, rewriter, loopIVs, operands[i], broadcastedDimInfo[i]); next = rewriter.create(loc, operands[i], nextLoopIVs); accumulated = mapToLowerScalarOp( op, memRefType.getElementType(), {accumulated, next}, rewriter); } // Store result in the resulting array. rewriter.create(loc, accumulated, alloc, loopIVs); rewriter.replaceOp(op, alloc); return matchSuccess(); } }; struct ONNXReshapeOpLowering : public ConversionPattern { ONNXReshapeOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertTensorToMemRef(tensorType); Value *alloc; // Compute size in bytes. Value *tensorSize = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType))); bool insertDealloc = checkInsertDealloc(op); if (hasAllConstantDimensions(memRefType)) { alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); } else { auto memRefShape = memRefType.getShape(); SmallVector allocOperands; for (int i = 0; i < memRefShape.size(); ++i) { // The shape array can always be used to construct shape information of // the result. Value *index = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i)); Value *loadedVal = rewriter.create(loc, operands[1], index); Value *int64LoadedVal = rewriter.create( loc, loadedVal, rewriter.getIntegerType(64)); tensorSize = rewriter.create(loc, tensorSize, int64LoadedVal); allocOperands.push_back(rewriter.create( loc, loadedVal, rewriter.getIndexType())); } AllocOp allocateMemref = rewriter.create(loc, memRefType, allocOperands); // Make sure to allocate at the beginning of the block if // all dimensions are known. auto *parentBlock = allocateMemref.getOperation()->getBlock(); if (insertDealloc) { auto dealloc = rewriter.create(loc, allocateMemref); dealloc.getOperation()->moveBefore(&parentBlock->back()); } alloc = allocateMemref; } rewriter.create(loc, alloc, operands[0], tensorSize); rewriter.replaceOp(op, alloc); return matchSuccess(); } }; //===----------------------------------------------------------------------===// // Conversion from Tensor type to the Standard dialect MemRef type. //===----------------------------------------------------------------------===// struct TensorTypeConverter : public TypeConverter { using TypeConverter::TypeConverter; LogicalResult convertType(Type t, SmallVectorImpl &results) override { if (auto tensor_type = t.dyn_cast()) { results.push_back(convertTensorToMemRef(tensor_type)); return success(); } results.push_back(t); return success(); } /// Return true if the inputs and outputs of the given function type are /// legal. [Taken from MLIR and adapted to only check the legality of the /// inputs. Once unranked results can be handled gracefully this /// override needs to be removed in favour of the original MLIR one.] bool isSignatureLegal(FunctionType funcType) { return llvm::all_of(funcType.getInputs(), [this](Type type) { return isLegal(type); }); } }; } // end anonymous namespace. //===----------------------------------------------------------------------===// // Frontend to Krnl Dialect lowering pass //===----------------------------------------------------------------------===// /// This is a partial lowering to Krnl loops of the ONNX operations. namespace { struct FrontendToKrnlLoweringPass : public ModulePass { void runOnModule() final; }; } // end anonymous namespace. void FrontendToKrnlLoweringPass::runOnModule() { auto module = getModule(); // The first thing to define is the conversion target. This will define the // final target for this lowering. ConversionTarget target(getContext()); // We define the specific operations, or dialects, that are legal targets for // this lowering. target .addLegalDialect(); // TODO: enable this once more ops are supported. // We also define the ONNX dialect as Illegal so that the conversion will fail // if any of these operations are *not* converted. // target.addIllegalDialect(); // TODO: add any other ops which are considered legal. // Some operations can be marked as being still legal. // Example: target.addLegalOp(); // Now that the conversion target has been defined, we just need to provide // the set of patterns that will lower the frontend operations. OwningRewritePatternList patterns; // Convert TensorType to MemRef TensorTypeConverter tensor_to_memref_converter; target.addDynamicallyLegalOp([&](FuncOp op) { // FuncOp is legal only if types have been converted to Std types. return tensor_to_memref_converter.isSignatureLegal(op.getType()); }); // Type conversion for function signatures. // Call MLIR FuncOp signature conversion when result type is // a ranked tensor. populateFuncOpTypeConversionPattern(patterns, &getContext(), tensor_to_memref_converter); // Frontent operation lowering. patterns.insert, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXReshapeOpLowering>(&getContext()); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` // operations were not converted successfully. if (failed(applyPartialConversion(module, target, patterns))) signalPassFailure(); } std::unique_ptr mlir::createLowerToKrnlPass() { return std::make_unique(); } static PassRegistration pass("lower-frontend", "Lower frontend ops to Krnl dialect.");