diff --git a/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp b/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp index 542a2c1..dec03cf 100644 --- a/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp +++ b/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp @@ -37,10 +37,18 @@ static bool hasAllConstantDimensions(MemRefType type) { return true; } -/// Convert the given TensorType into the corresponding MemRefType. -static MemRefType convertTensorToMemRef(TensorType type) { - assert(type.hasRank() && "expected only ranked shapes"); - return MemRefType::get(type.getShape(), type.getElementType()); +/// Get the corresponding MemRefType of a given TensorType/MemRefType. +static MemRefType convertToMemRefType(Type type) { + MemRefType memRefType; + auto tensorType = type.dyn_cast(); + if (tensorType) { + assert(tensorType.hasRank() && "expected only ranked shapes"); + memRefType = + MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + } else { + memRefType = type.dyn_cast(); + } + return memRefType; } /// Insert an allocation and deallocation for the given MemRefType. @@ -430,8 +438,8 @@ struct TensorTypeConverter : public TypeConverter { } static LogicalResult convertType(Type t, SmallVectorImpl &results) { - if (auto tensor_type = t.dyn_cast()) { - results.push_back(convertTensorToMemRef(tensor_type)); + if (auto type = convertToMemRefType(t)) { + results.push_back(type); return success(); } diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/math/elementwise.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/math/elementwise.inc index b48e23a..945d4da 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/math/elementwise.inc +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/math/elementwise.inc @@ -476,11 +476,10 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { // TODO: Check that the types are valid. // An element-wise unary operation must have all operands and the result of // the same type. This should have been verified by the verifier. - auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertTensorToMemRef(tensorType); + auto memRefType = convertToMemRefType(*op->result_type_begin()); // If the output has a dynamic dimension, pass the operands required for // each dynamic dimension to the AllocOp. The first operand of the @@ -545,12 +544,11 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { // TODO: Check that the types are valid. // An element-wise variadic operation must have all operands and the result // of the same type. This should have been verified by the verifier. - auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); auto numArgs = op->getNumOperands(); // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertTensorToMemRef(tensorType); + auto memRefType = convertToMemRefType(*op->result_type_begin()); Value alloc; bool insertDealloc = checkInsertDealloc(op); diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc index f25dc44..d8bbc55 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc @@ -15,7 +15,6 @@ struct ONNXGemmOpLowering : public ConversionPattern { PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); Value A, B, C; @@ -23,9 +22,11 @@ struct ONNXGemmOpLowering : public ConversionPattern { B = operands[1]; C = operands[2]; - auto alphaAttr = FloatAttr::get(tensorType.getElementType(), + auto memRefType = convertToMemRefType(*op->result_type_begin()); + + auto alphaAttr = FloatAttr::get(memRefType.getElementType(), llvm::dyn_cast(op).alpha().convertToFloat()); - auto betaAttr = FloatAttr::get(tensorType.getElementType(), + auto betaAttr = FloatAttr::get(memRefType.getElementType(), llvm::dyn_cast(op).beta().convertToFloat()); auto alpha = rewriter.create(loc, alphaAttr); auto beta = rewriter.create(loc, betaAttr); @@ -33,9 +34,6 @@ struct ONNXGemmOpLowering : public ConversionPattern { bool isTransA = (llvm::dyn_cast(op).transA() != 0); bool isTransB = (llvm::dyn_cast(op).transB() != 0); - // Result type - auto memRefType = convertTensorToMemRef(tensorType); - // Insert an allocation and deallocation for the result of this operation. Value alloc; bool insertDealloc = checkInsertDealloc(op); diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/math/matmul.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/math/matmul.inc index 5c6ebd7..1af1f1b 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/math/matmul.inc +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/math/matmul.inc @@ -15,7 +15,6 @@ struct ONNXMatMulOpLowering : public ConversionPattern { PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); Value A = operands[0]; @@ -29,7 +28,7 @@ struct ONNXMatMulOpLowering : public ConversionPattern { // - Both arguments are 1-D // Result type - auto memRefType = convertTensorToMemRef(tensorType); + auto memRefType = convertToMemRefType(*op->result_type_begin()); auto elementType = memRefType.getElementType(); auto memRefShape = memRefType.getShape(); diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/math/reduction.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/math/reduction.inc index 27f594e..9b94861 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/math/reduction.inc +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/math/reduction.inc @@ -145,9 +145,9 @@ struct ONNXReductionOpLowering : public ConversionPattern { auto loc = op->getLoc(); auto memRefInType = operands[0].getType().cast(); auto memRefInShape = memRefInType.getShape(); - auto tensorOutType = (*op->result_type_begin()).cast(); + auto memRefOutType = convertToMemRefType(*op->result_type_begin()); int64_t inRank = memRefInType.getRank(); - int64_t outRank = tensorOutType.getRank(); + int64_t outRank = memRefOutType.getRank(); // Get attributes ArrayAttr axisAttrs = llvm::dyn_cast(op).axesAttr(); @@ -171,7 +171,6 @@ struct ONNXReductionOpLowering : public ConversionPattern { bool isKeepdims = (keepdims == 1) ? true : false; // Get type information - auto memRefOutType = convertTensorToMemRef(tensorOutType); auto memRefOutShape = memRefOutType.getShape(); auto elementOutType = memRefOutType.getElementType(); std::map outInDimMap = diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/math/softmax.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/math/softmax.inc index eb126c0..3f24a6e 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/math/softmax.inc +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/math/softmax.inc @@ -18,8 +18,8 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern { // let exp_x = exp(x - max_x) in // let sum = sum(exp_x) in // exp_x / sum - auto tensorType = (*op->result_type_begin()).cast(); - int64_t rank = tensorType.getRank(); + auto memRefType = convertToMemRefType(*op->result_type_begin()); + int64_t rank = memRefType.getRank(); int64_t axis = llvm::dyn_cast(op).axis().getSExtValue(); axis = axis >= 0 ? axis : rank + axis; assert(axis >= -rank && axis <= rank - 1); @@ -27,7 +27,6 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern { auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertTensorToMemRef(tensorType); auto elementType = memRefType.getElementType(); Value alloc; diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc index 3ecfa3e..20ac5e8 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc @@ -15,10 +15,9 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern { PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertTensorToMemRef(tensorType); + auto memRefType = convertToMemRefType(*op->result_type_begin()); Value alloc; bool insertDealloc = checkInsertDealloc(op); ONNXConvNoBiasOp convOp = llvm::dyn_cast(op); diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/reshape.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/reshape.inc index ed2b185..b64494f 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/reshape.inc +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/reshape.inc @@ -15,12 +15,11 @@ struct ONNXReshapeOpLowering : public ConversionPattern { PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - auto tensorType = (*op->result_type_begin()).cast(); auto inputShape = operands[0].getType().cast().getShape(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertTensorToMemRef(tensorType); + auto memRefType = convertToMemRefType(*op->result_type_begin()); auto memRefShape = memRefType.getShape(); Value alloc; diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/transpose.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/transpose.inc index 39cfa8c..3bb897a 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/transpose.inc +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/transpose.inc @@ -15,10 +15,9 @@ struct ONNXTransposeOpLowering : public ConversionPattern { PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertTensorToMemRef(tensorType); + auto memRefType = convertToMemRefType(*op->result_type_begin()); Value alloc; bool insertDealloc = checkInsertDealloc(op); diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc index 18b9f8b..6d5289d 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc @@ -16,8 +16,8 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern { matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - auto tensorType = (*op->result_type_begin()).cast(); - int outRank = tensorType.getRank(); + auto memRefType = convertToMemRefType(*op->result_type_begin()); + int outRank = memRefType.getRank(); // Assume that `axes` has been validated by shape inference. // So, here we just get it. @@ -30,7 +30,6 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern { } // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertTensorToMemRef(tensorType); Value alloc; // Compute size in bytes. diff --git a/src/pass/lower_frontend_to_krnl.cpp b/src/pass/lower_frontend_to_krnl.cpp new file mode 100644 index 0000000..d609bc5 --- /dev/null +++ b/src/pass/lower_frontend_to_krnl.cpp @@ -0,0 +1,2380 @@ +//====- lower_frontend_to_krnl.cpp - Frontend dialects to Krnl lowering ---===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file implements the lowering of frontend operations to a combination of +// Krnl IR and standard operations. +// +//===----------------------------------------------------------------------===// +#include + +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Sequence.h" + +#include "src/dialect/krnl/krnl_helper.hpp" +#include "src/dialect/krnl/krnl_ops.hpp" +#include "src/dialect/onnx/onnx_ops.hpp" +#include "src/pass/passes.hpp" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// FrontendToAffine RewritePatterns +//===----------------------------------------------------------------------===// + +/// Check is all dimensions are known at compile time. +static bool hasAllConstantDimensions(MemRefType type) { + auto memRefShape = type.getShape(); + for (int i = 0; i < memRefShape.size(); ++i) + if (memRefShape[i] < 0) + return false; + return true; +} + +/// Get the corresponding MemRefType of a given TensorType/MemRefType. +static MemRefType convertToMemRefType(Type type) { + MemRefType memRefType; + auto tensorType = type.dyn_cast(); + if (tensorType) { + assert(tensorType.hasRank() && "expected only ranked shapes"); + memRefType = + MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + } else { + memRefType = type.dyn_cast(); + } + return memRefType; +} + +/// Insert an allocation and deallocation for the given MemRefType. +static Value insertAllocAndDealloc(MemRefType type, Location loc, + PatternRewriter &rewriter, + bool insertDealloc, + ArrayRef operands = {}) { + // Put together alloc operands for any dynamic dimensions of the memref. + AllocOp alloc; + if (!operands.empty()) { + auto memRefShape = type.getShape(); + auto rank = memRefShape.size(); + + std::map fromOperands; + for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { + int memRefDimIdx = rank - 1 - reversedIdx; + if (memRefShape[memRefDimIdx] < 0) { // unknown dimension + Value maxDim = nullptr; + for (int i = 0; i < operands.size(); i++) { + auto operandShape = + operands[i].getType().cast().getShape(); + int operandDimIdx = operandShape.size() - 1 - reversedIdx; + + if (operandDimIdx < 0) + continue; + + // In case of operations with broadcasting, the dimension of the + // alloc result is the maximum size along each dimension of the + // operands. + auto operandDim = + rewriter.create(loc, operands[i], operandDimIdx); + if (maxDim) { + auto maxCondition = rewriter.create(loc, CmpIPredicate::sgt, + operandDim, maxDim); + maxDim = rewriter.create(loc, maxCondition, operandDim, + maxDim); + } else { + maxDim = operandDim; + } + } + fromOperands.insert(std::make_pair(memRefDimIdx, maxDim)); + } + } + + SmallVector allocOperands; + for (int i = 0; i < rank; ++i) + if (memRefShape[i] < 0) + allocOperands.push_back(fromOperands[i]); + alloc = rewriter.create(loc, type, allocOperands); + } else { + alloc = rewriter.create(loc, type); + } + + // Make sure to allocate at the beginning of the block if + // all dimensions are known. + auto *parentBlock = alloc.getOperation()->getBlock(); + if (hasAllConstantDimensions(type)) + alloc.getOperation()->moveBefore(&parentBlock->front()); + + if (insertDealloc) { + auto dealloc = rewriter.create(loc, alloc); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + } + + return alloc; +} + +// Determine if current function returns the result value of the +// current op being lowered. If it does then dealloc should not be +// inserted. +static bool checkInsertDealloc(Operation *currentOp) { + auto parentBlock = currentOp->getBlock(); + + bool insertDealloc = true; + parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) { + assert(currentOp->getNumResults() < 2 && + "No more than one result supported (for now)."); + // If there is at least one result to investigate. + if (currentOp->getNumResults() > 0) { + auto result = currentOp->getResult(0); + for (const auto &operand : op.getOperands()) + if (operand == result) + insertDealloc = false; + } + }); + + return insertDealloc; +} + +// Create a mapping from result type's dimensions to input type's dimensions, +// given that the result type is the result of a reduction op over the input +// type. +std::map +getReductionMapping(MemRefType inputTy, ArrayRef axes, bool keepdims) { + std::map OutInDimMap; + int64_t rank = inputTy.getRank(); + + // Mark reduction axes. + std::vector isReductionAxis; + for (decltype(rank) i = 0; i < rank; ++i) { + if (std::find(axes.begin(), axes.end(), i) != axes.end()) + isReductionAxis.push_back(true); + else + isReductionAxis.push_back(false); + } + + for (decltype(rank) inIndex = 0, outIndex = 0; inIndex < rank; ++inIndex) { + // If it is a reduction axis, there is no relationship among dimensions. + if (isReductionAxis[inIndex]) { + if (keepdims) + outIndex++; + } else { + OutInDimMap.insert(std::make_pair(outIndex, inIndex)); + outIndex++; + } + } + + return OutInDimMap; +} + +// Add bounds associated with the op operand to the KRNL iteration pack. +// Dynamic dimenions are supported. +static void addDimensionToPack(ConversionPatternRewriter &rewriter, + Location loc, KrnlIterateOperandPack &pack, Value operand, int index) { + auto shape = operand.getType().cast().getShape(); + if (shape[index] < 0) { + pack.pushConstantBound(0); + pack.pushOperandBound( + rewriter.create(loc, operand, index).getResult()); + } else { + pack.pushConstantBound(0); + pack.pushConstantBound(shape[index]); + } +} + +// Function that defines the KRNL dialect loops and their respective +// optimized version. +static KrnlOptimizeLoopsOp emitOptimizedLoops( + ConversionPatternRewriter &rewriter, Location loc, + std::vector &loops, std::vector &optimizedLoops, + int64_t numLoops) { + // Define loops. + auto loopsOp = rewriter.create(loc, numLoops); + loops.reserve(numLoops); + for (auto result : loopsOp.getResults()) + loops.push_back(result); + + // Define optimized version of the loops. + auto optimizedLoopsOp = rewriter.create(loc, numLoops); + optimizedLoops.reserve(numLoops); + for (auto result : optimizedLoopsOp.getResults()) + optimizedLoops.push_back(result); + + return optimizedLoopsOp; +} + +// Function that emits the loops and their optimized version. +// The function returns a reference to the inner optimization block. +static Block* defineLoops(ConversionPatternRewriter &rewriter, + Location loc, std::vector &loops, + std::vector &optimizedLoops, int64_t numLoops) { + KrnlOptimizeLoopsOp optimizedLoopsOp = emitOptimizedLoops( + rewriter, loc, loops, optimizedLoops, numLoops); + return &optimizedLoopsOp.region().front(); +} + +// Function which emits a basic set of loops and optimized loops +// for a given operation argument. A reference to the loop optimization +// block is returned in the last argument of the function. +static void emitKrnlLoopsAndIterationForOperand( + ConversionPatternRewriter &rewriter, Location loc, + Value operand, std::vector &originalLoops, + KrnlOptimizeLoopsOp &optimizedLoopsOp, KrnlIterateOp &iterateOp) { + // Operand shape. + auto shape = operand.getType().cast().getShape(); + + // Number of loops. + int64_t rank = shape.size(); + + // Define loops and optimized loops. + std::vector optimizedLoops; + optimizedLoopsOp = emitOptimizedLoops(rewriter, loc, originalLoops, + optimizedLoops, rank); + + KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); + // Iterate over the loop nest. + for (int i = 0; i < rank; ++i) + addDimensionToPack(rewriter, loc, pack, operand, i); + + iterateOp = rewriter.create(loc, pack); +} + +unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { + auto elementType = memRefType.getElementType(); + + unsigned sizeInBits; + if (elementType.isIntOrFloat()) { + sizeInBits = elementType.getIntOrFloatBitWidth(); + } else { + auto vectorType = elementType.cast(); + sizeInBits = + vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); + } + return llvm::divideCeil(sizeInBits, 8); +} + +// Get run-time dimension information for unknown dimensions used for +// broadcasting. +std::map> +getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, + MemRefType memRefType, ArrayRef operands) { + auto memRefShape = memRefType.getShape(); + int64_t rank = memRefShape.size(); + // For unknown dimensions, we need to get dimension values at runtime in + // order to do broadcasting. + std::map> DimInfo; + // For each result dimension, compute the number of sharing operands. + // Sharing operands are operands sharing the same index (counting from the + // rightmost to the leftmost) for a given dimension. + std::map sharedDimCount; + for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { + int dimIdx = rank - 1 - reversedIdx; + sharedDimCount[dimIdx] = 0; + for (int i = 0; i < operands.size(); ++i) { + auto shape = operands[i].getType().cast().getShape(); + if (reversedIdx <= shape.size() - 1) + sharedDimCount[dimIdx]++; + } + } + // An unknown dimension can have a value of 1 or N (N > 1). + // If its value is 1, it is broadcasted dimension. + // Otherwise, non-broadcasted dimension. + // We only care about unknown dimensions whose number of sharing operands is + // more than one, since they are potentially broadcasted dimensions. + for (int i = 0; i < operands.size(); ++i) { + std::map broadcastedDims; + auto shape = operands[i].getType().cast().getShape(); + int size = shape.size(); + for (int j = 0; j < shape.size(); ++j) { + if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) { + auto dim = rewriter.create(loc, operands[i], j).getResult(); + auto one = rewriter.create(loc, 1); + auto isBroadcasted = + rewriter.create(loc, CmpIPredicate::eq, dim, one); + broadcastedDims.insert(std::make_pair(j, isBroadcasted)); + } + } + DimInfo.insert(std::make_pair(i, broadcastedDims)); + } + return DimInfo; +} + +// Extract induction variables that are used for broadcasting values of a +// given operand. +std::vector +getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, + ArrayRef loopIVs, Value operand, + std::map broadcastedDims) { + // `operand` must has a ranked type. This should have been checked by the + // shape inference pass. + auto operandShape = operand.getType().cast().getShape(); + auto rank = operandShape.size(); + auto loopCount = loopIVs.size(); + + std::vector newLoopIVs; + for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { + auto dimIdx = rank - 1 - reversedIdx; + auto loopIdx = loopCount - 1 - reversedIdx; + if (operandShape[dimIdx] == 1) { + // Broadcasted dimension + auto zero = rewriter.create(loc, 0); + newLoopIVs.insert(newLoopIVs.begin(), zero); + } else if ((operandShape[dimIdx] == -1) && + (broadcastedDims.find(dimIdx) != broadcastedDims.end())) { + // Unknown dimension, it can have a value of 1 or N (N > 1). + // If its value is 1, it is broadcasted dimension. + // Otherwise, non-broadcasted dimension. + auto zero = rewriter.create(loc, 0); + auto idx = rewriter.create(loc, broadcastedDims[dimIdx], zero, + loopIVs[loopIdx]); + newLoopIVs.insert(newLoopIVs.begin(), idx); + } else { + // Non-broadcasted dimension + newLoopIVs.insert(newLoopIVs.begin(), loopIVs[loopIdx]); + } + } + return newLoopIVs; +} + +namespace { + +template +struct ScalarOp; + +template <> +struct ScalarOp { + using FOp = AddFOp; + using IOp = AddIOp; +}; + +template <> +struct ScalarOp { + using FOp = MulFOp; + using IOp = MulIOp; +}; + +template <> +struct ScalarOp { + using FOp = DivFOp; + using IOp = SignedDivIOp; +}; + +template <> +struct ScalarOp { + using FOp = SubFOp; + using IOp = SubIOp; +}; + +template <> +struct ScalarOp { + using FOp = AndOp; // not use + using IOp = AndOp; +}; + +template <> +struct ScalarOp { + using FOp = OrOp; // not use + using IOp = OrOp; +}; + +template <> +struct ScalarOp { + using FOp = XOrOp; // not use + using IOp = XOrOp; +}; + +template <> +struct ScalarOp { + using FOp = ExpOp; + using IOp = ExpOp; // not use +}; + +template <> +struct ScalarOp { + using FOp = AddFOp; + using IOp = AddIOp; +}; + +template <> +struct ScalarOp { + using FOp = TanhOp; + using IOp = TanhOp; // not use +}; + +template <> +struct ScalarOp { + using FOp = CosOp; + using IOp = CosOp; // not use +}; + +template <> +struct ScalarOp { + using FOp = LogOp; + using IOp = LogOp; // not use +}; + +template <> +struct ScalarOp { + using FOp = MulFOp; + using IOp = MulIOp; +}; + +template <> +struct ScalarOp { + using FOp = AddFOp; + using IOp = AddIOp; +}; + +template <> +struct ScalarOp { + using FOp = KrnlSqrtOp; + using IOp = KrnlSqrtOp; // not use +}; + +template +using ScalarFOp = typename ScalarOp::FOp; +template +using ScalarIOp = typename ScalarOp::IOp; + +// Get the identity element of a operation. +// Return NULL if the function does not have identity. +template +DataType getIdentityValue() { + return NULL; +} + +template <> +float getIdentityValue(){ + return (float)-std::numeric_limits::infinity(); +} + +template <> +int getIdentityValue(){ + return std::numeric_limits::min(); +} + +template <> +float getIdentityValue(){ + return (float)std::numeric_limits::infinity(); +} + +template <> +int getIdentityValue(){ + return std::numeric_limits::max(); +} + +template <> +float getIdentityValue(){ + return (float)1.0; +} + +template <> +int getIdentityValue(){ + return 1; +} + +template <> +float getIdentityValue(){ + return (float)0; +} + +template <> +int getIdentityValue(){ + return 0; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering to Krnl dialect. +//===----------------------------------------------------------------------===// +template +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + /* Lower UnaryOp to Ops in the Standard dialect. + */ + auto loc = op->getLoc(); + Type element_type = operands.front().getType(); + if (element_type.isa()) { + return rewriter.create>(loc, result_types, operands, + mlir::None); + } else if (element_type.isa()) { + return rewriter.create>(loc, result_types, operands, + mlir::None); + } else { + emitError(loc, "unsupported element type"); + return nullptr; + } +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXSinhOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), + // ConstantOp 2) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto two = rewriter.create(loc, FloatAttr::get(elementType, 2)); + auto neg = rewriter.create(loc, zero, operand); + auto exp = rewriter.create(loc, operand); + auto negExp = rewriter.create(loc, neg); + auto result = rewriter.create( + loc, rewriter.create(loc, exp, negExp), two); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXCoshOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))), + // ConstantOp 2) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto two = rewriter.create(loc, FloatAttr::get(elementType, 2)); + auto neg = rewriter.create(loc, zero, operand); + auto exp = rewriter.create(loc, operand); + auto negExp = rewriter.create(loc, neg); + auto result = rewriter.create( + loc, rewriter.create(loc, exp, negExp), two); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXSigmoidOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, + ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1, + // AddFOp(ConstantOp 1, ExpOp(NegFOp(%X)))) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); + auto neg = rewriter.create(loc, zero, operand); + auto negExp = rewriter.create(loc, neg); + auto result = rewriter.create( + loc, one, rewriter.create(loc, one, negExp)); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXHardSigmoidOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp( + Operation *op, ArrayRef result_types, ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // %Y = AddFOp(MulFOp(alpha, %X), beta) + // %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0), + // %Y, + // Constant 0) + // ONNXHardSigmoidOp(%X) = SelectOp(CmpFOp(OLT, %Z, Constant 1), + // %Z, + // Constant 1) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), + llvm::dyn_cast(op).alpha().convertToFloat()); + auto betaAttribute = FloatAttr::get(rewriter.getF32Type(), + llvm::dyn_cast(op).beta().convertToFloat()); + auto elementType = result_types[0]; + + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); + auto alpha = rewriter.create(loc, alphaAttribute); + auto beta = rewriter.create(loc, betaAttribute); + + auto add = rewriter.create( + loc, rewriter.create(loc, alpha, operand), beta); + auto maxPredicate = + rewriter.create(loc, CmpFPredicate::OGT, add, zero); + auto max = rewriter.create(loc, maxPredicate, add, zero); + auto minPredicate = + rewriter.create(loc, CmpFPredicate::OLT, max, one); + auto result = rewriter.create(loc, minPredicate, max, one); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXEluOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), + // MulFOp(alpha, SubFOp(ExpOp(%X), 1)), + // %X) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), + llvm::dyn_cast(op).alpha().convertToFloat()); + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); + auto alpha = rewriter.create(loc, alphaAttribute); + auto exp = rewriter.create(loc, operand); + auto lessThanZero = + rewriter.create(loc, CmpFPredicate::OLT, operand, zero); + auto result = rewriter.create( + loc, lessThanZero, + rewriter.create(loc, alpha, + rewriter.create(loc, exp, one)), + operand); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXReluOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), + // ConstantOp 0, + // %X) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto lessThanZero = + rewriter.create(loc, CmpFPredicate::OLT, operand, zero); + auto result = rewriter.create(loc, lessThanZero, zero, operand); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXLeakyReluOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, + ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), + // MulFOp(alpha, %X), + // %X) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), + llvm::dyn_cast(op).alpha().convertToFloat()); + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto alpha = rewriter.create(loc, alphaAttribute); + auto lessThanZero = + rewriter.create(loc, CmpFPredicate::OLT, operand, zero); + auto result = rewriter.create( + loc, lessThanZero, rewriter.create(loc, alpha, operand), operand); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXSeluOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0), + // MulFOp(gamma, %X), + // MulFOp(gamma, + // SubFOp(MulFOp(alpha, ExpOp(%X)), + // alpha))) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), + llvm::dyn_cast(op).alpha().convertToFloat()); + auto gammaAttribute = FloatAttr::get(rewriter.getF32Type(), + llvm::dyn_cast(op).gamma().convertToFloat()); + auto elementType = result_types[0]; + + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto alpha = rewriter.create(loc, alphaAttribute); + auto gamma = rewriter.create(loc, gammaAttribute); + auto exp = rewriter.create(loc, operand); + auto greaterThanZero = + rewriter.create(loc, CmpFPredicate::OGT, operand, zero); + auto select = rewriter.create( + loc, greaterThanZero, operand, + rewriter.create(loc, rewriter.create(loc, alpha, exp), + alpha)); + auto result = rewriter.create(loc, gamma, select); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXReciprocalOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp( + Operation *op, ArrayRef result_types, ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); + auto result = rewriter.create(loc, one, operand); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXSoftplusOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp( + Operation *op, ArrayRef result_types, ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXSoftplusOp(%X) = LogOp(AddFOp(ExpOp(%X), ConstantOp 1)) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto exp = rewriter.create(loc, operand); + auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); + auto add = rewriter.create(loc, exp, one); + auto result = rewriter.create(loc, add); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXSoftsignOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp( + Operation *op, ArrayRef result_types, ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXSoftsignOp(%X) = DivFOp(ConstantOp 1, %X) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto abs = rewriter.create(loc, operand); + auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); + auto add = rewriter.create(loc, abs, one); + auto result = rewriter.create(loc, operand, add); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXSignOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + + auto loc = op->getLoc(); + Value operand = operands[0]; + Type element_type = operands.front().getType(); + // TODO: unsigned int should be supported separately? + if (element_type.isa()) { + // %Y = SelectOP(CmpIOp(GT, %X, ConstantOp 0), + // ConstantOp 1, + // COnstantOp -1) + // ONNXSignOp(%X) = SelectOP(CmpIOp(EQ, %X, ConstantOp 0), + // ConstantOp 0, + // %Y) + auto zero = rewriter.create(loc, rewriter.getI32IntegerAttr(0)); + auto one = rewriter.create(loc, rewriter.getI32IntegerAttr(1)); + auto minusOne = + rewriter.create(loc, rewriter.getI32IntegerAttr(-1)); + auto plusPredicate = + rewriter.create(loc, CmpIPredicate::sgt, operand, zero); + auto plusSelect = + rewriter.create(loc, plusPredicate, one, minusOne); + auto zeroPredicate = + rewriter.create(loc, CmpIPredicate::eq, operand, zero); + auto result = + rewriter.create(loc, zeroPredicate, zero, plusSelect); + return result; + } else if (element_type.isa()) { + // %Y = SelectOP(CmpFOp(OGT, %X, ConstantOp 0), + // ConstantOp 1, + // ConstantOp -1) + // ONNXSignOp(%X) = SelectOP(CmpFOp(OEQ, %X, ConstantOp 0), + // ConstantOp 0, + // %Y) + auto zero = + rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); + auto one = rewriter.create(loc, rewriter.getF32FloatAttr(1.0f)); + auto minusOne = + rewriter.create(loc, rewriter.getF32FloatAttr(-1.0f)); + auto plusPredicate = + rewriter.create(loc, CmpFPredicate::OGT, operand, zero); + auto plusSelect = + rewriter.create(loc, plusPredicate, one, minusOne); + auto zeroPredicate = + rewriter.create(loc, CmpFPredicate::OEQ, operand, zero); + auto result = + rewriter.create(loc, zeroPredicate, zero, plusSelect); + return result; + } else { + emitError(loc, "unsupported element type"); + } +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXMaxOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y), + // %X, + // %Y) + auto loc = op->getLoc(); + Value lhs = operands[0]; + Value rhs = operands[1]; + auto max = rewriter.create(loc, CmpFPredicate::OGT, lhs, rhs); + auto result = rewriter.create(loc, max, lhs, rhs); + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXMinOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y), + // %X, + // %Y) + auto loc = op->getLoc(); + Value lhs = operands[0]; + Value rhs = operands[1]; + auto min = rewriter.create(loc, CmpFPredicate::OLT, lhs, rhs); + auto result = rewriter.create(loc, min, lhs, rhs); + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXReduceMaxOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, + ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + auto loc = op->getLoc(); + Value lhs = operands[0]; + Value rhs = operands[1]; + Type element_type = lhs.getType(); + if (element_type.isa()) { + auto max = rewriter.create(loc, CmpIPredicate::sgt, lhs, rhs); + auto result = rewriter.create(loc, max, lhs, rhs); + return result; + } else if (element_type.isa()) { + auto max = rewriter.create(loc, CmpFPredicate::OGT, lhs, rhs); + auto result = rewriter.create(loc, max, lhs, rhs); + return result; + } else { + emitError(loc, "unsupported element type"); + return nullptr; + } +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXReduceMinOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, + ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + auto loc = op->getLoc(); + Value lhs = operands[0]; + Value rhs = operands[1]; + Type element_type = lhs.getType(); + if (element_type.isa()) { + auto min = rewriter.create(loc, CmpIPredicate::slt, lhs, rhs); + auto result = rewriter.create(loc, min, lhs, rhs); + return result; + } else if (element_type.isa()) { + auto min = rewriter.create(loc, CmpFPredicate::OLT, lhs, rhs); + auto result = rewriter.create(loc, min, lhs, rhs); + return result; + } else { + emitError(loc, "unsupported element type"); + return nullptr; + } +} + +// Element-wise unary ops lowering to Krnl dialect. +//===----------------------------------------------------------------------===// +template +struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { + ONNXElementwiseUnaryOpLowering(MLIRContext *ctx) + : ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {} + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // TODO: Check that the types are valid. + // An element-wise unary operation must have all operands and the result of + // the same type. This should have been verified by the verifier. + + auto loc = op->getLoc(); + + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertToMemRefType(*op->result_type_begin()); + + // If the output has a dynamic dimension, pass the operands required for + // each dynamic dimension to the AllocOp. The first operand of the + // operation is used. The operands of the op need to match in terms of + // dimensions with the result at this pre-optimization phase. + // TODO: verify that dimensions match. + // TODO: can the dimension of the result differ after optimizations? + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, + {operands[0]}); + + std::vector originalLoops; + KrnlOptimizeLoopsOp optimizedLoopsOp; + KrnlIterateOp iterateOp; + emitKrnlLoopsAndIterationForOperand( + rewriter, loc, operands[0], originalLoops, + optimizedLoopsOp, iterateOp); + Block &optimizationBlock = optimizedLoopsOp.region().front(); + Block &iterationBlock = iterateOp.bodyRegion().front(); + + // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. + rewriter.setInsertionPointToEnd(&optimizationBlock); + // Return from KrnlOptimizeLoopsOp body. + // When no optimizations are present we just return the loops + // unchaged. + rewriter.create(loc, originalLoops); + + // 2. Insert instructions inside the KernelIterateOp body. + rewriter.setInsertionPointToStart(&iterationBlock); + + // Handle the operation: + SmallVector loopIVs; + for (auto arg : iterationBlock.getArguments()) + loopIVs.push_back(arg); + + auto loadedVal = rewriter.create(loc, operands[0], loopIVs); + auto loweredOpResult = mapToLowerScalarOp( + op, memRefType.getElementType(), {loadedVal}, rewriter); + // Store result in the resulting array. + rewriter.create(loc, loweredOpResult, alloc, loopIVs); + + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +// Element-wise variadic ops lowering to Krnl dialect. +//===----------------------------------------------------------------------===// +template +struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { + ONNXElementwiseVariadicOpLowering(MLIRContext *ctx) + : ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {} + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // TODO: Check that the types are valid. + // An element-wise variadic operation must have all operands and the result + // of the same type. This should have been verified by the verifier. + auto loc = op->getLoc(); + auto numArgs = op->getNumOperands(); + + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertToMemRefType(*op->result_type_begin()); + + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + // If the output has a dynamic dimension, we compute its dimension at + // runtime by using dimensions from the operands. + // In particular, we need to know from which operand a result dimension + // comes from. + // TODO: can the dimension of the result differ after optimizations? + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, + operands); + + // Get run-time dimension information for unknown dimensions used for + // broadcasting. + std::map> broadcastedDimInfo = + getBroadcastedDimInfo(loc, rewriter, memRefType, operands); + + std::vector originalLoops; + KrnlOptimizeLoopsOp optimizedLoopsOp; + KrnlIterateOp iterateOp; + emitKrnlLoopsAndIterationForOperand( + rewriter, loc, alloc, originalLoops, + optimizedLoopsOp, iterateOp); + Block &optimizationBlock = optimizedLoopsOp.region().front(); + Block &iterationBlock = iterateOp.bodyRegion().front(); + + // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. + rewriter.setInsertionPointToEnd(&optimizationBlock); + // Return from KrnlOptimizeLoopsOp body. + // When no optimizations are present we just return the loops unchaged. + rewriter.create(loc, originalLoops); + + // 2. Insert instructions inside the KernelIterateOp body. + rewriter.setInsertionPointToStart(&iterationBlock); + + // Handle the operation: + SmallVector loopIVs; + for (auto arg : iterationBlock.getArguments()) + loopIVs.push_back(arg); + + // Fold over operands for each of their scalar values + Value accumulated, next; + auto accumulatedLoopIVs = getLoopIVsForBroadcasting( + loc, rewriter, loopIVs, operands[0], broadcastedDimInfo[0]); + accumulated = rewriter.create(loc, operands[0], accumulatedLoopIVs); + for (unsigned i = 1; i < numArgs; i++) { + auto nextLoopIVs = getLoopIVsForBroadcasting( + loc, rewriter, loopIVs, operands[i], broadcastedDimInfo[i]); + next = rewriter.create(loc, operands[i], nextLoopIVs); + accumulated = mapToLowerScalarOp( + op, memRefType.getElementType(), {accumulated, next}, rewriter); + } + // Store result in the resulting array. + rewriter.create(loc, accumulated, alloc, loopIVs); + + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +struct ONNXSoftmaxOpLowering : public ConversionPattern { + ONNXSoftmaxOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {} + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // softmax(x) = let max_x = max(x) in + // let exp_x = exp(x - max_x) in + // let sum = sum(exp_x) in + // exp_x / sum + auto memRefType = convertToMemRefType(*op->result_type_begin()); + int64_t rank = memRefType.getRank(); + int64_t axis = llvm::dyn_cast(op).axis().getSExtValue(); + axis = axis >= 0 ? axis : rank + axis; + assert(axis >= -rank && axis <= rank - 1); + + auto loc = op->getLoc(); + + // Insert an allocation and deallocation for the result of this operation. + auto elementType = memRefType.getElementType(); + + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, + operands[0]); + + // Shape of the result + auto memRefShape = memRefType.getShape(); + + // Insert allocations and deallocations for sum and max. + MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0); + Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true); + Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true); + Value zero = + rewriter.create(loc, FloatAttr::get(elementType, 0)); + Value negInfinity = rewriter.create( + loc, + FloatAttr::get(elementType, -std::numeric_limits::infinity())); + + // Define loops. + std::vector originalLoops; + std::vector optimizedLoops; + Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, + optimizedLoops, rank); + + // Coerce the input into a 2-D tensor. `axis` will be the coercing point. + // This coercing follows the softmax definition in ONNX: + // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softmax + // Here, we create an outer loop and inner loop for handling the two + // dimensions. The outer loop is only created once `axis` is not zero. + + // Define an outer loop with respect to axis. + std::vector outerLoops, optimizedOuterLoops; + outerLoops.reserve(axis); + optimizedOuterLoops.reserve(axis); + for (int i = 0; i < axis; ++i) { + outerLoops.push_back(originalLoops[i]); + optimizedOuterLoops.push_back(optimizedLoops[i]); + } + KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops); + for (int i = 0; i < axis; ++i) + addDimensionToPack(rewriter, loc, outerPack, operands[0], i); + + // Define an inner loop with respect to axis. + std::vector innerLoops, optimizedInnerLoops; + innerLoops.reserve(rank - axis); + optimizedInnerLoops.reserve(rank - axis); + for (int i = axis; i < rank; ++i) { + innerLoops.push_back(originalLoops[i]); + optimizedInnerLoops.push_back(optimizedLoops[i]); + } + KrnlIterateOperandPack innerPack(rewriter, innerLoops, optimizedInnerLoops); + for (int i = axis; i < rank; ++i) + addDimensionToPack(rewriter, loc, innerPack, operands[0], i); + + KrnlIterateOp outerIterateOp, maxIterateOp, sumIterateOp, softmaxIterateOp; + SmallVector outerLoopIVs; + if (axis != 0) { + outerIterateOp = rewriter.create(loc, outerPack); + + // No optimization + rewriter.setInsertionPointToEnd(optimizationBlock); + rewriter.create(loc, originalLoops); + + // Insert instructions inside the outer loop. + Block &outerIterationBlock = outerIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&outerIterationBlock); + for (auto arg : outerIterationBlock.getArguments()) + outerLoopIVs.push_back(arg); + + // Reset accumulators. + rewriter.create(loc, zero, sumOp); + rewriter.create(loc, negInfinity, maxOp); + + // Create an inner loop to compute max. + maxIterateOp = rewriter.create(loc, innerPack); + // Create an inner loop to compute sum. + sumIterateOp = rewriter.create(loc, innerPack); + // Create an inner loop to compute softmax. + softmaxIterateOp = rewriter.create(loc, innerPack); + } else { + // Reset accumulators. + rewriter.create(loc, zero, sumOp); + rewriter.create(loc, negInfinity, maxOp); + + // Create an inner loop to compute max. + maxIterateOp = rewriter.create(loc, innerPack); + // Create an inner loop to compute sum. + sumIterateOp = rewriter.create(loc, innerPack); + // Create an inner loop to compute softmax. + softmaxIterateOp = rewriter.create(loc, innerPack); + + // No optimization + rewriter.setInsertionPointToEnd(optimizationBlock); + rewriter.create(loc, originalLoops); + } + + // Insert instructions inside the max loop. + Block &maxIterationBlock = maxIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&maxIterationBlock); + + // Get induction variables. + SmallVector maxLoopIVs; + for (auto arg : outerLoopIVs) + maxLoopIVs.push_back(arg); + for (auto arg : maxIterationBlock.getArguments()) + maxLoopIVs.push_back(arg); + + // Compute the max value. + Value max = rewriter.create(loc, maxOp); + Value nextMax = rewriter.create(loc, operands[0], maxLoopIVs); + auto maxCond = + rewriter.create(loc, CmpFPredicate::OGT, max, nextMax); + max = rewriter.create(loc, maxCond, max, nextMax); + rewriter.create(loc, max, maxOp); + + // Get the max. + rewriter.setInsertionPoint(sumIterateOp); + max = rewriter.create(loc, maxOp); + + // Insert instructions inside the sum loop. + Block &sumIterationBlock = sumIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&sumIterationBlock); + + // Get induction variables. + SmallVector sumLoopIVs; + for (auto arg : outerLoopIVs) + sumLoopIVs.push_back(arg); + for (auto arg : sumIterationBlock.getArguments()) + sumLoopIVs.push_back(arg); + + // Sum up values. + Value sum = rewriter.create(loc, sumOp); + Value next = rewriter.create(loc, operands[0], sumLoopIVs); + Value sub = rewriter.create(loc, next, max); + Value exp = rewriter.create(loc, sub); + sum = rewriter.create(loc, sum, exp); + rewriter.create(loc, sum, sumOp); + // Store intermediate values in the result to avoid recomputation. + rewriter.create(loc, exp, alloc, sumLoopIVs); + + // Get the sum. + rewriter.setInsertionPoint(softmaxIterateOp); + sum = rewriter.create(loc, sumOp); + + // Insert instructions inside the softmax loop. + Block &softmaxIterationBlock = softmaxIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&softmaxIterationBlock); + + // Get induction variables. + SmallVector softmaxLoopIVs; + for (auto arg : outerLoopIVs) + softmaxLoopIVs.push_back(arg); + for (auto arg : softmaxIterationBlock.getArguments()) + softmaxLoopIVs.push_back(arg); + + // Compute softmax. + Value expLoadedVal = rewriter.create(loc, alloc, softmaxLoopIVs); + Value result = rewriter.create(loc, expLoadedVal, sum); + rewriter.create(loc, result, alloc, softmaxLoopIVs); + + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +struct ONNXReshapeOpLowering : public ConversionPattern { + ONNXReshapeOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + + auto memRefType = convertToMemRefType(*op->result_type_begin()); + auto memRefShape = memRefType.getShape(); + auto inputShape = operands[0].getType().cast().getShape(); + + // Insert an allocation and deallocation for the result of this operation. + Value alloc; + + // Compute size in bytes using the input tensor. + Value tensorSize = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), + getMemRefEltSizeInBytes(memRefType))); + for (int i = 0; i < inputShape.size(); ++i) { + Value dimVal; + if (inputShape[i] < 0) { + Value dim = rewriter.create(loc, operands[0], i); + dimVal = + rewriter.create(loc, dim, rewriter.getIntegerType(64)); + } else { + dimVal = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), + inputShape[i])); + } + tensorSize = rewriter.create(loc, tensorSize, dimVal); + } + + bool insertDealloc = checkInsertDealloc(op); + if (hasAllConstantDimensions(memRefType)) { + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + } else { + // If a dimension is zero, the actual dimension value is taken from the + // input tensor. + // + // If the shape array has a negative dimension (-1), we compute its actual + // dimension value from the other dimensions. But we don't have enough + // information about the other dimensions at this point. So, we need to + // scan the shape first to calculate reduction of all of the dimensions. + // If the reduction is negative, then the shape array contains a negative + // dimension. Otherwise, the reduction is the same as the one computed + // from the input tensor. + Value tensorSizeFromShape = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), + getMemRefEltSizeInBytes(memRefType))); + SmallVector DimInfo; + for (int i = 0; i < memRefShape.size(); ++i) { + Value index = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i)); + // Load index from array of indices. + Value loadedVal = rewriter.create(loc, operands[1], index); + // If a dimension is zero, the actual dimension value is taken from the + // input tensor. + // + // If a dimension is negative, it is computed from the other dimensions. + // But we don't have enough information about the other dimensions at + // this point. So, we let it as it is (-1), and compute it later. + if (i < inputShape.size()) { + Value dimVal; + auto loadedValType = loadedVal.getType().cast(); + if (inputShape[i] < 0) { + Value dim = rewriter.create(loc, operands[0], i); + dimVal = rewriter.create(loc, dim, loadedValType); + } else { + dimVal = rewriter.create( + loc, rewriter.getIntegerAttr(loadedValType, inputShape[i])); + } + auto zero = rewriter.create( + loc, rewriter.getIntegerAttr(loadedValType, 0)); + auto isZero = + rewriter.create(loc, CmpIPredicate::eq, loadedVal, zero); + loadedVal = rewriter.create(loc, isZero, dimVal, loadedVal); + } + // Check if the loaded index is already the correct width of 64 bits. + // Convert the value to a 64 bit integer if needed. + Value int64LoadedVal = loadedVal; + if (loadedVal.getType().cast().getWidth() < 64) + int64LoadedVal = rewriter.create( + loc, loadedVal, rewriter.getIntegerType(64)); + tensorSizeFromShape = + rewriter.create(loc, tensorSizeFromShape, int64LoadedVal); + // Store intermediate results to use later. + DimInfo.emplace_back(int64LoadedVal); + } + // Reverse tensorSizeFromShape since it is negative if the shape array has + // a negative dimension. This is safe since we only use it to compute the + // actual value for the negative dimension. + auto zero = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + tensorSizeFromShape = + rewriter.create(loc, zero, tensorSizeFromShape); + + // Obtain operands for AllocOp. + SmallVector allocOperands; + auto negOne = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), -1)); + + for (int i = 0; i < memRefShape.size(); ++i) { + auto dimVal = DimInfo[i]; + auto isNegOne = + rewriter.create(loc, CmpIPredicate::eq, dimVal, negOne); + // If dimension is negative, compute its value from the other + // dimensions. + auto actualDimVal = + rewriter.create(loc, tensorSize, tensorSizeFromShape); + auto loadedVal = + rewriter.create(loc, isNegOne, actualDimVal, dimVal); + allocOperands.push_back(rewriter.create( + loc, loadedVal, rewriter.getIndexType())); + } + AllocOp allocateMemref = + rewriter.create(loc, memRefType, allocOperands); + + // Make sure to allocate at the beginning of the block if + // all dimensions are known. + auto *parentBlock = allocateMemref.getOperation()->getBlock(); + if (insertDealloc) { + auto dealloc = rewriter.create(loc, allocateMemref); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + } + + alloc = allocateMemref; + } + + rewriter.create(loc, alloc, operands[0], tensorSize); + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +struct ONNXGemmOpLowering : public ConversionPattern { + ONNXGemmOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXGemmOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + + Value A, B, C; + A = operands[0]; + B = operands[1]; + C = operands[2]; + + auto memRefType = convertToMemRefType(*op->result_type_begin()); + + auto alphaAttr = FloatAttr::get(memRefType.getElementType(), + llvm::dyn_cast(op).alpha().convertToFloat()); + auto betaAttr = FloatAttr::get(memRefType.getElementType(), + llvm::dyn_cast(op).beta().convertToFloat()); + auto alpha = rewriter.create(loc, alphaAttr); + auto beta = rewriter.create(loc, betaAttr); + + bool isTransA = (llvm::dyn_cast(op).transA() != 0); + bool isTransB = (llvm::dyn_cast(op).transB() != 0); + + // Insert an allocation and deallocation for the result of this operation. + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else { + auto memRefShape = memRefType.getShape(); + SmallVector allocOperands; + if (memRefShape[0] < 0) { + auto dim = rewriter.create(loc, A, (isTransA) ? 1 : 0); + allocOperands.emplace_back(dim); + } + if (memRefShape[1] < 0) { + auto dim = rewriter.create(loc, B, (isTransB) ? 0 : 1); + allocOperands.emplace_back(dim); + } + alloc = rewriter.create(loc, memRefType, allocOperands); + if (insertDealloc) { + auto *parentBlock = alloc.getDefiningOp()->getBlock(); + auto dealloc = rewriter.create(loc, alloc); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + } + } + + // Number of loops + auto memRefShape = memRefType.getShape(); + int64_t numLoops = 3; + + // Define loops. + std::vector originalLoops; + std::vector optimizedLoops; + Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, + optimizedLoops, numLoops); + + // We have two Krnl loops: + // - Outer loop iterates over the output matrix dimensions, and + // - Reduction loop iterates over the reduction dimension. + + // Outer loop + std::vector outerLoops, optimizedOuterLoops; + outerLoops.reserve(2); + optimizedOuterLoops.reserve(2); + for (int i = 0; i < 2; ++i) { + outerLoops.push_back(originalLoops[i]); + optimizedOuterLoops.push_back(optimizedLoops[i]); + } + KrnlIterateOperandPack outerPack(rewriter, outerLoops, + optimizedOuterLoops); + // Induction variables for the outer loops + for (int i = 0; i < 2; ++i) + addDimensionToPack(rewriter, loc, outerPack, alloc, i); + + // Reduction loop + std::vector reductionLoops, optimizedReductionLoops; + reductionLoops.reserve(1); + optimizedReductionLoops.reserve(1); + reductionLoops.push_back(originalLoops[2]); + optimizedReductionLoops.push_back(optimizedLoops[2]); + KrnlIterateOperandPack reductionPack(rewriter, reductionLoops, + optimizedReductionLoops); + // Induction variable for the reduction dimension + // Try to find and use a static value from A or B first. + // If it failed then use a dynamic value. + auto ATy = A.getType().cast(); + auto BTy = B.getType().cast(); + int64_t K_A_Idx = (isTransA) ? 0 : 1; + int64_t K_B_Idx = (isTransB) ? 1 : 0; + reductionPack.pushConstantBound(0); + if (ATy.getShape()[K_A_Idx] != -1) + reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]); + else + if (BTy.getShape()[K_B_Idx] != -1) + reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]); + else + reductionPack.pushOperandBound( + rewriter.create(loc, B, K_B_Idx).getResult()); + + // Get run-time dimension information for unknown dimensions used for + // broadcasting. + // GemmOp supports unidirectional broadcasting from C to A*B. + // Hence, it must be enough to get broadcasting information for C only. + std::map broadcastedDimInfo; + auto shape = C.getType().cast().getShape(); + for (int i = 0; i < shape.size(); ++i) { + if (shape[i] < 0) { + auto dim = rewriter.create(loc, C, i).getResult(); + auto one = rewriter.create(loc, 1); + auto isBroadcasted = + rewriter.create(loc, CmpIPredicate::eq, dim, one); + broadcastedDimInfo.insert(std::make_pair(i, isBroadcasted)); + } + } + + auto outerIterateOp = rewriter.create(loc, outerPack); + + // Now perform the insertions into the body of the + // just generated instructions: + + // No optimization + rewriter.setInsertionPointToEnd(optimizationBlock); + rewriter.create(loc, originalLoops); + + // Insert instructions inside the outer loop. + Block &outerIterationBlock = outerIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&outerIterationBlock); + + // Induction variables + SmallVector loopMNIVs; + for (auto arg : outerIterationBlock.getArguments()) { + loopMNIVs.emplace_back(arg); + } + + // Initialize the output of A*B + auto zero = rewriter.create( + loc, FloatAttr::get(memRefType.getElementType(), 0)); + rewriter.create(loc, zero, alloc, loopMNIVs); + + // Compute A*B + auto matmulIterateOp = rewriter.create(loc, reductionPack); + + // Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting) + auto loopCIVs = getLoopIVsForBroadcasting( + loc, rewriter, loopMNIVs, C, broadcastedDimInfo); + auto loadedC = rewriter.create(loc, C, loopCIVs); + auto loadedAB = rewriter.create(loc, alloc, loopMNIVs); + auto alphaAB = rewriter.create(loc, alpha, loadedAB); + auto betaC = rewriter.create(loc, beta, loadedC); + auto Y = rewriter.create(loc, alphaAB, betaC); + rewriter.create(loc, Y, alloc, loopMNIVs); + + // Insert instructions to do matrix multiplication: A*B + Block &matmulIterationBlock = matmulIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&matmulIterationBlock); + + // Induction variables + SmallVector loopKIVs, loopAIVs, loopBIVs; + for (auto arg : matmulIterationBlock.getArguments()) + loopKIVs.emplace_back(arg); + if (isTransA) { + loopAIVs.emplace_back(loopKIVs[0]); + loopAIVs.emplace_back(loopMNIVs[0]); + } else { + loopAIVs.emplace_back(loopMNIVs[0]); + loopAIVs.emplace_back(loopKIVs[0]); + } + if (isTransB) { + loopBIVs.emplace_back(loopMNIVs[1]); + loopBIVs.emplace_back(loopKIVs[0]); + } else { + loopBIVs.emplace_back(loopKIVs[0]); + loopBIVs.emplace_back(loopMNIVs[1]); + } + + // Matmul computation + auto loadedA = rewriter.create(loc, A, loopAIVs); + auto loadedB = rewriter.create(loc, B, loopBIVs); + auto loadedY = rewriter.create(loc, alloc, loopMNIVs); + auto AB = rewriter.create(loc, loadedA, loadedB); + auto accumulated = rewriter.create(loc, loadedY, AB); + rewriter.create(loc, accumulated, alloc, loopMNIVs); + + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +struct ONNXUnsqueezeOpLowering : public ConversionPattern { + ONNXUnsqueezeOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + auto memRefType = convertToMemRefType(*op->result_type_begin()); + int outRank = memRefType.getRank(); + + // Assume that `axes` has been validated by shape inference. + // So, here we just get it. + ArrayAttr axisAttrs = llvm::dyn_cast(op).axesAttr(); + SmallVector axes; + for (auto axisAttr : axisAttrs.getValue()) { + int axis = axisAttr.cast().getInt(); + axis = axis >= 0 ? axis : (outRank + axis); + axes.emplace_back(axis); + } + + // Insert an allocation and deallocation for the result of this operation. + Value alloc; + + // Compute size in bytes. + Value tensorSize = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), + getMemRefEltSizeInBytes(memRefType))); + + bool insertDealloc = checkInsertDealloc(op); + auto memRefShape = memRefType.getShape(); + if (hasAllConstantDimensions(memRefType)) { + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + for (int i = 0; i < memRefShape.size(); ++i) { + Value dimVal = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), + memRefShape[i])); + tensorSize = rewriter.create(loc, tensorSize, dimVal); + } + } else { + // Unknown dimensions are always the operand's dimensions. + SmallVector allocOperands; + for (int outIdx = 0, inIdx = 0; outIdx < memRefShape.size(); ++outIdx) { + Value dimVal = nullptr; + if (memRefShape[outIdx] < 0) { + Value index = rewriter.create(loc, operands[0], inIdx); + dimVal = rewriter.create( + loc, index, rewriter.getIntegerType(64)); + allocOperands.emplace_back(index); + } else { + dimVal = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), + memRefShape[outIdx])); + } + tensorSize = rewriter.create(loc, tensorSize, dimVal); + if (std::find(axes.begin(), axes.end(), outIdx) == axes.end()) + inIdx++; + } + alloc = rewriter.create(loc, memRefType, allocOperands); + auto *parentBlock = alloc.getDefiningOp()->getBlock(); + if (insertDealloc) { + auto dealloc = rewriter.create(loc, alloc); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + } + } + rewriter.create(loc, alloc, operands[0], tensorSize); + rewriter.replaceOp(op, alloc); + return matchSuccess(); + } +}; + +struct ONNXTransposeOpLowering : public ConversionPattern { + ONNXTransposeOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertToMemRefType(*op->result_type_begin()); + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, + {operands[0]}); + + // Number of loops + auto memRefShape = memRefType.getShape(); + int64_t rank = memRefShape.size(); + + // Define loops. + std::vector originalLoops; + std::vector optimizedLoops; + Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, + optimizedLoops, rank); + + KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); + // Iterate over the loop nest using the input shape. + for (int i = 0; i < rank; ++i) + addDimensionToPack(rewriter, loc, pack, operands[0], i); + + auto iterateOp = rewriter.create(loc, pack); + Block &iterationBlock = iterateOp.bodyRegion().front(); + + // Now perform the insertions into the body of the + // just generated instructions: + + // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. + rewriter.setInsertionPointToEnd(optimizationBlock); + // Return from KrnlOptimizeLoopsOp body. + // When no optimizations are present we just return the loops + // unchaged. + rewriter.create(loc, originalLoops); + + // 2. Insert instructions inside the KernelIterateOp body. + rewriter.setInsertionPointToStart(&iterationBlock); + + // Handle the operation. + + // Read perm attribute. + SmallVector perm; + auto permAttribute = llvm::dyn_cast(op).permAttr(); + if (permAttribute) { + for (auto permVal : permAttribute.getValue()) + perm.emplace_back(permVal.cast().getInt()); + } else { + // TODO: Remove when perm is guaranteed to be present (even for + // the default case). This means that perm was added by shape + // inference or another pass to contain the values corresponding + // to the default behavior of Transpose. + for (int i = iterationBlock.getArguments().size()-1; i >= 0; i--) + perm.emplace_back(i); + } + + SmallVector inLoopIVs; + for (auto arg : iterationBlock.getArguments()) + inLoopIVs.emplace_back(arg); + + SmallVector outLoopIVs; + for (int i=0; i(loc, operands[0], inLoopIVs); + rewriter.create(loc, inVal, alloc, outLoopIVs); + + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +struct ONNXIdentityOpLowering : public ConversionPattern { + ONNXIdentityOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + rewriter.replaceOp(op, operands[0]); + return matchSuccess(); + } +}; + +struct ONNXConvNoBiasOpLowering : public ConversionPattern { + ONNXConvNoBiasOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertToMemRefType(*op->result_type_begin()); + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + ONNXConvNoBiasOp convOp = llvm::dyn_cast(op); + + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, + {operands[0]}); + + auto resultShape = memRefType.getShape(); + auto inputShape = operands[0].getType().cast().getShape(); + auto kernelShape = operands[1].getType().cast().getShape(); + + // R = ConvNoBias(D, K) + // + // The input/output shapes will look like this: + // + // D (NxCxHxW) x K (MxC/groupxKHxKW) -> R (NxMxRHxRW) + // + // M is a multiple of the number of groups: + // M = group * kernelsPerGroup + // + // The loop nest will look as follows: + // + // strides = [s1, s2] + // + // kernelsPerGroup = M / group; + // for n = 0 .. N: + // for g = 0 .. group: + // for m = 0 .. kernelsPerGroup: + // kernel = g * kernelsPerGroup + m; + // for r1 = 0 .. RH: + // for r2 = 0 .. RW: + // R[n][kernel][r1][r2] = 0; + // for c = 0 .. C/group: + // for k1 = 0 .. KH: + // for k2 = 0 .. KW: + // R[n][kernel][r1][r2] = + // D[n][g * (C / group) + c][s1 * r1 + k1][s2 * r2 + k2] * + // K[kernel][c][k1][k2]; + // + // Naming: + // n, g, m: outer loop nest indices + // r1, r2: spatial loop nest indices + // c, k1, k2: inner loop nest indices + // + // TODO: handle padding. + // + // In the general case: + // + // D (NxCxD1xD2x...xDdim) x K (MxC/groupxK1xK2x...xKdim) + // -> R (NxMxR1xR2x...xRdim) + // + // The above loop nest can be adapted by increasing the number + // of r- and k-index loop i.e. r1 r2 and k1 k2 loops. + + // Set up outermost loops: n g m r1 r2 ... rdim + // Skip g if group is 1. + + // Before we start the iteration we need to compute the number of + // unsplit kernels and fetch the number of groups from the attribute + // list. Group is always a compilation constant. + int64_t group = convOp.group().getSExtValue(); + // Compute the number of unsplit kernels. The number of kernels + // must be a multiple of the number of groups. + int64_t kernelsPerGroup = floor(kernelShape[0] / group); + auto kernelsPerGroupValue = + rewriter.create(loc, kernelsPerGroup); + auto zero = rewriter.create( + loc, FloatAttr::get(memRefType.getElementType(), 0)); + Value subchannels; + if (kernelShape[1] < 0) { + subchannels = + rewriter.create(loc, operands[1], 1).getResult(); + } else { + subchannels = rewriter.create( + loc, kernelShape[1]); + } + + // 1. Define outer loops and emit empty optimization block: + int64_t nOuterLoops = (group > 1) ? 3 : 2; + std::vector outerLoops; + std::vector optimizedOuterLoops; + Block *optimizationBlock = defineLoops(rewriter, loc, outerLoops, + optimizedOuterLoops, nOuterLoops); + + // Prepare iteration arguments over outer loop nest. + KrnlIterateOperandPack pack( + rewriter, outerLoops, optimizedOuterLoops); + // for n = 0 .. N: + pack.pushConstantBound(0); + if (inputShape[0] < 0) + pack.pushOperandBound( + rewriter.create(loc, operands[0], 0).getResult()); + else + pack.pushConstantBound(inputShape[0]); + // for g = 0 .. N: + if (group > 1) { + pack.pushConstantBound(0); + pack.pushConstantBound(group); + } + // for m = 0 .. kernelsPerGroup: + pack.pushConstantBound(0); + pack.pushConstantBound(kernelsPerGroup); + // Outer loop iteration. + auto iterateOp = rewriter.create(loc, pack); + Block &outerIterationBlock = iterateOp.bodyRegion().front(); + // Emit optimizations for outer loops: + rewriter.setInsertionPointToEnd(optimizationBlock); + rewriter.create(loc, outerLoops); + rewriter.setInsertionPointToStart(&outerIterationBlock); + { + // 2. Emit the body of the outer loop nest. + + // 2.1 Compute kernel order number: kernel = g * kernelsPerGroup + m; + // If group is not set then the value of the kernel ID is + // identical to that of the loop over kernels. + Value kernel = outerIterationBlock.getArguments()[1]; + if (group > 1) { + // Middle loop is over groups and third loop is over the + // kernel identifiers in the current group. + auto kernelsOffset = rewriter.create(loc, + outerIterationBlock.getArguments()[1], + kernelsPerGroupValue); + kernel = rewriter.create(loc, kernelsOffset, + outerIterationBlock.getArguments()[2]); + } + + // 2.2 Define spatial loops + int64_t nSpatialLoops = resultShape.size() - 2; + std::vector spatialLoops; + std::vector optimizedSpatialLoops; + Block *optSpatialLoopBlock = defineLoops(rewriter, loc, spatialLoops, + optimizedSpatialLoops, nSpatialLoops); + + // 2.3 Prepare iteration arguments for spatial loop nest. + KrnlIterateOperandPack spatialPack( + rewriter, spatialLoops, optimizedSpatialLoops); + for (int i = 2; i < resultShape.size(); ++i) + addDimensionToPack(rewriter, loc, spatialPack, alloc, i); + + // 2.4 Emit loop nest over output spatial dimensions. + // for rX = 0 .. RX + auto spatialIterateOp = + rewriter.create(loc, spatialPack); + Block &spatialIterationBlock = spatialIterateOp.bodyRegion().front(); + // 2.5 Emit optimizations for outer loops: + rewriter.setInsertionPointToEnd(optSpatialLoopBlock); + rewriter.create(loc, spatialLoops); + rewriter.setInsertionPointToStart(&spatialIterationBlock); + { + // 3. Emit the body of the spatial loop nest. + // 3.1 Emit: R[n][kernel][r1][r2] = 0; + SmallVector resultIndices; + // n + resultIndices.emplace_back(outerIterationBlock.getArguments()[0]); + // kernel + resultIndices.emplace_back(kernel); + // rX + for (auto arg : spatialIterationBlock.getArguments()) + resultIndices.emplace_back(arg); + // Store initializer value into output location. + rewriter.create(loc, zero, alloc, resultIndices); + + // 3.2 Define inner loops. + int64_t nInnerLoops = 1 + (kernelShape.size() - 2); + std::vector innerLoops; + std::vector optimizedInnerLoops; + Block *optInnerLoopBlock = defineLoops(rewriter, loc, innerLoops, + optimizedInnerLoops, nInnerLoops); + + // 3.3 Prepare iteration arguments for inner loop nest. + KrnlIterateOperandPack innerPack( + rewriter, innerLoops, optimizedInnerLoops); + // for c = 0 .. C/group + innerPack.pushConstantBound(0); + innerPack.pushConstantBound(kernelShape[1]); + // for Kx = 0 .. KX + for (int i = 2; i < kernelShape.size(); ++i) + addDimensionToPack(rewriter, loc, innerPack, operands[1], i); + + // 3.4 Emit inner loop nest. + auto innerIterateOp = + rewriter.create(loc, innerPack); + Block &innerIterationBlock = innerIterateOp.bodyRegion().front(); + // 3.5 Emit optimizations for outer loops: + rewriter.setInsertionPointToEnd(optInnerLoopBlock); + rewriter.create(loc, innerLoops); + rewriter.setInsertionPointToStart(&innerIterationBlock); + { + // 4. Emit inner loop body + // R[n][kernel][r1][r2] = + // D[n][g * (C / group) + c][s1 * r1 + k1][s2 * r2 + k2] * + // K[kernel][c][k1][k2]; + + // 4.1 Prepare indices for accesing the data tensor. + SmallVector dataIndices; + // n + dataIndices.emplace_back(outerIterationBlock.getArguments()[0]); + // g * (C / group) + c + Value channelDepth = innerIterationBlock.getArguments()[0]; + if (group > 1) + channelDepth = rewriter.create(loc, channelDepth, + rewriter.create(loc, subchannels, + outerIterationBlock.getArguments()[1])); + dataIndices.emplace_back(channelDepth); + // sX * rX + kX + auto stridesAttribute = convOp.stridesAttr(); + // Read strides attribute + SmallVector strides; + if (stridesAttribute) + for (auto stride : stridesAttribute.getValue()) + strides.emplace_back(stride.cast().getInt()); + for (int i = 0; i < kernelShape.size() - 2; ++i) { + Value spatialIndex = spatialIterationBlock.getArguments()[i]; + // If strides are present then emit the correct access index. + if (stridesAttribute && strides[i] > 1) + spatialIndex = rewriter.create(loc, + rewriter.create(loc, strides[i]), + spatialIterationBlock.getArguments()[i]); + dataIndices.emplace_back( + rewriter.create(loc, spatialIndex, + innerIterationBlock.getArguments()[i+1])); + } + + // 4.2 Prepare indices for accessing the kernel tensor. + SmallVector kernelIndices; + // kernel + kernelIndices.emplace_back(kernel); + // c + kernelIndices.emplace_back(innerIterationBlock.getArguments()[0]); + // kX + for (int i = 0; i < kernelShape.size() - 2; ++i) + kernelIndices.emplace_back( + innerIterationBlock.getArguments()[i+1]); + + // 4.3 Compute convolution. + auto loadData = + rewriter.create(loc, operands[0], dataIndices); + auto loadKernel = + rewriter.create(loc, operands[1], kernelIndices); + auto loadPartialSum = + rewriter.create(loc, alloc, resultIndices); + Value result = rewriter.create(loc, loadPartialSum, + rewriter.create(loc, loadData, loadKernel)); + // 4.4 Store computed value into output location. + rewriter.create(loc, result, alloc, resultIndices); + } + } + } + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +//===----------------------------------------------------------------------===// +// Reduction ops lowering to Krnl dialect. +//===----------------------------------------------------------------------===// +template +struct ONNXReductionOpLowering : public ConversionPattern { + ONNXReductionOpLowering(MLIRContext *ctx) + : ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + /* + * Condition: reduction function must be associative and commutative. + * + * Example 1 (here, reduction function is `+`): + * Induction variables: (i0, i1, i2) + * axes = [0, 2] + * keepdims = true + * krnl.iterate() with (i0, i1, i2) { + * Y(0, i1, 0) += X(i0, i1, i2) + * } + * + * Example 2 (here, reduction function is `+`): + * Induction variables: (i0, i1, i2) + * axes = [0, 2] + * keepdims = false + * krnl.iterate() with (i0, i1, i2) { + * Y(i1) += X(i0, i1, i2) + * } + * + */ + auto loc = op->getLoc(); + auto memRefInType = operands[0].getType().cast(); + auto memRefInShape = memRefInType.getShape(); + auto memRefOutType = convertToMemRefType(*op->result_type_begin()); + int64_t inRank = memRefInType.getRank(); + int64_t outRank = memRefOutType.getRank(); + + // Get attributes + ArrayAttr axisAttrs = llvm::dyn_cast(op).axesAttr(); + std::vector axes; + if (axisAttrs) { + for (auto axisAttr : axisAttrs.getValue()) { + int64_t axis = axisAttr.cast().getInt(); + axis = axis >= 0 ? axis : (inRank + axis); + assert(axis >= -inRank && axis <= inRank - 1); + if (std::find(axes.begin(), axes.end(), axis) == axes.end()) + axes.push_back(axis); + } + } else { + for (decltype(inRank) i = 0; i < inRank; ++i) { + axes.push_back(i); + } + } + // KeepDims + auto keepdims = + llvm::dyn_cast(op).keepdims(); + bool isKeepdims = (keepdims == 1) ? true : false; + + // Get type information + auto memRefOutShape = memRefOutType.getShape(); + auto elementOutType = memRefOutType.getElementType(); + std::map outInDimMap = + getReductionMapping(memRefInType, axes, isKeepdims); + + // Insert an allocation and deallocation for the result of this operation. + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + if (hasAllConstantDimensions(memRefOutType)) { + alloc = insertAllocAndDealloc(memRefOutType, loc, rewriter, insertDealloc); + } else { + SmallVector allocOperands; + for (decltype(outRank) i = 0; i < outRank; ++i) { + if (memRefOutShape[i] < 0) { + auto dim = rewriter.create(loc, operands[0], outInDimMap[i]); + allocOperands.push_back(dim); + } + } + alloc = rewriter.create(loc, memRefOutType, allocOperands); + if (insertDealloc) { + auto *parentBlock = alloc.getDefiningOp()->getBlock(); + auto dealloc = rewriter.create(loc, alloc); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + } + } + + // There are two Krnl loops: + // - One to initialize the result memref, and + // - One to do reduction + + // Define loops to initialize the result. + std::vector originalLoopsInit; + std::vector optimizedLoopsInit; + Block *optimizationBlockInit = defineLoops(rewriter, loc, originalLoopsInit, + optimizedLoopsInit, outRank); + + // Iteration information + KrnlIterateOperandPack packInit(rewriter, originalLoopsInit, + optimizedLoopsInit); + for (decltype(outRank) i = 0; i < outRank; ++i) { + addDimensionToPack(rewriter, loc, packInit, alloc, i); + } + auto iterateOpInit = rewriter.create(loc, packInit); + Block &iterationBlockInit = iterateOpInit.bodyRegion().front(); + + // Perform the insertions into the body of the initialization loop. + // No optimization + rewriter.setInsertionPointToEnd(optimizationBlockInit); + rewriter.create(loc, originalLoopsInit); + + // Insert instructions inside the KernelIterateOp body. + rewriter.setInsertionPointToStart(&iterationBlockInit); + + // Handle the operation: + SmallVector loopIVs; + for (auto arg : iterationBlockInit.getArguments()) { + loopIVs.push_back(arg); + } + + Value identity; + if (elementOutType.isa()) { + identity = rewriter.create( + loc, FloatAttr::get(elementOutType, + getIdentityValue())); + } else if (elementOutType.isa()) { + identity = rewriter.create( + loc, IntegerAttr::get(elementOutType, + getIdentityValue())); + } else { + emitError(loc, "unsupported element type"); + } + rewriter.create(loc, identity, alloc, loopIVs); + + // Define an Krnl loop to do reduction. + rewriter.setInsertionPointAfter(iterateOpInit); + std::vector originalLoops, optimizedLoops; + Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, + optimizedLoops, inRank); + // Iteration information + KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); + for (decltype(inRank) i = 0; i < inRank; ++i) { + addDimensionToPack(rewriter, loc, pack, operands[0], i); + } + auto iterateOp = rewriter.create(loc, pack); + Block &iterationBlock = iterateOp.bodyRegion().front(); + + // Perform the insertions into the body of the reduction loop. + // No optimization + rewriter.setInsertionPointToEnd(optimizationBlock); + rewriter.create(loc, originalLoops); + + // Insert instructions inside the KernelIterateOp body. + rewriter.setInsertionPointToStart(&iterationBlock); + + // Handle the operation: + SmallVector inLoopIVs, outLoopIVs; + auto args = iterationBlock.getArguments(); + for (int i = 0; i < args.size(); ++i) { + inLoopIVs.push_back(args[i]); + } + Value zeroIndex = nullptr; + for (decltype(inRank) i = 0; i < outRank; ++i) { + if (outInDimMap.find(i) != outInDimMap.end()) { + outLoopIVs.push_back(inLoopIVs[outInDimMap[i]]); + } else { + if (zeroIndex) { + outLoopIVs.push_back(zeroIndex); + } else { + zeroIndex = rewriter.create(loc, 0); + outLoopIVs.push_back(zeroIndex); + } + } + } + + Value next, accumulated; + next = rewriter.create(loc, operands[0], inLoopIVs); + accumulated = rewriter.create(loc, alloc, outLoopIVs); + accumulated = mapToLowerScalarOp( + op, memRefOutType.getElementType(), {accumulated, next}, rewriter); + rewriter.create(loc, accumulated, alloc, outLoopIVs); + + rewriter.replaceOp(op, alloc); + return matchSuccess(); + } +}; + +//===----------------------------------------------------------------------===// +// EntryPoint Op lowering to Krnl Entry Point. +//===----------------------------------------------------------------------===// + +class ONNXEntryPointLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(ONNXEntryPointOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, + op.getAttrOfType( + ONNXEntryPointOp::getEntryPointFuncAttrName()), + op.getAttrOfType(ONNXEntryPointOp::getNumInputsAttrName()), + op.getAttrOfType( + ONNXEntryPointOp::getNumOutputsAttrName())); + return matchSuccess(); + } +}; + +//===----------------------------------------------------------------------===// +// Conversion from Tensor type to the Standard dialect MemRef type. +//===----------------------------------------------------------------------===// + +struct TensorTypeConverter : public TypeConverter { + using TypeConverter::TypeConverter; + + LogicalResult convertType(Type t, SmallVectorImpl &results) override { + if (auto type = convertToMemRefType(t)) { + results.push_back(type); + return success(); + } + + results.push_back(t); + return success(); + } + + /// Return true if the inputs and outputs of the given function type are + /// legal. [Taken from MLIR and adapted to only check the legality of the + /// inputs. Once unranked results can be handled gracefully this + /// override needs to be removed in favour of the original MLIR one.] + bool isSignatureLegal(FunctionType funcType) { + return llvm::all_of(funcType.getInputs(), + [this](Type type) { return isLegal(type); }); + } +}; + +} // end anonymous namespace. + +//===----------------------------------------------------------------------===// +// Frontend to Krnl Dialect lowering pass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering to Krnl loops of the ONNX operations. +namespace { +struct FrontendToKrnlLoweringPass + : public ModulePass { + void runOnModule() final; +}; +} // end anonymous namespace. + +void FrontendToKrnlLoweringPass::runOnModule() { + auto module = getModule(); + + // The first thing to define is the conversion target. This will define the + // final target for this lowering. + ConversionTarget target(getContext()); + + // We define the specific operations, or dialects, that are legal targets for + // this lowering. + target + .addLegalDialect(); + + // TODO: enable this once more ops are supported. + // We also define the ONNX dialect as Illegal so that the conversion will fail + // if any of these operations are *not* converted. + // target.addIllegalDialect(); + + // TODO: add any other ops which are considered legal. + // Some operations can be marked as being still legal. + // Example: target.addLegalOp(); + + // Now that the conversion target has been defined, we just need to provide + // the set of patterns that will lower the frontend operations. + OwningRewritePatternList patterns; + + // Convert TensorType to MemRef + TensorTypeConverter tensor_to_memref_converter; + target.addDynamicallyLegalOp([&](FuncOp op) { + // FuncOp is legal only if types have been converted to Std types. + return tensor_to_memref_converter.isSignatureLegal(op.getType()); + }); + + // Type conversion for function signatures. + // Call MLIR FuncOp signature conversion when result type is + // a ranked tensor. + populateFuncOpTypeConversionPattern(patterns, &getContext(), + tensor_to_memref_converter); + + // Frontent operation lowering. + patterns.insert, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXReshapeOpLowering, ONNXEntryPointLowering, + ONNXReductionOpLowering, + ONNXReductionOpLowering, + ONNXReductionOpLowering, + ONNXReductionOpLowering, + ONNXSoftmaxOpLowering, ONNXGemmOpLowering, + ONNXUnsqueezeOpLowering, ONNXTransposeOpLowering, + ONNXIdentityOpLowering, ONNXConvNoBiasOpLowering + >(&getContext()); + + // With the target and rewrite patterns defined, we can now attempt the + // conversion. The conversion will signal failure if any of our `illegal` + // operations were not converted successfully. + if (failed(applyPartialConversion(module, target, patterns))) + signalPassFailure(); +} + +std::unique_ptr mlir::createLowerToKrnlPass() { + return std::make_unique(); +} + +static PassRegistration + pass("lower-frontend", "Lower frontend ops to Krnl dialect.");