diff --git a/src/pass/lower_frontend_to_krnl.cpp b/src/pass/lower_frontend_to_krnl.cpp deleted file mode 100644 index d609bc5..0000000 --- a/src/pass/lower_frontend_to_krnl.cpp +++ /dev/null @@ -1,2380 +0,0 @@ -//====- lower_frontend_to_krnl.cpp - Frontend dialects to Krnl lowering ---===// -// -// Copyright 2019 The IBM Research Authors. -// -// ============================================================================= -// -// This file implements the lowering of frontend operations to a combination of -// Krnl IR and standard operations. -// -//===----------------------------------------------------------------------===// -#include - -#include "mlir/Dialect/AffineOps/AffineOps.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/Sequence.h" - -#include "src/dialect/krnl/krnl_helper.hpp" -#include "src/dialect/krnl/krnl_ops.hpp" -#include "src/dialect/onnx/onnx_ops.hpp" -#include "src/pass/passes.hpp" - -using namespace mlir; - -//===----------------------------------------------------------------------===// -// FrontendToAffine RewritePatterns -//===----------------------------------------------------------------------===// - -/// Check is all dimensions are known at compile time. -static bool hasAllConstantDimensions(MemRefType type) { - auto memRefShape = type.getShape(); - for (int i = 0; i < memRefShape.size(); ++i) - if (memRefShape[i] < 0) - return false; - return true; -} - -/// Get the corresponding MemRefType of a given TensorType/MemRefType. -static MemRefType convertToMemRefType(Type type) { - MemRefType memRefType; - auto tensorType = type.dyn_cast(); - if (tensorType) { - assert(tensorType.hasRank() && "expected only ranked shapes"); - memRefType = - MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - } else { - memRefType = type.dyn_cast(); - } - return memRefType; -} - -/// Insert an allocation and deallocation for the given MemRefType. -static Value insertAllocAndDealloc(MemRefType type, Location loc, - PatternRewriter &rewriter, - bool insertDealloc, - ArrayRef operands = {}) { - // Put together alloc operands for any dynamic dimensions of the memref. - AllocOp alloc; - if (!operands.empty()) { - auto memRefShape = type.getShape(); - auto rank = memRefShape.size(); - - std::map fromOperands; - for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { - int memRefDimIdx = rank - 1 - reversedIdx; - if (memRefShape[memRefDimIdx] < 0) { // unknown dimension - Value maxDim = nullptr; - for (int i = 0; i < operands.size(); i++) { - auto operandShape = - operands[i].getType().cast().getShape(); - int operandDimIdx = operandShape.size() - 1 - reversedIdx; - - if (operandDimIdx < 0) - continue; - - // In case of operations with broadcasting, the dimension of the - // alloc result is the maximum size along each dimension of the - // operands. - auto operandDim = - rewriter.create(loc, operands[i], operandDimIdx); - if (maxDim) { - auto maxCondition = rewriter.create(loc, CmpIPredicate::sgt, - operandDim, maxDim); - maxDim = rewriter.create(loc, maxCondition, operandDim, - maxDim); - } else { - maxDim = operandDim; - } - } - fromOperands.insert(std::make_pair(memRefDimIdx, maxDim)); - } - } - - SmallVector allocOperands; - for (int i = 0; i < rank; ++i) - if (memRefShape[i] < 0) - allocOperands.push_back(fromOperands[i]); - alloc = rewriter.create(loc, type, allocOperands); - } else { - alloc = rewriter.create(loc, type); - } - - // Make sure to allocate at the beginning of the block if - // all dimensions are known. - auto *parentBlock = alloc.getOperation()->getBlock(); - if (hasAllConstantDimensions(type)) - alloc.getOperation()->moveBefore(&parentBlock->front()); - - if (insertDealloc) { - auto dealloc = rewriter.create(loc, alloc); - dealloc.getOperation()->moveBefore(&parentBlock->back()); - } - - return alloc; -} - -// Determine if current function returns the result value of the -// current op being lowered. If it does then dealloc should not be -// inserted. -static bool checkInsertDealloc(Operation *currentOp) { - auto parentBlock = currentOp->getBlock(); - - bool insertDealloc = true; - parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) { - assert(currentOp->getNumResults() < 2 && - "No more than one result supported (for now)."); - // If there is at least one result to investigate. - if (currentOp->getNumResults() > 0) { - auto result = currentOp->getResult(0); - for (const auto &operand : op.getOperands()) - if (operand == result) - insertDealloc = false; - } - }); - - return insertDealloc; -} - -// Create a mapping from result type's dimensions to input type's dimensions, -// given that the result type is the result of a reduction op over the input -// type. -std::map -getReductionMapping(MemRefType inputTy, ArrayRef axes, bool keepdims) { - std::map OutInDimMap; - int64_t rank = inputTy.getRank(); - - // Mark reduction axes. - std::vector isReductionAxis; - for (decltype(rank) i = 0; i < rank; ++i) { - if (std::find(axes.begin(), axes.end(), i) != axes.end()) - isReductionAxis.push_back(true); - else - isReductionAxis.push_back(false); - } - - for (decltype(rank) inIndex = 0, outIndex = 0; inIndex < rank; ++inIndex) { - // If it is a reduction axis, there is no relationship among dimensions. - if (isReductionAxis[inIndex]) { - if (keepdims) - outIndex++; - } else { - OutInDimMap.insert(std::make_pair(outIndex, inIndex)); - outIndex++; - } - } - - return OutInDimMap; -} - -// Add bounds associated with the op operand to the KRNL iteration pack. -// Dynamic dimenions are supported. -static void addDimensionToPack(ConversionPatternRewriter &rewriter, - Location loc, KrnlIterateOperandPack &pack, Value operand, int index) { - auto shape = operand.getType().cast().getShape(); - if (shape[index] < 0) { - pack.pushConstantBound(0); - pack.pushOperandBound( - rewriter.create(loc, operand, index).getResult()); - } else { - pack.pushConstantBound(0); - pack.pushConstantBound(shape[index]); - } -} - -// Function that defines the KRNL dialect loops and their respective -// optimized version. -static KrnlOptimizeLoopsOp emitOptimizedLoops( - ConversionPatternRewriter &rewriter, Location loc, - std::vector &loops, std::vector &optimizedLoops, - int64_t numLoops) { - // Define loops. - auto loopsOp = rewriter.create(loc, numLoops); - loops.reserve(numLoops); - for (auto result : loopsOp.getResults()) - loops.push_back(result); - - // Define optimized version of the loops. - auto optimizedLoopsOp = rewriter.create(loc, numLoops); - optimizedLoops.reserve(numLoops); - for (auto result : optimizedLoopsOp.getResults()) - optimizedLoops.push_back(result); - - return optimizedLoopsOp; -} - -// Function that emits the loops and their optimized version. -// The function returns a reference to the inner optimization block. -static Block* defineLoops(ConversionPatternRewriter &rewriter, - Location loc, std::vector &loops, - std::vector &optimizedLoops, int64_t numLoops) { - KrnlOptimizeLoopsOp optimizedLoopsOp = emitOptimizedLoops( - rewriter, loc, loops, optimizedLoops, numLoops); - return &optimizedLoopsOp.region().front(); -} - -// Function which emits a basic set of loops and optimized loops -// for a given operation argument. A reference to the loop optimization -// block is returned in the last argument of the function. -static void emitKrnlLoopsAndIterationForOperand( - ConversionPatternRewriter &rewriter, Location loc, - Value operand, std::vector &originalLoops, - KrnlOptimizeLoopsOp &optimizedLoopsOp, KrnlIterateOp &iterateOp) { - // Operand shape. - auto shape = operand.getType().cast().getShape(); - - // Number of loops. - int64_t rank = shape.size(); - - // Define loops and optimized loops. - std::vector optimizedLoops; - optimizedLoopsOp = emitOptimizedLoops(rewriter, loc, originalLoops, - optimizedLoops, rank); - - KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); - // Iterate over the loop nest. - for (int i = 0; i < rank; ++i) - addDimensionToPack(rewriter, loc, pack, operand, i); - - iterateOp = rewriter.create(loc, pack); -} - -unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { - auto elementType = memRefType.getElementType(); - - unsigned sizeInBits; - if (elementType.isIntOrFloat()) { - sizeInBits = elementType.getIntOrFloatBitWidth(); - } else { - auto vectorType = elementType.cast(); - sizeInBits = - vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); - } - return llvm::divideCeil(sizeInBits, 8); -} - -// Get run-time dimension information for unknown dimensions used for -// broadcasting. -std::map> -getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, - MemRefType memRefType, ArrayRef operands) { - auto memRefShape = memRefType.getShape(); - int64_t rank = memRefShape.size(); - // For unknown dimensions, we need to get dimension values at runtime in - // order to do broadcasting. - std::map> DimInfo; - // For each result dimension, compute the number of sharing operands. - // Sharing operands are operands sharing the same index (counting from the - // rightmost to the leftmost) for a given dimension. - std::map sharedDimCount; - for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { - int dimIdx = rank - 1 - reversedIdx; - sharedDimCount[dimIdx] = 0; - for (int i = 0; i < operands.size(); ++i) { - auto shape = operands[i].getType().cast().getShape(); - if (reversedIdx <= shape.size() - 1) - sharedDimCount[dimIdx]++; - } - } - // An unknown dimension can have a value of 1 or N (N > 1). - // If its value is 1, it is broadcasted dimension. - // Otherwise, non-broadcasted dimension. - // We only care about unknown dimensions whose number of sharing operands is - // more than one, since they are potentially broadcasted dimensions. - for (int i = 0; i < operands.size(); ++i) { - std::map broadcastedDims; - auto shape = operands[i].getType().cast().getShape(); - int size = shape.size(); - for (int j = 0; j < shape.size(); ++j) { - if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) { - auto dim = rewriter.create(loc, operands[i], j).getResult(); - auto one = rewriter.create(loc, 1); - auto isBroadcasted = - rewriter.create(loc, CmpIPredicate::eq, dim, one); - broadcastedDims.insert(std::make_pair(j, isBroadcasted)); - } - } - DimInfo.insert(std::make_pair(i, broadcastedDims)); - } - return DimInfo; -} - -// Extract induction variables that are used for broadcasting values of a -// given operand. -std::vector -getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, - ArrayRef loopIVs, Value operand, - std::map broadcastedDims) { - // `operand` must has a ranked type. This should have been checked by the - // shape inference pass. - auto operandShape = operand.getType().cast().getShape(); - auto rank = operandShape.size(); - auto loopCount = loopIVs.size(); - - std::vector newLoopIVs; - for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { - auto dimIdx = rank - 1 - reversedIdx; - auto loopIdx = loopCount - 1 - reversedIdx; - if (operandShape[dimIdx] == 1) { - // Broadcasted dimension - auto zero = rewriter.create(loc, 0); - newLoopIVs.insert(newLoopIVs.begin(), zero); - } else if ((operandShape[dimIdx] == -1) && - (broadcastedDims.find(dimIdx) != broadcastedDims.end())) { - // Unknown dimension, it can have a value of 1 or N (N > 1). - // If its value is 1, it is broadcasted dimension. - // Otherwise, non-broadcasted dimension. - auto zero = rewriter.create(loc, 0); - auto idx = rewriter.create(loc, broadcastedDims[dimIdx], zero, - loopIVs[loopIdx]); - newLoopIVs.insert(newLoopIVs.begin(), idx); - } else { - // Non-broadcasted dimension - newLoopIVs.insert(newLoopIVs.begin(), loopIVs[loopIdx]); - } - } - return newLoopIVs; -} - -namespace { - -template -struct ScalarOp; - -template <> -struct ScalarOp { - using FOp = AddFOp; - using IOp = AddIOp; -}; - -template <> -struct ScalarOp { - using FOp = MulFOp; - using IOp = MulIOp; -}; - -template <> -struct ScalarOp { - using FOp = DivFOp; - using IOp = SignedDivIOp; -}; - -template <> -struct ScalarOp { - using FOp = SubFOp; - using IOp = SubIOp; -}; - -template <> -struct ScalarOp { - using FOp = AndOp; // not use - using IOp = AndOp; -}; - -template <> -struct ScalarOp { - using FOp = OrOp; // not use - using IOp = OrOp; -}; - -template <> -struct ScalarOp { - using FOp = XOrOp; // not use - using IOp = XOrOp; -}; - -template <> -struct ScalarOp { - using FOp = ExpOp; - using IOp = ExpOp; // not use -}; - -template <> -struct ScalarOp { - using FOp = AddFOp; - using IOp = AddIOp; -}; - -template <> -struct ScalarOp { - using FOp = TanhOp; - using IOp = TanhOp; // not use -}; - -template <> -struct ScalarOp { - using FOp = CosOp; - using IOp = CosOp; // not use -}; - -template <> -struct ScalarOp { - using FOp = LogOp; - using IOp = LogOp; // not use -}; - -template <> -struct ScalarOp { - using FOp = MulFOp; - using IOp = MulIOp; -}; - -template <> -struct ScalarOp { - using FOp = AddFOp; - using IOp = AddIOp; -}; - -template <> -struct ScalarOp { - using FOp = KrnlSqrtOp; - using IOp = KrnlSqrtOp; // not use -}; - -template -using ScalarFOp = typename ScalarOp::FOp; -template -using ScalarIOp = typename ScalarOp::IOp; - -// Get the identity element of a operation. -// Return NULL if the function does not have identity. -template -DataType getIdentityValue() { - return NULL; -} - -template <> -float getIdentityValue(){ - return (float)-std::numeric_limits::infinity(); -} - -template <> -int getIdentityValue(){ - return std::numeric_limits::min(); -} - -template <> -float getIdentityValue(){ - return (float)std::numeric_limits::infinity(); -} - -template <> -int getIdentityValue(){ - return std::numeric_limits::max(); -} - -template <> -float getIdentityValue(){ - return (float)1.0; -} - -template <> -int getIdentityValue(){ - return 1; -} - -template <> -float getIdentityValue(){ - return (float)0; -} - -template <> -int getIdentityValue(){ - return 0; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering to Krnl dialect. -//===----------------------------------------------------------------------===// -template -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - /* Lower UnaryOp to Ops in the Standard dialect. - */ - auto loc = op->getLoc(); - Type element_type = operands.front().getType(); - if (element_type.isa()) { - return rewriter.create>(loc, result_types, operands, - mlir::None); - } else if (element_type.isa()) { - return rewriter.create>(loc, result_types, operands, - mlir::None); - } else { - emitError(loc, "unsupported element type"); - return nullptr; - } -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXSinhOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), - // ConstantOp 2) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; - - auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); - auto two = rewriter.create(loc, FloatAttr::get(elementType, 2)); - auto neg = rewriter.create(loc, zero, operand); - auto exp = rewriter.create(loc, operand); - auto negExp = rewriter.create(loc, neg); - auto result = rewriter.create( - loc, rewriter.create(loc, exp, negExp), two); - - return result; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXCoshOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))), - // ConstantOp 2) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; - - auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); - auto two = rewriter.create(loc, FloatAttr::get(elementType, 2)); - auto neg = rewriter.create(loc, zero, operand); - auto exp = rewriter.create(loc, operand); - auto negExp = rewriter.create(loc, neg); - auto result = rewriter.create( - loc, rewriter.create(loc, exp, negExp), two); - - return result; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXSigmoidOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, - ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1, - // AddFOp(ConstantOp 1, ExpOp(NegFOp(%X)))) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; - - auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); - auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); - auto neg = rewriter.create(loc, zero, operand); - auto negExp = rewriter.create(loc, neg); - auto result = rewriter.create( - loc, one, rewriter.create(loc, one, negExp)); - - return result; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXHardSigmoidOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp( - Operation *op, ArrayRef result_types, ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // %Y = AddFOp(MulFOp(alpha, %X), beta) - // %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0), - // %Y, - // Constant 0) - // ONNXHardSigmoidOp(%X) = SelectOp(CmpFOp(OLT, %Z, Constant 1), - // %Z, - // Constant 1) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), - llvm::dyn_cast(op).alpha().convertToFloat()); - auto betaAttribute = FloatAttr::get(rewriter.getF32Type(), - llvm::dyn_cast(op).beta().convertToFloat()); - auto elementType = result_types[0]; - - auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); - auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); - auto alpha = rewriter.create(loc, alphaAttribute); - auto beta = rewriter.create(loc, betaAttribute); - - auto add = rewriter.create( - loc, rewriter.create(loc, alpha, operand), beta); - auto maxPredicate = - rewriter.create(loc, CmpFPredicate::OGT, add, zero); - auto max = rewriter.create(loc, maxPredicate, add, zero); - auto minPredicate = - rewriter.create(loc, CmpFPredicate::OLT, max, one); - auto result = rewriter.create(loc, minPredicate, max, one); - - return result; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXEluOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), - // MulFOp(alpha, SubFOp(ExpOp(%X), 1)), - // %X) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; - - auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), - llvm::dyn_cast(op).alpha().convertToFloat()); - auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); - auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); - auto alpha = rewriter.create(loc, alphaAttribute); - auto exp = rewriter.create(loc, operand); - auto lessThanZero = - rewriter.create(loc, CmpFPredicate::OLT, operand, zero); - auto result = rewriter.create( - loc, lessThanZero, - rewriter.create(loc, alpha, - rewriter.create(loc, exp, one)), - operand); - - return result; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXReluOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), - // ConstantOp 0, - // %X) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; - - auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); - auto lessThanZero = - rewriter.create(loc, CmpFPredicate::OLT, operand, zero); - auto result = rewriter.create(loc, lessThanZero, zero, operand); - - return result; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXLeakyReluOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, - ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), - // MulFOp(alpha, %X), - // %X) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; - - auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), - llvm::dyn_cast(op).alpha().convertToFloat()); - auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); - auto alpha = rewriter.create(loc, alphaAttribute); - auto lessThanZero = - rewriter.create(loc, CmpFPredicate::OLT, operand, zero); - auto result = rewriter.create( - loc, lessThanZero, rewriter.create(loc, alpha, operand), operand); - - return result; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXSeluOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0), - // MulFOp(gamma, %X), - // MulFOp(gamma, - // SubFOp(MulFOp(alpha, ExpOp(%X)), - // alpha))) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), - llvm::dyn_cast(op).alpha().convertToFloat()); - auto gammaAttribute = FloatAttr::get(rewriter.getF32Type(), - llvm::dyn_cast(op).gamma().convertToFloat()); - auto elementType = result_types[0]; - - auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); - auto alpha = rewriter.create(loc, alphaAttribute); - auto gamma = rewriter.create(loc, gammaAttribute); - auto exp = rewriter.create(loc, operand); - auto greaterThanZero = - rewriter.create(loc, CmpFPredicate::OGT, operand, zero); - auto select = rewriter.create( - loc, greaterThanZero, operand, - rewriter.create(loc, rewriter.create(loc, alpha, exp), - alpha)); - auto result = rewriter.create(loc, gamma, select); - - return result; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXReciprocalOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp( - Operation *op, ArrayRef result_types, ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; - - auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); - auto result = rewriter.create(loc, one, operand); - - return result; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXSoftplusOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp( - Operation *op, ArrayRef result_types, ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // ONNXSoftplusOp(%X) = LogOp(AddFOp(ExpOp(%X), ConstantOp 1)) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; - - auto exp = rewriter.create(loc, operand); - auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); - auto add = rewriter.create(loc, exp, one); - auto result = rewriter.create(loc, add); - - return result; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXSoftsignOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp( - Operation *op, ArrayRef result_types, ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // ONNXSoftsignOp(%X) = DivFOp(ConstantOp 1, %X) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; - - auto abs = rewriter.create(loc, operand); - auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); - auto add = rewriter.create(loc, abs, one); - auto result = rewriter.create(loc, operand, add); - - return result; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXSignOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - - auto loc = op->getLoc(); - Value operand = operands[0]; - Type element_type = operands.front().getType(); - // TODO: unsigned int should be supported separately? - if (element_type.isa()) { - // %Y = SelectOP(CmpIOp(GT, %X, ConstantOp 0), - // ConstantOp 1, - // COnstantOp -1) - // ONNXSignOp(%X) = SelectOP(CmpIOp(EQ, %X, ConstantOp 0), - // ConstantOp 0, - // %Y) - auto zero = rewriter.create(loc, rewriter.getI32IntegerAttr(0)); - auto one = rewriter.create(loc, rewriter.getI32IntegerAttr(1)); - auto minusOne = - rewriter.create(loc, rewriter.getI32IntegerAttr(-1)); - auto plusPredicate = - rewriter.create(loc, CmpIPredicate::sgt, operand, zero); - auto plusSelect = - rewriter.create(loc, plusPredicate, one, minusOne); - auto zeroPredicate = - rewriter.create(loc, CmpIPredicate::eq, operand, zero); - auto result = - rewriter.create(loc, zeroPredicate, zero, plusSelect); - return result; - } else if (element_type.isa()) { - // %Y = SelectOP(CmpFOp(OGT, %X, ConstantOp 0), - // ConstantOp 1, - // ConstantOp -1) - // ONNXSignOp(%X) = SelectOP(CmpFOp(OEQ, %X, ConstantOp 0), - // ConstantOp 0, - // %Y) - auto zero = - rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); - auto one = rewriter.create(loc, rewriter.getF32FloatAttr(1.0f)); - auto minusOne = - rewriter.create(loc, rewriter.getF32FloatAttr(-1.0f)); - auto plusPredicate = - rewriter.create(loc, CmpFPredicate::OGT, operand, zero); - auto plusSelect = - rewriter.create(loc, plusPredicate, one, minusOne); - auto zeroPredicate = - rewriter.create(loc, CmpFPredicate::OEQ, operand, zero); - auto result = - rewriter.create(loc, zeroPredicate, zero, plusSelect); - return result; - } else { - emitError(loc, "unsupported element type"); - } -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXMaxOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y), - // %X, - // %Y) - auto loc = op->getLoc(); - Value lhs = operands[0]; - Value rhs = operands[1]; - auto max = rewriter.create(loc, CmpFPredicate::OGT, lhs, rhs); - auto result = rewriter.create(loc, max, lhs, rhs); - return result; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXMinOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y), - // %X, - // %Y) - auto loc = op->getLoc(); - Value lhs = operands[0]; - Value rhs = operands[1]; - auto min = rewriter.create(loc, CmpFPredicate::OLT, lhs, rhs); - auto result = rewriter.create(loc, min, lhs, rhs); - return result; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXReduceMaxOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, - ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - auto loc = op->getLoc(); - Value lhs = operands[0]; - Value rhs = operands[1]; - Type element_type = lhs.getType(); - if (element_type.isa()) { - auto max = rewriter.create(loc, CmpIPredicate::sgt, lhs, rhs); - auto result = rewriter.create(loc, max, lhs, rhs); - return result; - } else if (element_type.isa()) { - auto max = rewriter.create(loc, CmpFPredicate::OGT, lhs, rhs); - auto result = rewriter.create(loc, max, lhs, rhs); - return result; - } else { - emitError(loc, "unsupported element type"); - return nullptr; - } -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXReduceMinOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, - ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - auto loc = op->getLoc(); - Value lhs = operands[0]; - Value rhs = operands[1]; - Type element_type = lhs.getType(); - if (element_type.isa()) { - auto min = rewriter.create(loc, CmpIPredicate::slt, lhs, rhs); - auto result = rewriter.create(loc, min, lhs, rhs); - return result; - } else if (element_type.isa()) { - auto min = rewriter.create(loc, CmpFPredicate::OLT, lhs, rhs); - auto result = rewriter.create(loc, min, lhs, rhs); - return result; - } else { - emitError(loc, "unsupported element type"); - return nullptr; - } -} - -// Element-wise unary ops lowering to Krnl dialect. -//===----------------------------------------------------------------------===// -template -struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { - ONNXElementwiseUnaryOpLowering(MLIRContext *ctx) - : ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {} - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - // TODO: Check that the types are valid. - // An element-wise unary operation must have all operands and the result of - // the same type. This should have been verified by the verifier. - - auto loc = op->getLoc(); - - // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertToMemRefType(*op->result_type_begin()); - - // If the output has a dynamic dimension, pass the operands required for - // each dynamic dimension to the AllocOp. The first operand of the - // operation is used. The operands of the op need to match in terms of - // dimensions with the result at this pre-optimization phase. - // TODO: verify that dimensions match. - // TODO: can the dimension of the result differ after optimizations? - Value alloc; - bool insertDealloc = checkInsertDealloc(op); - - if (hasAllConstantDimensions(memRefType)) - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); - else - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, - {operands[0]}); - - std::vector originalLoops; - KrnlOptimizeLoopsOp optimizedLoopsOp; - KrnlIterateOp iterateOp; - emitKrnlLoopsAndIterationForOperand( - rewriter, loc, operands[0], originalLoops, - optimizedLoopsOp, iterateOp); - Block &optimizationBlock = optimizedLoopsOp.region().front(); - Block &iterationBlock = iterateOp.bodyRegion().front(); - - // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. - rewriter.setInsertionPointToEnd(&optimizationBlock); - // Return from KrnlOptimizeLoopsOp body. - // When no optimizations are present we just return the loops - // unchaged. - rewriter.create(loc, originalLoops); - - // 2. Insert instructions inside the KernelIterateOp body. - rewriter.setInsertionPointToStart(&iterationBlock); - - // Handle the operation: - SmallVector loopIVs; - for (auto arg : iterationBlock.getArguments()) - loopIVs.push_back(arg); - - auto loadedVal = rewriter.create(loc, operands[0], loopIVs); - auto loweredOpResult = mapToLowerScalarOp( - op, memRefType.getElementType(), {loadedVal}, rewriter); - // Store result in the resulting array. - rewriter.create(loc, loweredOpResult, alloc, loopIVs); - - rewriter.replaceOp(op, alloc); - - return matchSuccess(); - } -}; - -// Element-wise variadic ops lowering to Krnl dialect. -//===----------------------------------------------------------------------===// -template -struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { - ONNXElementwiseVariadicOpLowering(MLIRContext *ctx) - : ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {} - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - // TODO: Check that the types are valid. - // An element-wise variadic operation must have all operands and the result - // of the same type. This should have been verified by the verifier. - auto loc = op->getLoc(); - auto numArgs = op->getNumOperands(); - - // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertToMemRefType(*op->result_type_begin()); - - Value alloc; - bool insertDealloc = checkInsertDealloc(op); - // If the output has a dynamic dimension, we compute its dimension at - // runtime by using dimensions from the operands. - // In particular, we need to know from which operand a result dimension - // comes from. - // TODO: can the dimension of the result differ after optimizations? - if (hasAllConstantDimensions(memRefType)) - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); - else - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, - operands); - - // Get run-time dimension information for unknown dimensions used for - // broadcasting. - std::map> broadcastedDimInfo = - getBroadcastedDimInfo(loc, rewriter, memRefType, operands); - - std::vector originalLoops; - KrnlOptimizeLoopsOp optimizedLoopsOp; - KrnlIterateOp iterateOp; - emitKrnlLoopsAndIterationForOperand( - rewriter, loc, alloc, originalLoops, - optimizedLoopsOp, iterateOp); - Block &optimizationBlock = optimizedLoopsOp.region().front(); - Block &iterationBlock = iterateOp.bodyRegion().front(); - - // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. - rewriter.setInsertionPointToEnd(&optimizationBlock); - // Return from KrnlOptimizeLoopsOp body. - // When no optimizations are present we just return the loops unchaged. - rewriter.create(loc, originalLoops); - - // 2. Insert instructions inside the KernelIterateOp body. - rewriter.setInsertionPointToStart(&iterationBlock); - - // Handle the operation: - SmallVector loopIVs; - for (auto arg : iterationBlock.getArguments()) - loopIVs.push_back(arg); - - // Fold over operands for each of their scalar values - Value accumulated, next; - auto accumulatedLoopIVs = getLoopIVsForBroadcasting( - loc, rewriter, loopIVs, operands[0], broadcastedDimInfo[0]); - accumulated = rewriter.create(loc, operands[0], accumulatedLoopIVs); - for (unsigned i = 1; i < numArgs; i++) { - auto nextLoopIVs = getLoopIVsForBroadcasting( - loc, rewriter, loopIVs, operands[i], broadcastedDimInfo[i]); - next = rewriter.create(loc, operands[i], nextLoopIVs); - accumulated = mapToLowerScalarOp( - op, memRefType.getElementType(), {accumulated, next}, rewriter); - } - // Store result in the resulting array. - rewriter.create(loc, accumulated, alloc, loopIVs); - - rewriter.replaceOp(op, alloc); - - return matchSuccess(); - } -}; - -struct ONNXSoftmaxOpLowering : public ConversionPattern { - ONNXSoftmaxOpLowering(MLIRContext *ctx) - : ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {} - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - // softmax(x) = let max_x = max(x) in - // let exp_x = exp(x - max_x) in - // let sum = sum(exp_x) in - // exp_x / sum - auto memRefType = convertToMemRefType(*op->result_type_begin()); - int64_t rank = memRefType.getRank(); - int64_t axis = llvm::dyn_cast(op).axis().getSExtValue(); - axis = axis >= 0 ? axis : rank + axis; - assert(axis >= -rank && axis <= rank - 1); - - auto loc = op->getLoc(); - - // Insert an allocation and deallocation for the result of this operation. - auto elementType = memRefType.getElementType(); - - Value alloc; - bool insertDealloc = checkInsertDealloc(op); - if (hasAllConstantDimensions(memRefType)) - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); - else - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, - operands[0]); - - // Shape of the result - auto memRefShape = memRefType.getShape(); - - // Insert allocations and deallocations for sum and max. - MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0); - Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true); - Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true); - Value zero = - rewriter.create(loc, FloatAttr::get(elementType, 0)); - Value negInfinity = rewriter.create( - loc, - FloatAttr::get(elementType, -std::numeric_limits::infinity())); - - // Define loops. - std::vector originalLoops; - std::vector optimizedLoops; - Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, - optimizedLoops, rank); - - // Coerce the input into a 2-D tensor. `axis` will be the coercing point. - // This coercing follows the softmax definition in ONNX: - // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softmax - // Here, we create an outer loop and inner loop for handling the two - // dimensions. The outer loop is only created once `axis` is not zero. - - // Define an outer loop with respect to axis. - std::vector outerLoops, optimizedOuterLoops; - outerLoops.reserve(axis); - optimizedOuterLoops.reserve(axis); - for (int i = 0; i < axis; ++i) { - outerLoops.push_back(originalLoops[i]); - optimizedOuterLoops.push_back(optimizedLoops[i]); - } - KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops); - for (int i = 0; i < axis; ++i) - addDimensionToPack(rewriter, loc, outerPack, operands[0], i); - - // Define an inner loop with respect to axis. - std::vector innerLoops, optimizedInnerLoops; - innerLoops.reserve(rank - axis); - optimizedInnerLoops.reserve(rank - axis); - for (int i = axis; i < rank; ++i) { - innerLoops.push_back(originalLoops[i]); - optimizedInnerLoops.push_back(optimizedLoops[i]); - } - KrnlIterateOperandPack innerPack(rewriter, innerLoops, optimizedInnerLoops); - for (int i = axis; i < rank; ++i) - addDimensionToPack(rewriter, loc, innerPack, operands[0], i); - - KrnlIterateOp outerIterateOp, maxIterateOp, sumIterateOp, softmaxIterateOp; - SmallVector outerLoopIVs; - if (axis != 0) { - outerIterateOp = rewriter.create(loc, outerPack); - - // No optimization - rewriter.setInsertionPointToEnd(optimizationBlock); - rewriter.create(loc, originalLoops); - - // Insert instructions inside the outer loop. - Block &outerIterationBlock = outerIterateOp.bodyRegion().front(); - rewriter.setInsertionPointToStart(&outerIterationBlock); - for (auto arg : outerIterationBlock.getArguments()) - outerLoopIVs.push_back(arg); - - // Reset accumulators. - rewriter.create(loc, zero, sumOp); - rewriter.create(loc, negInfinity, maxOp); - - // Create an inner loop to compute max. - maxIterateOp = rewriter.create(loc, innerPack); - // Create an inner loop to compute sum. - sumIterateOp = rewriter.create(loc, innerPack); - // Create an inner loop to compute softmax. - softmaxIterateOp = rewriter.create(loc, innerPack); - } else { - // Reset accumulators. - rewriter.create(loc, zero, sumOp); - rewriter.create(loc, negInfinity, maxOp); - - // Create an inner loop to compute max. - maxIterateOp = rewriter.create(loc, innerPack); - // Create an inner loop to compute sum. - sumIterateOp = rewriter.create(loc, innerPack); - // Create an inner loop to compute softmax. - softmaxIterateOp = rewriter.create(loc, innerPack); - - // No optimization - rewriter.setInsertionPointToEnd(optimizationBlock); - rewriter.create(loc, originalLoops); - } - - // Insert instructions inside the max loop. - Block &maxIterationBlock = maxIterateOp.bodyRegion().front(); - rewriter.setInsertionPointToStart(&maxIterationBlock); - - // Get induction variables. - SmallVector maxLoopIVs; - for (auto arg : outerLoopIVs) - maxLoopIVs.push_back(arg); - for (auto arg : maxIterationBlock.getArguments()) - maxLoopIVs.push_back(arg); - - // Compute the max value. - Value max = rewriter.create(loc, maxOp); - Value nextMax = rewriter.create(loc, operands[0], maxLoopIVs); - auto maxCond = - rewriter.create(loc, CmpFPredicate::OGT, max, nextMax); - max = rewriter.create(loc, maxCond, max, nextMax); - rewriter.create(loc, max, maxOp); - - // Get the max. - rewriter.setInsertionPoint(sumIterateOp); - max = rewriter.create(loc, maxOp); - - // Insert instructions inside the sum loop. - Block &sumIterationBlock = sumIterateOp.bodyRegion().front(); - rewriter.setInsertionPointToStart(&sumIterationBlock); - - // Get induction variables. - SmallVector sumLoopIVs; - for (auto arg : outerLoopIVs) - sumLoopIVs.push_back(arg); - for (auto arg : sumIterationBlock.getArguments()) - sumLoopIVs.push_back(arg); - - // Sum up values. - Value sum = rewriter.create(loc, sumOp); - Value next = rewriter.create(loc, operands[0], sumLoopIVs); - Value sub = rewriter.create(loc, next, max); - Value exp = rewriter.create(loc, sub); - sum = rewriter.create(loc, sum, exp); - rewriter.create(loc, sum, sumOp); - // Store intermediate values in the result to avoid recomputation. - rewriter.create(loc, exp, alloc, sumLoopIVs); - - // Get the sum. - rewriter.setInsertionPoint(softmaxIterateOp); - sum = rewriter.create(loc, sumOp); - - // Insert instructions inside the softmax loop. - Block &softmaxIterationBlock = softmaxIterateOp.bodyRegion().front(); - rewriter.setInsertionPointToStart(&softmaxIterationBlock); - - // Get induction variables. - SmallVector softmaxLoopIVs; - for (auto arg : outerLoopIVs) - softmaxLoopIVs.push_back(arg); - for (auto arg : softmaxIterationBlock.getArguments()) - softmaxLoopIVs.push_back(arg); - - // Compute softmax. - Value expLoadedVal = rewriter.create(loc, alloc, softmaxLoopIVs); - Value result = rewriter.create(loc, expLoadedVal, sum); - rewriter.create(loc, result, alloc, softmaxLoopIVs); - - rewriter.replaceOp(op, alloc); - - return matchSuccess(); - } -}; - -struct ONNXReshapeOpLowering : public ConversionPattern { - ONNXReshapeOpLowering(MLIRContext *ctx) - : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - auto loc = op->getLoc(); - - auto memRefType = convertToMemRefType(*op->result_type_begin()); - auto memRefShape = memRefType.getShape(); - auto inputShape = operands[0].getType().cast().getShape(); - - // Insert an allocation and deallocation for the result of this operation. - Value alloc; - - // Compute size in bytes using the input tensor. - Value tensorSize = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), - getMemRefEltSizeInBytes(memRefType))); - for (int i = 0; i < inputShape.size(); ++i) { - Value dimVal; - if (inputShape[i] < 0) { - Value dim = rewriter.create(loc, operands[0], i); - dimVal = - rewriter.create(loc, dim, rewriter.getIntegerType(64)); - } else { - dimVal = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), - inputShape[i])); - } - tensorSize = rewriter.create(loc, tensorSize, dimVal); - } - - bool insertDealloc = checkInsertDealloc(op); - if (hasAllConstantDimensions(memRefType)) { - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); - } else { - // If a dimension is zero, the actual dimension value is taken from the - // input tensor. - // - // If the shape array has a negative dimension (-1), we compute its actual - // dimension value from the other dimensions. But we don't have enough - // information about the other dimensions at this point. So, we need to - // scan the shape first to calculate reduction of all of the dimensions. - // If the reduction is negative, then the shape array contains a negative - // dimension. Otherwise, the reduction is the same as the one computed - // from the input tensor. - Value tensorSizeFromShape = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), - getMemRefEltSizeInBytes(memRefType))); - SmallVector DimInfo; - for (int i = 0; i < memRefShape.size(); ++i) { - Value index = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i)); - // Load index from array of indices. - Value loadedVal = rewriter.create(loc, operands[1], index); - // If a dimension is zero, the actual dimension value is taken from the - // input tensor. - // - // If a dimension is negative, it is computed from the other dimensions. - // But we don't have enough information about the other dimensions at - // this point. So, we let it as it is (-1), and compute it later. - if (i < inputShape.size()) { - Value dimVal; - auto loadedValType = loadedVal.getType().cast(); - if (inputShape[i] < 0) { - Value dim = rewriter.create(loc, operands[0], i); - dimVal = rewriter.create(loc, dim, loadedValType); - } else { - dimVal = rewriter.create( - loc, rewriter.getIntegerAttr(loadedValType, inputShape[i])); - } - auto zero = rewriter.create( - loc, rewriter.getIntegerAttr(loadedValType, 0)); - auto isZero = - rewriter.create(loc, CmpIPredicate::eq, loadedVal, zero); - loadedVal = rewriter.create(loc, isZero, dimVal, loadedVal); - } - // Check if the loaded index is already the correct width of 64 bits. - // Convert the value to a 64 bit integer if needed. - Value int64LoadedVal = loadedVal; - if (loadedVal.getType().cast().getWidth() < 64) - int64LoadedVal = rewriter.create( - loc, loadedVal, rewriter.getIntegerType(64)); - tensorSizeFromShape = - rewriter.create(loc, tensorSizeFromShape, int64LoadedVal); - // Store intermediate results to use later. - DimInfo.emplace_back(int64LoadedVal); - } - // Reverse tensorSizeFromShape since it is negative if the shape array has - // a negative dimension. This is safe since we only use it to compute the - // actual value for the negative dimension. - auto zero = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - tensorSizeFromShape = - rewriter.create(loc, zero, tensorSizeFromShape); - - // Obtain operands for AllocOp. - SmallVector allocOperands; - auto negOne = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), -1)); - - for (int i = 0; i < memRefShape.size(); ++i) { - auto dimVal = DimInfo[i]; - auto isNegOne = - rewriter.create(loc, CmpIPredicate::eq, dimVal, negOne); - // If dimension is negative, compute its value from the other - // dimensions. - auto actualDimVal = - rewriter.create(loc, tensorSize, tensorSizeFromShape); - auto loadedVal = - rewriter.create(loc, isNegOne, actualDimVal, dimVal); - allocOperands.push_back(rewriter.create( - loc, loadedVal, rewriter.getIndexType())); - } - AllocOp allocateMemref = - rewriter.create(loc, memRefType, allocOperands); - - // Make sure to allocate at the beginning of the block if - // all dimensions are known. - auto *parentBlock = allocateMemref.getOperation()->getBlock(); - if (insertDealloc) { - auto dealloc = rewriter.create(loc, allocateMemref); - dealloc.getOperation()->moveBefore(&parentBlock->back()); - } - - alloc = allocateMemref; - } - - rewriter.create(loc, alloc, operands[0], tensorSize); - rewriter.replaceOp(op, alloc); - - return matchSuccess(); - } -}; - -struct ONNXGemmOpLowering : public ConversionPattern { - ONNXGemmOpLowering(MLIRContext *ctx) - : ConversionPattern(mlir::ONNXGemmOp::getOperationName(), 1, ctx) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - auto loc = op->getLoc(); - - Value A, B, C; - A = operands[0]; - B = operands[1]; - C = operands[2]; - - auto memRefType = convertToMemRefType(*op->result_type_begin()); - - auto alphaAttr = FloatAttr::get(memRefType.getElementType(), - llvm::dyn_cast(op).alpha().convertToFloat()); - auto betaAttr = FloatAttr::get(memRefType.getElementType(), - llvm::dyn_cast(op).beta().convertToFloat()); - auto alpha = rewriter.create(loc, alphaAttr); - auto beta = rewriter.create(loc, betaAttr); - - bool isTransA = (llvm::dyn_cast(op).transA() != 0); - bool isTransB = (llvm::dyn_cast(op).transB() != 0); - - // Insert an allocation and deallocation for the result of this operation. - Value alloc; - bool insertDealloc = checkInsertDealloc(op); - if (hasAllConstantDimensions(memRefType)) - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); - else { - auto memRefShape = memRefType.getShape(); - SmallVector allocOperands; - if (memRefShape[0] < 0) { - auto dim = rewriter.create(loc, A, (isTransA) ? 1 : 0); - allocOperands.emplace_back(dim); - } - if (memRefShape[1] < 0) { - auto dim = rewriter.create(loc, B, (isTransB) ? 0 : 1); - allocOperands.emplace_back(dim); - } - alloc = rewriter.create(loc, memRefType, allocOperands); - if (insertDealloc) { - auto *parentBlock = alloc.getDefiningOp()->getBlock(); - auto dealloc = rewriter.create(loc, alloc); - dealloc.getOperation()->moveBefore(&parentBlock->back()); - } - } - - // Number of loops - auto memRefShape = memRefType.getShape(); - int64_t numLoops = 3; - - // Define loops. - std::vector originalLoops; - std::vector optimizedLoops; - Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, - optimizedLoops, numLoops); - - // We have two Krnl loops: - // - Outer loop iterates over the output matrix dimensions, and - // - Reduction loop iterates over the reduction dimension. - - // Outer loop - std::vector outerLoops, optimizedOuterLoops; - outerLoops.reserve(2); - optimizedOuterLoops.reserve(2); - for (int i = 0; i < 2; ++i) { - outerLoops.push_back(originalLoops[i]); - optimizedOuterLoops.push_back(optimizedLoops[i]); - } - KrnlIterateOperandPack outerPack(rewriter, outerLoops, - optimizedOuterLoops); - // Induction variables for the outer loops - for (int i = 0; i < 2; ++i) - addDimensionToPack(rewriter, loc, outerPack, alloc, i); - - // Reduction loop - std::vector reductionLoops, optimizedReductionLoops; - reductionLoops.reserve(1); - optimizedReductionLoops.reserve(1); - reductionLoops.push_back(originalLoops[2]); - optimizedReductionLoops.push_back(optimizedLoops[2]); - KrnlIterateOperandPack reductionPack(rewriter, reductionLoops, - optimizedReductionLoops); - // Induction variable for the reduction dimension - // Try to find and use a static value from A or B first. - // If it failed then use a dynamic value. - auto ATy = A.getType().cast(); - auto BTy = B.getType().cast(); - int64_t K_A_Idx = (isTransA) ? 0 : 1; - int64_t K_B_Idx = (isTransB) ? 1 : 0; - reductionPack.pushConstantBound(0); - if (ATy.getShape()[K_A_Idx] != -1) - reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]); - else - if (BTy.getShape()[K_B_Idx] != -1) - reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]); - else - reductionPack.pushOperandBound( - rewriter.create(loc, B, K_B_Idx).getResult()); - - // Get run-time dimension information for unknown dimensions used for - // broadcasting. - // GemmOp supports unidirectional broadcasting from C to A*B. - // Hence, it must be enough to get broadcasting information for C only. - std::map broadcastedDimInfo; - auto shape = C.getType().cast().getShape(); - for (int i = 0; i < shape.size(); ++i) { - if (shape[i] < 0) { - auto dim = rewriter.create(loc, C, i).getResult(); - auto one = rewriter.create(loc, 1); - auto isBroadcasted = - rewriter.create(loc, CmpIPredicate::eq, dim, one); - broadcastedDimInfo.insert(std::make_pair(i, isBroadcasted)); - } - } - - auto outerIterateOp = rewriter.create(loc, outerPack); - - // Now perform the insertions into the body of the - // just generated instructions: - - // No optimization - rewriter.setInsertionPointToEnd(optimizationBlock); - rewriter.create(loc, originalLoops); - - // Insert instructions inside the outer loop. - Block &outerIterationBlock = outerIterateOp.bodyRegion().front(); - rewriter.setInsertionPointToStart(&outerIterationBlock); - - // Induction variables - SmallVector loopMNIVs; - for (auto arg : outerIterationBlock.getArguments()) { - loopMNIVs.emplace_back(arg); - } - - // Initialize the output of A*B - auto zero = rewriter.create( - loc, FloatAttr::get(memRefType.getElementType(), 0)); - rewriter.create(loc, zero, alloc, loopMNIVs); - - // Compute A*B - auto matmulIterateOp = rewriter.create(loc, reductionPack); - - // Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting) - auto loopCIVs = getLoopIVsForBroadcasting( - loc, rewriter, loopMNIVs, C, broadcastedDimInfo); - auto loadedC = rewriter.create(loc, C, loopCIVs); - auto loadedAB = rewriter.create(loc, alloc, loopMNIVs); - auto alphaAB = rewriter.create(loc, alpha, loadedAB); - auto betaC = rewriter.create(loc, beta, loadedC); - auto Y = rewriter.create(loc, alphaAB, betaC); - rewriter.create(loc, Y, alloc, loopMNIVs); - - // Insert instructions to do matrix multiplication: A*B - Block &matmulIterationBlock = matmulIterateOp.bodyRegion().front(); - rewriter.setInsertionPointToStart(&matmulIterationBlock); - - // Induction variables - SmallVector loopKIVs, loopAIVs, loopBIVs; - for (auto arg : matmulIterationBlock.getArguments()) - loopKIVs.emplace_back(arg); - if (isTransA) { - loopAIVs.emplace_back(loopKIVs[0]); - loopAIVs.emplace_back(loopMNIVs[0]); - } else { - loopAIVs.emplace_back(loopMNIVs[0]); - loopAIVs.emplace_back(loopKIVs[0]); - } - if (isTransB) { - loopBIVs.emplace_back(loopMNIVs[1]); - loopBIVs.emplace_back(loopKIVs[0]); - } else { - loopBIVs.emplace_back(loopKIVs[0]); - loopBIVs.emplace_back(loopMNIVs[1]); - } - - // Matmul computation - auto loadedA = rewriter.create(loc, A, loopAIVs); - auto loadedB = rewriter.create(loc, B, loopBIVs); - auto loadedY = rewriter.create(loc, alloc, loopMNIVs); - auto AB = rewriter.create(loc, loadedA, loadedB); - auto accumulated = rewriter.create(loc, loadedY, AB); - rewriter.create(loc, accumulated, alloc, loopMNIVs); - - rewriter.replaceOp(op, alloc); - - return matchSuccess(); - } -}; - -struct ONNXUnsqueezeOpLowering : public ConversionPattern { - ONNXUnsqueezeOpLowering(MLIRContext *ctx) - : ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - auto loc = op->getLoc(); - auto memRefType = convertToMemRefType(*op->result_type_begin()); - int outRank = memRefType.getRank(); - - // Assume that `axes` has been validated by shape inference. - // So, here we just get it. - ArrayAttr axisAttrs = llvm::dyn_cast(op).axesAttr(); - SmallVector axes; - for (auto axisAttr : axisAttrs.getValue()) { - int axis = axisAttr.cast().getInt(); - axis = axis >= 0 ? axis : (outRank + axis); - axes.emplace_back(axis); - } - - // Insert an allocation and deallocation for the result of this operation. - Value alloc; - - // Compute size in bytes. - Value tensorSize = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), - getMemRefEltSizeInBytes(memRefType))); - - bool insertDealloc = checkInsertDealloc(op); - auto memRefShape = memRefType.getShape(); - if (hasAllConstantDimensions(memRefType)) { - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); - for (int i = 0; i < memRefShape.size(); ++i) { - Value dimVal = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), - memRefShape[i])); - tensorSize = rewriter.create(loc, tensorSize, dimVal); - } - } else { - // Unknown dimensions are always the operand's dimensions. - SmallVector allocOperands; - for (int outIdx = 0, inIdx = 0; outIdx < memRefShape.size(); ++outIdx) { - Value dimVal = nullptr; - if (memRefShape[outIdx] < 0) { - Value index = rewriter.create(loc, operands[0], inIdx); - dimVal = rewriter.create( - loc, index, rewriter.getIntegerType(64)); - allocOperands.emplace_back(index); - } else { - dimVal = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), - memRefShape[outIdx])); - } - tensorSize = rewriter.create(loc, tensorSize, dimVal); - if (std::find(axes.begin(), axes.end(), outIdx) == axes.end()) - inIdx++; - } - alloc = rewriter.create(loc, memRefType, allocOperands); - auto *parentBlock = alloc.getDefiningOp()->getBlock(); - if (insertDealloc) { - auto dealloc = rewriter.create(loc, alloc); - dealloc.getOperation()->moveBefore(&parentBlock->back()); - } - } - rewriter.create(loc, alloc, operands[0], tensorSize); - rewriter.replaceOp(op, alloc); - return matchSuccess(); - } -}; - -struct ONNXTransposeOpLowering : public ConversionPattern { - ONNXTransposeOpLowering(MLIRContext *ctx) - : ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - auto loc = op->getLoc(); - // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertToMemRefType(*op->result_type_begin()); - Value alloc; - bool insertDealloc = checkInsertDealloc(op); - - if (hasAllConstantDimensions(memRefType)) - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); - else - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, - {operands[0]}); - - // Number of loops - auto memRefShape = memRefType.getShape(); - int64_t rank = memRefShape.size(); - - // Define loops. - std::vector originalLoops; - std::vector optimizedLoops; - Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, - optimizedLoops, rank); - - KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); - // Iterate over the loop nest using the input shape. - for (int i = 0; i < rank; ++i) - addDimensionToPack(rewriter, loc, pack, operands[0], i); - - auto iterateOp = rewriter.create(loc, pack); - Block &iterationBlock = iterateOp.bodyRegion().front(); - - // Now perform the insertions into the body of the - // just generated instructions: - - // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. - rewriter.setInsertionPointToEnd(optimizationBlock); - // Return from KrnlOptimizeLoopsOp body. - // When no optimizations are present we just return the loops - // unchaged. - rewriter.create(loc, originalLoops); - - // 2. Insert instructions inside the KernelIterateOp body. - rewriter.setInsertionPointToStart(&iterationBlock); - - // Handle the operation. - - // Read perm attribute. - SmallVector perm; - auto permAttribute = llvm::dyn_cast(op).permAttr(); - if (permAttribute) { - for (auto permVal : permAttribute.getValue()) - perm.emplace_back(permVal.cast().getInt()); - } else { - // TODO: Remove when perm is guaranteed to be present (even for - // the default case). This means that perm was added by shape - // inference or another pass to contain the values corresponding - // to the default behavior of Transpose. - for (int i = iterationBlock.getArguments().size()-1; i >= 0; i--) - perm.emplace_back(i); - } - - SmallVector inLoopIVs; - for (auto arg : iterationBlock.getArguments()) - inLoopIVs.emplace_back(arg); - - SmallVector outLoopIVs; - for (int i=0; i(loc, operands[0], inLoopIVs); - rewriter.create(loc, inVal, alloc, outLoopIVs); - - rewriter.replaceOp(op, alloc); - - return matchSuccess(); - } -}; - -struct ONNXIdentityOpLowering : public ConversionPattern { - ONNXIdentityOpLowering(MLIRContext *ctx) - : ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - rewriter.replaceOp(op, operands[0]); - return matchSuccess(); - } -}; - -struct ONNXConvNoBiasOpLowering : public ConversionPattern { - ONNXConvNoBiasOpLowering(MLIRContext *ctx) - : ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - auto loc = op->getLoc(); - // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertToMemRefType(*op->result_type_begin()); - Value alloc; - bool insertDealloc = checkInsertDealloc(op); - ONNXConvNoBiasOp convOp = llvm::dyn_cast(op); - - if (hasAllConstantDimensions(memRefType)) - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); - else - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, - {operands[0]}); - - auto resultShape = memRefType.getShape(); - auto inputShape = operands[0].getType().cast().getShape(); - auto kernelShape = operands[1].getType().cast().getShape(); - - // R = ConvNoBias(D, K) - // - // The input/output shapes will look like this: - // - // D (NxCxHxW) x K (MxC/groupxKHxKW) -> R (NxMxRHxRW) - // - // M is a multiple of the number of groups: - // M = group * kernelsPerGroup - // - // The loop nest will look as follows: - // - // strides = [s1, s2] - // - // kernelsPerGroup = M / group; - // for n = 0 .. N: - // for g = 0 .. group: - // for m = 0 .. kernelsPerGroup: - // kernel = g * kernelsPerGroup + m; - // for r1 = 0 .. RH: - // for r2 = 0 .. RW: - // R[n][kernel][r1][r2] = 0; - // for c = 0 .. C/group: - // for k1 = 0 .. KH: - // for k2 = 0 .. KW: - // R[n][kernel][r1][r2] = - // D[n][g * (C / group) + c][s1 * r1 + k1][s2 * r2 + k2] * - // K[kernel][c][k1][k2]; - // - // Naming: - // n, g, m: outer loop nest indices - // r1, r2: spatial loop nest indices - // c, k1, k2: inner loop nest indices - // - // TODO: handle padding. - // - // In the general case: - // - // D (NxCxD1xD2x...xDdim) x K (MxC/groupxK1xK2x...xKdim) - // -> R (NxMxR1xR2x...xRdim) - // - // The above loop nest can be adapted by increasing the number - // of r- and k-index loop i.e. r1 r2 and k1 k2 loops. - - // Set up outermost loops: n g m r1 r2 ... rdim - // Skip g if group is 1. - - // Before we start the iteration we need to compute the number of - // unsplit kernels and fetch the number of groups from the attribute - // list. Group is always a compilation constant. - int64_t group = convOp.group().getSExtValue(); - // Compute the number of unsplit kernels. The number of kernels - // must be a multiple of the number of groups. - int64_t kernelsPerGroup = floor(kernelShape[0] / group); - auto kernelsPerGroupValue = - rewriter.create(loc, kernelsPerGroup); - auto zero = rewriter.create( - loc, FloatAttr::get(memRefType.getElementType(), 0)); - Value subchannels; - if (kernelShape[1] < 0) { - subchannels = - rewriter.create(loc, operands[1], 1).getResult(); - } else { - subchannels = rewriter.create( - loc, kernelShape[1]); - } - - // 1. Define outer loops and emit empty optimization block: - int64_t nOuterLoops = (group > 1) ? 3 : 2; - std::vector outerLoops; - std::vector optimizedOuterLoops; - Block *optimizationBlock = defineLoops(rewriter, loc, outerLoops, - optimizedOuterLoops, nOuterLoops); - - // Prepare iteration arguments over outer loop nest. - KrnlIterateOperandPack pack( - rewriter, outerLoops, optimizedOuterLoops); - // for n = 0 .. N: - pack.pushConstantBound(0); - if (inputShape[0] < 0) - pack.pushOperandBound( - rewriter.create(loc, operands[0], 0).getResult()); - else - pack.pushConstantBound(inputShape[0]); - // for g = 0 .. N: - if (group > 1) { - pack.pushConstantBound(0); - pack.pushConstantBound(group); - } - // for m = 0 .. kernelsPerGroup: - pack.pushConstantBound(0); - pack.pushConstantBound(kernelsPerGroup); - // Outer loop iteration. - auto iterateOp = rewriter.create(loc, pack); - Block &outerIterationBlock = iterateOp.bodyRegion().front(); - // Emit optimizations for outer loops: - rewriter.setInsertionPointToEnd(optimizationBlock); - rewriter.create(loc, outerLoops); - rewriter.setInsertionPointToStart(&outerIterationBlock); - { - // 2. Emit the body of the outer loop nest. - - // 2.1 Compute kernel order number: kernel = g * kernelsPerGroup + m; - // If group is not set then the value of the kernel ID is - // identical to that of the loop over kernels. - Value kernel = outerIterationBlock.getArguments()[1]; - if (group > 1) { - // Middle loop is over groups and third loop is over the - // kernel identifiers in the current group. - auto kernelsOffset = rewriter.create(loc, - outerIterationBlock.getArguments()[1], - kernelsPerGroupValue); - kernel = rewriter.create(loc, kernelsOffset, - outerIterationBlock.getArguments()[2]); - } - - // 2.2 Define spatial loops - int64_t nSpatialLoops = resultShape.size() - 2; - std::vector spatialLoops; - std::vector optimizedSpatialLoops; - Block *optSpatialLoopBlock = defineLoops(rewriter, loc, spatialLoops, - optimizedSpatialLoops, nSpatialLoops); - - // 2.3 Prepare iteration arguments for spatial loop nest. - KrnlIterateOperandPack spatialPack( - rewriter, spatialLoops, optimizedSpatialLoops); - for (int i = 2; i < resultShape.size(); ++i) - addDimensionToPack(rewriter, loc, spatialPack, alloc, i); - - // 2.4 Emit loop nest over output spatial dimensions. - // for rX = 0 .. RX - auto spatialIterateOp = - rewriter.create(loc, spatialPack); - Block &spatialIterationBlock = spatialIterateOp.bodyRegion().front(); - // 2.5 Emit optimizations for outer loops: - rewriter.setInsertionPointToEnd(optSpatialLoopBlock); - rewriter.create(loc, spatialLoops); - rewriter.setInsertionPointToStart(&spatialIterationBlock); - { - // 3. Emit the body of the spatial loop nest. - // 3.1 Emit: R[n][kernel][r1][r2] = 0; - SmallVector resultIndices; - // n - resultIndices.emplace_back(outerIterationBlock.getArguments()[0]); - // kernel - resultIndices.emplace_back(kernel); - // rX - for (auto arg : spatialIterationBlock.getArguments()) - resultIndices.emplace_back(arg); - // Store initializer value into output location. - rewriter.create(loc, zero, alloc, resultIndices); - - // 3.2 Define inner loops. - int64_t nInnerLoops = 1 + (kernelShape.size() - 2); - std::vector innerLoops; - std::vector optimizedInnerLoops; - Block *optInnerLoopBlock = defineLoops(rewriter, loc, innerLoops, - optimizedInnerLoops, nInnerLoops); - - // 3.3 Prepare iteration arguments for inner loop nest. - KrnlIterateOperandPack innerPack( - rewriter, innerLoops, optimizedInnerLoops); - // for c = 0 .. C/group - innerPack.pushConstantBound(0); - innerPack.pushConstantBound(kernelShape[1]); - // for Kx = 0 .. KX - for (int i = 2; i < kernelShape.size(); ++i) - addDimensionToPack(rewriter, loc, innerPack, operands[1], i); - - // 3.4 Emit inner loop nest. - auto innerIterateOp = - rewriter.create(loc, innerPack); - Block &innerIterationBlock = innerIterateOp.bodyRegion().front(); - // 3.5 Emit optimizations for outer loops: - rewriter.setInsertionPointToEnd(optInnerLoopBlock); - rewriter.create(loc, innerLoops); - rewriter.setInsertionPointToStart(&innerIterationBlock); - { - // 4. Emit inner loop body - // R[n][kernel][r1][r2] = - // D[n][g * (C / group) + c][s1 * r1 + k1][s2 * r2 + k2] * - // K[kernel][c][k1][k2]; - - // 4.1 Prepare indices for accesing the data tensor. - SmallVector dataIndices; - // n - dataIndices.emplace_back(outerIterationBlock.getArguments()[0]); - // g * (C / group) + c - Value channelDepth = innerIterationBlock.getArguments()[0]; - if (group > 1) - channelDepth = rewriter.create(loc, channelDepth, - rewriter.create(loc, subchannels, - outerIterationBlock.getArguments()[1])); - dataIndices.emplace_back(channelDepth); - // sX * rX + kX - auto stridesAttribute = convOp.stridesAttr(); - // Read strides attribute - SmallVector strides; - if (stridesAttribute) - for (auto stride : stridesAttribute.getValue()) - strides.emplace_back(stride.cast().getInt()); - for (int i = 0; i < kernelShape.size() - 2; ++i) { - Value spatialIndex = spatialIterationBlock.getArguments()[i]; - // If strides are present then emit the correct access index. - if (stridesAttribute && strides[i] > 1) - spatialIndex = rewriter.create(loc, - rewriter.create(loc, strides[i]), - spatialIterationBlock.getArguments()[i]); - dataIndices.emplace_back( - rewriter.create(loc, spatialIndex, - innerIterationBlock.getArguments()[i+1])); - } - - // 4.2 Prepare indices for accessing the kernel tensor. - SmallVector kernelIndices; - // kernel - kernelIndices.emplace_back(kernel); - // c - kernelIndices.emplace_back(innerIterationBlock.getArguments()[0]); - // kX - for (int i = 0; i < kernelShape.size() - 2; ++i) - kernelIndices.emplace_back( - innerIterationBlock.getArguments()[i+1]); - - // 4.3 Compute convolution. - auto loadData = - rewriter.create(loc, operands[0], dataIndices); - auto loadKernel = - rewriter.create(loc, operands[1], kernelIndices); - auto loadPartialSum = - rewriter.create(loc, alloc, resultIndices); - Value result = rewriter.create(loc, loadPartialSum, - rewriter.create(loc, loadData, loadKernel)); - // 4.4 Store computed value into output location. - rewriter.create(loc, result, alloc, resultIndices); - } - } - } - rewriter.replaceOp(op, alloc); - - return matchSuccess(); - } -}; - -//===----------------------------------------------------------------------===// -// Reduction ops lowering to Krnl dialect. -//===----------------------------------------------------------------------===// -template -struct ONNXReductionOpLowering : public ConversionPattern { - ONNXReductionOpLowering(MLIRContext *ctx) - : ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - /* - * Condition: reduction function must be associative and commutative. - * - * Example 1 (here, reduction function is `+`): - * Induction variables: (i0, i1, i2) - * axes = [0, 2] - * keepdims = true - * krnl.iterate() with (i0, i1, i2) { - * Y(0, i1, 0) += X(i0, i1, i2) - * } - * - * Example 2 (here, reduction function is `+`): - * Induction variables: (i0, i1, i2) - * axes = [0, 2] - * keepdims = false - * krnl.iterate() with (i0, i1, i2) { - * Y(i1) += X(i0, i1, i2) - * } - * - */ - auto loc = op->getLoc(); - auto memRefInType = operands[0].getType().cast(); - auto memRefInShape = memRefInType.getShape(); - auto memRefOutType = convertToMemRefType(*op->result_type_begin()); - int64_t inRank = memRefInType.getRank(); - int64_t outRank = memRefOutType.getRank(); - - // Get attributes - ArrayAttr axisAttrs = llvm::dyn_cast(op).axesAttr(); - std::vector axes; - if (axisAttrs) { - for (auto axisAttr : axisAttrs.getValue()) { - int64_t axis = axisAttr.cast().getInt(); - axis = axis >= 0 ? axis : (inRank + axis); - assert(axis >= -inRank && axis <= inRank - 1); - if (std::find(axes.begin(), axes.end(), axis) == axes.end()) - axes.push_back(axis); - } - } else { - for (decltype(inRank) i = 0; i < inRank; ++i) { - axes.push_back(i); - } - } - // KeepDims - auto keepdims = - llvm::dyn_cast(op).keepdims(); - bool isKeepdims = (keepdims == 1) ? true : false; - - // Get type information - auto memRefOutShape = memRefOutType.getShape(); - auto elementOutType = memRefOutType.getElementType(); - std::map outInDimMap = - getReductionMapping(memRefInType, axes, isKeepdims); - - // Insert an allocation and deallocation for the result of this operation. - Value alloc; - bool insertDealloc = checkInsertDealloc(op); - if (hasAllConstantDimensions(memRefOutType)) { - alloc = insertAllocAndDealloc(memRefOutType, loc, rewriter, insertDealloc); - } else { - SmallVector allocOperands; - for (decltype(outRank) i = 0; i < outRank; ++i) { - if (memRefOutShape[i] < 0) { - auto dim = rewriter.create(loc, operands[0], outInDimMap[i]); - allocOperands.push_back(dim); - } - } - alloc = rewriter.create(loc, memRefOutType, allocOperands); - if (insertDealloc) { - auto *parentBlock = alloc.getDefiningOp()->getBlock(); - auto dealloc = rewriter.create(loc, alloc); - dealloc.getOperation()->moveBefore(&parentBlock->back()); - } - } - - // There are two Krnl loops: - // - One to initialize the result memref, and - // - One to do reduction - - // Define loops to initialize the result. - std::vector originalLoopsInit; - std::vector optimizedLoopsInit; - Block *optimizationBlockInit = defineLoops(rewriter, loc, originalLoopsInit, - optimizedLoopsInit, outRank); - - // Iteration information - KrnlIterateOperandPack packInit(rewriter, originalLoopsInit, - optimizedLoopsInit); - for (decltype(outRank) i = 0; i < outRank; ++i) { - addDimensionToPack(rewriter, loc, packInit, alloc, i); - } - auto iterateOpInit = rewriter.create(loc, packInit); - Block &iterationBlockInit = iterateOpInit.bodyRegion().front(); - - // Perform the insertions into the body of the initialization loop. - // No optimization - rewriter.setInsertionPointToEnd(optimizationBlockInit); - rewriter.create(loc, originalLoopsInit); - - // Insert instructions inside the KernelIterateOp body. - rewriter.setInsertionPointToStart(&iterationBlockInit); - - // Handle the operation: - SmallVector loopIVs; - for (auto arg : iterationBlockInit.getArguments()) { - loopIVs.push_back(arg); - } - - Value identity; - if (elementOutType.isa()) { - identity = rewriter.create( - loc, FloatAttr::get(elementOutType, - getIdentityValue())); - } else if (elementOutType.isa()) { - identity = rewriter.create( - loc, IntegerAttr::get(elementOutType, - getIdentityValue())); - } else { - emitError(loc, "unsupported element type"); - } - rewriter.create(loc, identity, alloc, loopIVs); - - // Define an Krnl loop to do reduction. - rewriter.setInsertionPointAfter(iterateOpInit); - std::vector originalLoops, optimizedLoops; - Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, - optimizedLoops, inRank); - // Iteration information - KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); - for (decltype(inRank) i = 0; i < inRank; ++i) { - addDimensionToPack(rewriter, loc, pack, operands[0], i); - } - auto iterateOp = rewriter.create(loc, pack); - Block &iterationBlock = iterateOp.bodyRegion().front(); - - // Perform the insertions into the body of the reduction loop. - // No optimization - rewriter.setInsertionPointToEnd(optimizationBlock); - rewriter.create(loc, originalLoops); - - // Insert instructions inside the KernelIterateOp body. - rewriter.setInsertionPointToStart(&iterationBlock); - - // Handle the operation: - SmallVector inLoopIVs, outLoopIVs; - auto args = iterationBlock.getArguments(); - for (int i = 0; i < args.size(); ++i) { - inLoopIVs.push_back(args[i]); - } - Value zeroIndex = nullptr; - for (decltype(inRank) i = 0; i < outRank; ++i) { - if (outInDimMap.find(i) != outInDimMap.end()) { - outLoopIVs.push_back(inLoopIVs[outInDimMap[i]]); - } else { - if (zeroIndex) { - outLoopIVs.push_back(zeroIndex); - } else { - zeroIndex = rewriter.create(loc, 0); - outLoopIVs.push_back(zeroIndex); - } - } - } - - Value next, accumulated; - next = rewriter.create(loc, operands[0], inLoopIVs); - accumulated = rewriter.create(loc, alloc, outLoopIVs); - accumulated = mapToLowerScalarOp( - op, memRefOutType.getElementType(), {accumulated, next}, rewriter); - rewriter.create(loc, accumulated, alloc, outLoopIVs); - - rewriter.replaceOp(op, alloc); - return matchSuccess(); - } -}; - -//===----------------------------------------------------------------------===// -// EntryPoint Op lowering to Krnl Entry Point. -//===----------------------------------------------------------------------===// - -class ONNXEntryPointLowering : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - PatternMatchResult matchAndRewrite(ONNXEntryPointOp op, - PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, - op.getAttrOfType( - ONNXEntryPointOp::getEntryPointFuncAttrName()), - op.getAttrOfType(ONNXEntryPointOp::getNumInputsAttrName()), - op.getAttrOfType( - ONNXEntryPointOp::getNumOutputsAttrName())); - return matchSuccess(); - } -}; - -//===----------------------------------------------------------------------===// -// Conversion from Tensor type to the Standard dialect MemRef type. -//===----------------------------------------------------------------------===// - -struct TensorTypeConverter : public TypeConverter { - using TypeConverter::TypeConverter; - - LogicalResult convertType(Type t, SmallVectorImpl &results) override { - if (auto type = convertToMemRefType(t)) { - results.push_back(type); - return success(); - } - - results.push_back(t); - return success(); - } - - /// Return true if the inputs and outputs of the given function type are - /// legal. [Taken from MLIR and adapted to only check the legality of the - /// inputs. Once unranked results can be handled gracefully this - /// override needs to be removed in favour of the original MLIR one.] - bool isSignatureLegal(FunctionType funcType) { - return llvm::all_of(funcType.getInputs(), - [this](Type type) { return isLegal(type); }); - } -}; - -} // end anonymous namespace. - -//===----------------------------------------------------------------------===// -// Frontend to Krnl Dialect lowering pass -//===----------------------------------------------------------------------===// - -/// This is a partial lowering to Krnl loops of the ONNX operations. -namespace { -struct FrontendToKrnlLoweringPass - : public ModulePass { - void runOnModule() final; -}; -} // end anonymous namespace. - -void FrontendToKrnlLoweringPass::runOnModule() { - auto module = getModule(); - - // The first thing to define is the conversion target. This will define the - // final target for this lowering. - ConversionTarget target(getContext()); - - // We define the specific operations, or dialects, that are legal targets for - // this lowering. - target - .addLegalDialect(); - - // TODO: enable this once more ops are supported. - // We also define the ONNX dialect as Illegal so that the conversion will fail - // if any of these operations are *not* converted. - // target.addIllegalDialect(); - - // TODO: add any other ops which are considered legal. - // Some operations can be marked as being still legal. - // Example: target.addLegalOp(); - - // Now that the conversion target has been defined, we just need to provide - // the set of patterns that will lower the frontend operations. - OwningRewritePatternList patterns; - - // Convert TensorType to MemRef - TensorTypeConverter tensor_to_memref_converter; - target.addDynamicallyLegalOp([&](FuncOp op) { - // FuncOp is legal only if types have been converted to Std types. - return tensor_to_memref_converter.isSignatureLegal(op.getType()); - }); - - // Type conversion for function signatures. - // Call MLIR FuncOp signature conversion when result type is - // a ranked tensor. - populateFuncOpTypeConversionPattern(patterns, &getContext(), - tensor_to_memref_converter); - - // Frontent operation lowering. - patterns.insert, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXReshapeOpLowering, ONNXEntryPointLowering, - ONNXReductionOpLowering, - ONNXReductionOpLowering, - ONNXReductionOpLowering, - ONNXReductionOpLowering, - ONNXSoftmaxOpLowering, ONNXGemmOpLowering, - ONNXUnsqueezeOpLowering, ONNXTransposeOpLowering, - ONNXIdentityOpLowering, ONNXConvNoBiasOpLowering - >(&getContext()); - - // With the target and rewrite patterns defined, we can now attempt the - // conversion. The conversion will signal failure if any of our `illegal` - // operations were not converted successfully. - if (failed(applyPartialConversion(module, target, patterns))) - signalPassFailure(); -} - -std::unique_ptr mlir::createLowerToKrnlPass() { - return std::make_unique(); -} - -static PassRegistration - pass("lower-frontend", "Lower frontend ops to Krnl dialect.");