diff --git a/.circleci/config.yml b/.circleci/config.yml index 5dc118d..48fda88 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -18,7 +18,7 @@ jobs: git submodule update --init --recursive # Use cached mlir installation if possible. - restore_cache: - key: V4-LLVM-PROJECT-{{ arch }} + key: V6-LLVM-PROJECT-{{ arch }} - run: name: Install MLIR command: | @@ -29,7 +29,7 @@ jobs: source ONNF/utils/install-mlir.sh fi - save_cache: - key: V4-LLVM-PROJECT-{{ arch }} + key: V6-LLVM-PROJECT-{{ arch }} paths: - llvm-project - run: diff --git a/MLIR.cmake b/MLIR.cmake index f7e153d..e330316 100644 --- a/MLIR.cmake +++ b/MLIR.cmake @@ -58,9 +58,11 @@ find_mlir_lib(MLIRAffineOps) find_mlir_lib(MLIRAffineToStandard) find_mlir_lib(MLIRAnalysis) find_mlir_lib(MLIRDialect) +find_mlir_lib(MLIREDSC) find_mlir_lib(MLIRExecutionEngine) find_mlir_lib(MLIRIR) find_mlir_lib(MLIRLLVMIR) +find_mlir_lib(MLIRLoopAnalysis) find_mlir_lib(MLIRLoopToStandard) find_mlir_lib(MLIRLoopOps) find_mlir_lib(MLIRParser) @@ -71,7 +73,8 @@ find_mlir_lib(MLIRTargetLLVMIR) find_mlir_lib(MLIRTransforms) find_mlir_lib(MLIRTransformUtils) find_mlir_lib(MLIRSupport) -find_mlir_lib(MLIROptMain) +find_mlir_lib(MLIRMlirOptMain) +find_mlir_lib(MLIROptLib) find_mlir_lib(MLIRTargetLLVMIRModuleTranslation) find_mlir_lib(MLIRTargetLLVMIR) find_mlir_lib(MLIRTransformUtils) @@ -117,12 +120,15 @@ set(MLIRLibsOnce ${MLIRAffineToStandard} ${MLIRAnalysis} ${MLIRDialect} + ${MLIREDSC} ${MLIRExecutionEngine} ${MLIRIR} ${MLIRLLVMIR} ${MLIRLoopToStandard} ${MLIRLoopOps} - ${MLIROptMain} + ${MLIRLoopAnalysis} + ${MLIRMlirOptMain} + ${MLIROptLib} ${MLIRParser} ${MLIRPass} ${MLIRStandardOps} @@ -226,4 +232,4 @@ function(add_onnf_dialect_doc dialect dialect_tablegen_file) add_dependencies(onnf-doc ${dialect}DocGen) endfunction() -add_custom_target(onnf-doc) \ No newline at end of file +add_custom_target(onnf-doc) diff --git a/doc/Dialects/onnx.md b/doc/Dialects/onnx.md index f67b1ae..cc52df6 100644 --- a/doc/Dialects/onnx.md +++ b/doc/Dialects/onnx.md @@ -332,6 +332,42 @@ ONNX BatchNormalization operation 1. `saved_mean`: memref of any type values or tensor of any type values 1. `saved_var`: memref of any type values or tensor of any type values +### onnx.BatchNormalizationTestMode (ONNXBatchNormalizationTestModeOp) +ONNX BatchNormalization operation in test mode + +#### Description: + + +"Carries out batch normalization as described in the paper" +"https://arxiv.org/abs/1502.03167. Depending on the mode it is being run," +"there are multiple cases for the number of outputs, which we list below:" +"" +"Output case #1: Y, mean, var, saved_mean, saved_var (training mode)" +"Output case #2: Y (test mode)" +"" +"For previous (depreciated) non-spatial cases, implementors are suggested" +"to flatten the input shape to (N x C*D1*D2 ..*Dn) before a BatchNormalization Op." +"This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted." + +#### Operands: + +1. `X`: memref of any type values or tensor of any type values +1. `scale`: memref of any type values or tensor of any type values +1. `B`: memref of any type values or tensor of any type values +1. `mean`: memref of any type values or tensor of any type values +1. `var`: memref of any type values or tensor of any type values + +#### Attributes: + +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `epsilon` | `FloatAttr` | 32-bit float attribute attribute | +| `momentum` | `FloatAttr` | 32-bit float attribute attribute | + +#### Results: + +1. `o_Y`: memref of any type values or tensor of any type values + ### onnx.BitShift (ONNXBitShiftOp) ONNX BitShift operation diff --git a/doc/gen_doc.py b/doc/gen_doc.py index 04d456b..df1337d 100644 --- a/doc/gen_doc.py +++ b/doc/gen_doc.py @@ -36,6 +36,7 @@ special_attr_defaults = dict([ special_op_handler = dict([ ("Conv", "ImportNodeConv"), ("MaxPool", "ImportNodeMaxPool"), + ("BatchNormalization", "ImportNodeBatchNormalization"), ("Gemm", "ImportNodeGemm"), ("Pad", "ImportNodePad"), #("Transpose", "ImportNodeTranspose") diff --git a/src/builder/frontend_dialect_transformer.cpp b/src/builder/frontend_dialect_transformer.cpp index 29b7798..9cadad8 100644 --- a/src/builder/frontend_dialect_transformer.cpp +++ b/src/builder/frontend_dialect_transformer.cpp @@ -434,6 +434,21 @@ private: } } + /*! + * Special handle for BatchNormalization operations. + */ + void ImportNodeBatchNormalization(onnx::NodeProto node, int nIn, int nOut) { + int nOuts = node.output().size(); + if (nOuts == 1) { + // Test mode with one output. + ImportNodeOneOut(node, nIn, + nOuts); + } else { + // Training mode with four trailing optional outputs. Not handled yet. + ImportNodeMultipleOuts(node, nIn, nOuts); + } + } + /*! * Special handle for Gemm operations. */ diff --git a/src/builder/op_build_table.inc b/src/builder/op_build_table.inc index e6d97c5..c0b2ca6 100644 --- a/src/builder/op_build_table.inc +++ b/src/builder/op_build_table.inc @@ -30,7 +30,7 @@ }else if (OpName == "AveragePool") { ImportNodeOneOut(node, 1, 1); }else if (OpName == "BatchNormalization") { - ImportNodeMultipleOuts(node, 5, 5); + ImportNodeBatchNormalization(node, 5, 5); }else if (OpName == "BitShift") { ImportNodeOneOut(node, 2, 1); }else if (OpName == "Cast") { diff --git a/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp b/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp index 9c9b826..84d4be8 100644 --- a/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp +++ b/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp @@ -37,10 +37,18 @@ static bool hasAllConstantDimensions(MemRefType type) { return true; } -/// Convert the given TensorType into the corresponding MemRefType. -static MemRefType convertTensorToMemRef(TensorType type) { - assert(type.hasRank() && "expected only ranked shapes"); - return MemRefType::get(type.getShape(), type.getElementType()); +/// Get the corresponding MemRefType of a given TensorType/MemRefType. +static MemRefType convertToMemRefType(Type type) { + MemRefType memRefType; + auto tensorType = type.dyn_cast(); + if (tensorType) { + assert(tensorType.hasRank() && "expected only ranked shapes"); + memRefType = + MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + } else { + memRefType = type.dyn_cast(); + } + return memRefType; } /// Insert an allocation and deallocation for the given MemRefType. @@ -396,6 +404,7 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, #include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc" // Neural network #include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc" +#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/normalization.inc" //===----------------------------------------------------------------------===// // EntryPoint Op lowering to Krnl Entry Point. @@ -425,9 +434,13 @@ public: struct TensorTypeConverter : public TypeConverter { using TypeConverter::TypeConverter; - LogicalResult convertType(Type t, SmallVectorImpl &results) override { - if (auto tensor_type = t.dyn_cast()) { - results.push_back(convertTensorToMemRef(tensor_type)); + TensorTypeConverter() { + addConversion(convertType); + } + + static LogicalResult convertType(Type t, SmallVectorImpl &results) { + if (auto type = convertToMemRefType(t)) { + results.push_back(type); return success(); } @@ -511,6 +524,7 @@ void FrontendToKrnlLoweringPass::runOnModule() { populateLoweringONNXIdentityOpPattern(patterns, &getContext()); // Neural network populateLoweringONNXConvOpPattern(patterns, &getContext()); + populateLoweringONNXNormalizationOpPattern(patterns, &getContext()); // Entry point patterns.insert(&getContext()); diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/math/elementwise.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/math/elementwise.inc index b48e23a..945d4da 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/math/elementwise.inc +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/math/elementwise.inc @@ -476,11 +476,10 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { // TODO: Check that the types are valid. // An element-wise unary operation must have all operands and the result of // the same type. This should have been verified by the verifier. - auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertTensorToMemRef(tensorType); + auto memRefType = convertToMemRefType(*op->result_type_begin()); // If the output has a dynamic dimension, pass the operands required for // each dynamic dimension to the AllocOp. The first operand of the @@ -545,12 +544,11 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { // TODO: Check that the types are valid. // An element-wise variadic operation must have all operands and the result // of the same type. This should have been verified by the verifier. - auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); auto numArgs = op->getNumOperands(); // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertTensorToMemRef(tensorType); + auto memRefType = convertToMemRefType(*op->result_type_begin()); Value alloc; bool insertDealloc = checkInsertDealloc(op); diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc index f25dc44..af1da9e 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc @@ -8,33 +8,34 @@ // //===----------------------------------------------------------------------===// +template struct ONNXGemmOpLowering : public ConversionPattern { ONNXGemmOpLowering(MLIRContext *ctx) - : ConversionPattern(mlir::ONNXGemmOp::getOperationName(), 1, ctx) {} + : ConversionPattern(GemmOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); + auto has_bias = (operands.size() == 3); Value A, B, C; A = operands[0]; B = operands[1]; - C = operands[2]; + if (has_bias) + C = operands[2]; - auto alphaAttr = FloatAttr::get(tensorType.getElementType(), - llvm::dyn_cast(op).alpha().convertToFloat()); - auto betaAttr = FloatAttr::get(tensorType.getElementType(), - llvm::dyn_cast(op).beta().convertToFloat()); + auto memRefType = convertToMemRefType(*op->result_type_begin()); + + auto alphaAttr = FloatAttr::get(memRefType.getElementType(), + llvm::dyn_cast(op).alpha().convertToFloat()); + auto betaAttr = FloatAttr::get(memRefType.getElementType(), + llvm::dyn_cast(op).beta().convertToFloat()); auto alpha = rewriter.create(loc, alphaAttr); auto beta = rewriter.create(loc, betaAttr); - bool isTransA = (llvm::dyn_cast(op).transA() != 0); - bool isTransB = (llvm::dyn_cast(op).transB() != 0); - - // Result type - auto memRefType = convertTensorToMemRef(tensorType); + bool isTransA = (llvm::dyn_cast(op).transA() != 0); + bool isTransB = (llvm::dyn_cast(op).transB() != 0); // Insert an allocation and deallocation for the result of this operation. Value alloc; @@ -118,14 +119,16 @@ struct ONNXGemmOpLowering : public ConversionPattern { // GemmOp supports unidirectional broadcasting from C to A*B. // Hence, it must be enough to get broadcasting information for C only. std::map broadcastedDimInfo; - auto shape = C.getType().cast().getShape(); - for (int i = 0; i < shape.size(); ++i) { - if (shape[i] < 0) { - auto dim = rewriter.create(loc, C, i).getResult(); - auto one = rewriter.create(loc, 1); - auto isBroadcasted = - rewriter.create(loc, CmpIPredicate::eq, dim, one); - broadcastedDimInfo.insert(std::make_pair(i, isBroadcasted)); + if (has_bias) { + auto shape = C.getType().cast().getShape(); + for (int i = 0; i < shape.size(); ++i) { + if (shape[i] < 0) { + auto dim = rewriter.create(loc, C, i).getResult(); + auto one = rewriter.create(loc, 1); + auto isBroadcasted = + rewriter.create(loc, CmpIPredicate::eq, dim, one); + broadcastedDimInfo.insert(std::make_pair(i, isBroadcasted)); + } } } @@ -157,14 +160,18 @@ struct ONNXGemmOpLowering : public ConversionPattern { auto matmulIterateOp = rewriter.create(loc, reductionPack); // Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting) - auto loopCIVs = getLoopIVsForBroadcasting( - loc, rewriter, loopMNIVs, C, broadcastedDimInfo); - auto loadedC = rewriter.create(loc, C, loopCIVs); auto loadedAB = rewriter.create(loc, alloc, loopMNIVs); auto alphaAB = rewriter.create(loc, alpha, loadedAB); - auto betaC = rewriter.create(loc, beta, loadedC); - auto Y = rewriter.create(loc, alphaAB, betaC); - rewriter.create(loc, Y, alloc, loopMNIVs); + if (has_bias) { + auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C, + broadcastedDimInfo); + auto loadedC = rewriter.create(loc, C, loopCIVs); + auto betaC = rewriter.create(loc, beta, loadedC); + auto Y = rewriter.create(loc, alphaAB, betaC); + rewriter.create(loc, Y, alloc, loopMNIVs); + } else { + rewriter.create(loc, alphaAB, alloc, loopMNIVs); + } // Insert instructions to do matrix multiplication: A*B Block &matmulIterationBlock = matmulIterateOp.bodyRegion().front(); @@ -205,5 +212,6 @@ struct ONNXGemmOpLowering : public ConversionPattern { void populateLoweringONNXGemmOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx) { - patterns.insert(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); } diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/math/matmul.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/math/matmul.inc index 5c6ebd7..1af1f1b 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/math/matmul.inc +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/math/matmul.inc @@ -15,7 +15,6 @@ struct ONNXMatMulOpLowering : public ConversionPattern { PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); Value A = operands[0]; @@ -29,7 +28,7 @@ struct ONNXMatMulOpLowering : public ConversionPattern { // - Both arguments are 1-D // Result type - auto memRefType = convertTensorToMemRef(tensorType); + auto memRefType = convertToMemRefType(*op->result_type_begin()); auto elementType = memRefType.getElementType(); auto memRefShape = memRefType.getShape(); diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/math/reduction.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/math/reduction.inc index 27f594e..9b94861 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/math/reduction.inc +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/math/reduction.inc @@ -145,9 +145,9 @@ struct ONNXReductionOpLowering : public ConversionPattern { auto loc = op->getLoc(); auto memRefInType = operands[0].getType().cast(); auto memRefInShape = memRefInType.getShape(); - auto tensorOutType = (*op->result_type_begin()).cast(); + auto memRefOutType = convertToMemRefType(*op->result_type_begin()); int64_t inRank = memRefInType.getRank(); - int64_t outRank = tensorOutType.getRank(); + int64_t outRank = memRefOutType.getRank(); // Get attributes ArrayAttr axisAttrs = llvm::dyn_cast(op).axesAttr(); @@ -171,7 +171,6 @@ struct ONNXReductionOpLowering : public ConversionPattern { bool isKeepdims = (keepdims == 1) ? true : false; // Get type information - auto memRefOutType = convertTensorToMemRef(tensorOutType); auto memRefOutShape = memRefOutType.getShape(); auto elementOutType = memRefOutType.getElementType(); std::map outInDimMap = diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/math/softmax.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/math/softmax.inc index eb126c0..3f24a6e 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/math/softmax.inc +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/math/softmax.inc @@ -18,8 +18,8 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern { // let exp_x = exp(x - max_x) in // let sum = sum(exp_x) in // exp_x / sum - auto tensorType = (*op->result_type_begin()).cast(); - int64_t rank = tensorType.getRank(); + auto memRefType = convertToMemRefType(*op->result_type_begin()); + int64_t rank = memRefType.getRank(); int64_t axis = llvm::dyn_cast(op).axis().getSExtValue(); axis = axis >= 0 ? axis : rank + axis; assert(axis >= -rank && axis <= rank - 1); @@ -27,7 +27,6 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern { auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertTensorToMemRef(tensorType); auto elementType = memRefType.getElementType(); Value alloc; diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc index 3ecfa3e..20ac5e8 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc @@ -15,10 +15,9 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern { PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertTensorToMemRef(tensorType); + auto memRefType = convertToMemRefType(*op->result_type_begin()); Value alloc; bool insertDealloc = checkInsertDealloc(op); ONNXConvNoBiasOp convOp = llvm::dyn_cast(op); diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/nn/normalization.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/nn/normalization.inc new file mode 100644 index 0000000..cb98b13 --- /dev/null +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/nn/normalization.inc @@ -0,0 +1,140 @@ +//===----- normalization.inc - Lowering Normalization Ops -----------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers ONNX Normalization Operators to Krnl dialect. +// +//===----------------------------------------------------------------------===// + +struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern { + ONNXBatchNormalizationTestModeOpLowering(MLIRContext *ctx) + : ConversionPattern( + mlir::ONNXBatchNormalizationTestModeOp::getOperationName(), 1, + ctx) {} + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter & rewriter) const final { + // batchnorm{epsilon}(x, scale, bias, mean, variance) = + // scale * (x - mean) / sqrt(variance + epsilon) + bias + auto loc = op->getLoc(); + + auto memRefType = convertToMemRefType(*op->result_type_begin()); + auto epsilonAttr = + FloatAttr::get(memRefType.getElementType(), + llvm::dyn_cast(op) + .epsilon() + .convertToFloat()); + auto epsilon = rewriter.create(loc, epsilonAttr); + + auto operand = operands[0]; + auto scale = operands[1]; + auto bias = operands[2]; + auto mean = operands[3]; + auto variance = operands[4]; + + // Insert an allocation and deallocation for the result of this operation. + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, + {operand}); + + // Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N. + // In case of N, C is assumed to be 1. + // Shapes of scale, bias, mean and variance must be C. + // Computation of BatchNormalization is done as if scale, bias, mean, and + // variance are reshaped to Cx1x1x...x1. + + // rank + int64_t rank = memRefType.getRank(); + + std::vector originalLoops; + std::vector optimizedLoops; + Block *optimizationBlock = + defineLoops(rewriter, loc, originalLoops, optimizedLoops, rank); + + // Create a KrnlIterateOp along C dimension. + // This will be the outer-most loop in order to re-use scale, bias, + // mean and variance. + + SmallVector loopCIVs; + if (rank > 1) { + KrnlIterateOperandPack cPack(rewriter, originalLoops[1], + optimizedLoops[1]); + addDimensionToPack(rewriter, loc, cPack, operand, 1); + auto cIterateOp = rewriter.create(loc, cPack); + Block &cIterationBlock = cIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&cIterationBlock); + for (auto arg : cIterationBlock.getArguments()) + loopCIVs.emplace_back(arg); + } else { + loopCIVs.emplace_back(rewriter.create(loc, 0)); + } + + auto scaleVal = rewriter.create(loc, scale, loopCIVs); + auto biasVal = rewriter.create(loc, bias, loopCIVs); + auto meanVal = rewriter.create(loc, mean, loopCIVs); + auto varianceVal = rewriter.create(loc, variance, loopCIVs); + + // Create a KrnlIterateOp along the other dimensions. + SmallVector axes; + axes.emplace_back(0); + for (int64_t i = 2; i < rank; ++i) + axes.emplace_back(i); + std::vector packLoops, packOptimizedLoops; + for (int i = 0; i < axes.size(); ++i) { + packLoops.emplace_back(originalLoops[axes[i]]); + packOptimizedLoops.emplace_back(optimizedLoops[axes[i]]); + } + KrnlIterateOperandPack pack(rewriter, packLoops, packOptimizedLoops); + for (int i = 0; i < axes.size(); ++i) { + addDimensionToPack(rewriter, loc, pack, operand, axes[i]); + } + auto iterateOp = rewriter.create(loc, pack); + + // No optimization + rewriter.setInsertionPointToEnd(optimizationBlock); + rewriter.create(loc, originalLoops); + + Block &iterationBlock = iterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&iterationBlock); + + SmallVector loopIVs; + auto args = iterationBlock.getArguments(); + if (args.size() > 1) { + loopIVs.emplace_back(args[0]); + loopIVs.emplace_back(loopCIVs[0]); // Insert C back. + for (int i = 1; i < args.size(); ++i) + loopIVs.emplace_back(args[i]); + } else { + loopIVs.emplace_back(args[0]); + } + + auto xVal = rewriter.create(loc, operand, loopIVs); + // normalize + auto dividend = rewriter.create(loc, xVal, meanVal); + auto adjustedVarianceVal = + rewriter.create(loc, varianceVal, epsilon); + auto divisor = rewriter.create(loc, memRefType.getElementType(), + adjustedVarianceVal); + auto normVal = rewriter.create(loc, dividend, divisor); + // scale and shift + auto scaleNormVal = rewriter.create(loc, scaleVal, normVal); + auto shiftScaleNormVal = + rewriter.create(loc, scaleNormVal, biasVal); + rewriter.create(loc, shiftScaleNormVal, alloc, loopIVs); + + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +void populateLoweringONNXNormalizationOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/reshape.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/reshape.inc index ed2b185..b64494f 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/reshape.inc +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/reshape.inc @@ -15,12 +15,11 @@ struct ONNXReshapeOpLowering : public ConversionPattern { PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - auto tensorType = (*op->result_type_begin()).cast(); auto inputShape = operands[0].getType().cast().getShape(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertTensorToMemRef(tensorType); + auto memRefType = convertToMemRefType(*op->result_type_begin()); auto memRefShape = memRefType.getShape(); Value alloc; diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/transpose.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/transpose.inc index 39cfa8c..3bb897a 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/transpose.inc +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/transpose.inc @@ -15,10 +15,9 @@ struct ONNXTransposeOpLowering : public ConversionPattern { PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertTensorToMemRef(tensorType); + auto memRefType = convertToMemRefType(*op->result_type_begin()); Value alloc; bool insertDealloc = checkInsertDealloc(op); diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc index 18b9f8b..6d5289d 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc @@ -16,8 +16,8 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern { matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - auto tensorType = (*op->result_type_begin()).cast(); - int outRank = tensorType.getRank(); + auto memRefType = convertToMemRefType(*op->result_type_begin()); + int outRank = memRefType.getRank(); // Assume that `axes` has been validated by shape inference. // So, here we just get it. @@ -30,7 +30,6 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern { } // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertTensorToMemRef(tensorType); Value alloc; // Compute size in bytes. diff --git a/src/dialect/onnx/onnx.td b/src/dialect/onnx/onnx.td index c09d910..ef4e62a 100644 --- a/src/dialect/onnx/onnx.td +++ b/src/dialect/onnx/onnx.td @@ -146,6 +146,31 @@ def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut", let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y); } +def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "ONNX BatchNormalization operation in test mode"; + let description = [{ + "Carries out batch normalization as described in the paper" + "https://arxiv.org/abs/1502.03167. Depending on the mode it is being run," + "there are multiple cases for the number of outputs, which we list below:" + "" + "Output case #1: Y, mean, var, saved_mean, saved_var (training mode)" + "Output case #2: Y (test mode)" + "" + "For previous (depreciated) non-spatial cases, implementors are suggested" + "to flatten the input shape to (N x C*D1*D2 ..*Dn) before a BatchNormalization Op." + "This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted." + }]; + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + AnyTypeOf<[AnyMemRef, AnyTensor]>:$scale, + AnyTypeOf<[AnyMemRef, AnyTensor]>:$B, + AnyTypeOf<[AnyMemRef, AnyTensor]>:$mean, + AnyTypeOf<[AnyMemRef, AnyTensor]>:$var, + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$momentum); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y); +} + def ONNXPadConstantValueOp : ONNX_Op<"PadConstantValue", [NoSideEffect ]> { let summary = "ONNX Pad operation with constant padding value"; diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 2031a9a..25a86e9 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -586,12 +586,72 @@ void ONNXGemmNoBiasOp::inferShapes() { return; auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); + + int64_t M, N, K_A, K_B; + M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1]; + K_A = (transA() == 0) ? lhsTy.getShape()[1] : lhsTy.getShape()[0]; + N = (transB() == 0) ? rhsTy.getShape()[1] : rhsTy.getShape()[0]; + K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1]; + + if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) { + emitError("Tensor shapes mismatched."); + } + SmallVector dims; - dims.emplace_back(lhsTy.getShape()[0]); - dims.emplace_back(rhsTy.getShape()[1]); + dims.emplace_back(M); + dims.emplace_back(N); getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); } +/// BatchNormalizationTestMode +void ONNXBatchNormalizationTestModeOp::inferShapes() { + // Cannot infer shape if no shape exists. + if (!getOperand(0).getType().isa() || + !getOperand(1).getType().isa() || + !getOperand(2).getType().isa() || + !getOperand(3).getType().isa() || + !getOperand(4).getType().isa()) + return; + + auto input = getOperand(0).getType().cast(); + auto scale = getOperand(1).getType().cast(); + auto bias = getOperand(2).getType().cast(); + auto mean = getOperand(3).getType().cast(); + auto variance = getOperand(4).getType().cast(); + + // Check whether the shapes of scale, bias, mean and variance are valid. + // Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N. + // In case of N, C is assumed to be 1. + // Shapes of scale, bias, mean and variance must be C. + int64_t c = -1; + if (input.getShape().size() == 1) { + c = 1; + } else if (input.getShape().size() > 2) { + c = (input.getShape()[1] != -1) ? input.getShape()[1] : -1; + } else { + emitError("Wrong rank for the input."); + } + + if (c != -1) { + auto s = scale.getShape(); + auto b = bias.getShape(); + auto m = mean.getShape(); + auto v = variance.getShape(); + + if ((s.size() != 1) || (s[0] != -1 && s[0] != c)) + emitError("Wrong rank for the scale."); + if ((b.size() != 1) || (b[0] != -1 && b[0] != c)) + emitError("Wrong rank for the bias."); + if ((m.size() != 1) || (m[0] != -1 && m[0] != c)) + emitError("Wrong rank for the mean."); + if ((v.size() != 1) || (v[0] != -1 && v[0] != c)) + emitError("Wrong rank for the variance."); + } + + // The output tensor of the same shape as the input. + getResult().setType(getOperand(0).getType()); +} + // TODO: // Verify that matrix sizes are valid for multiplication and addition. // Take into account the dimensionality of the matrix. diff --git a/src/main.cpp b/src/main.cpp index f0de7e9..e3a36c5 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -24,6 +24,7 @@ #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" #include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/InitAllDialects.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/Parser.h" @@ -69,6 +70,10 @@ void EmitLLVMBitCode(const mlir::OwningModuleRef &module) { } int main(int argc, char *argv[]) { + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); mlir::registerDialect(); mlir::registerDialect(); diff --git a/src/pass/lower_frontend_to_krnl.cpp b/src/pass/lower_frontend_to_krnl.cpp new file mode 100644 index 0000000..d609bc5 --- /dev/null +++ b/src/pass/lower_frontend_to_krnl.cpp @@ -0,0 +1,2380 @@ +//====- lower_frontend_to_krnl.cpp - Frontend dialects to Krnl lowering ---===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file implements the lowering of frontend operations to a combination of +// Krnl IR and standard operations. +// +//===----------------------------------------------------------------------===// +#include + +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Sequence.h" + +#include "src/dialect/krnl/krnl_helper.hpp" +#include "src/dialect/krnl/krnl_ops.hpp" +#include "src/dialect/onnx/onnx_ops.hpp" +#include "src/pass/passes.hpp" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// FrontendToAffine RewritePatterns +//===----------------------------------------------------------------------===// + +/// Check is all dimensions are known at compile time. +static bool hasAllConstantDimensions(MemRefType type) { + auto memRefShape = type.getShape(); + for (int i = 0; i < memRefShape.size(); ++i) + if (memRefShape[i] < 0) + return false; + return true; +} + +/// Get the corresponding MemRefType of a given TensorType/MemRefType. +static MemRefType convertToMemRefType(Type type) { + MemRefType memRefType; + auto tensorType = type.dyn_cast(); + if (tensorType) { + assert(tensorType.hasRank() && "expected only ranked shapes"); + memRefType = + MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + } else { + memRefType = type.dyn_cast(); + } + return memRefType; +} + +/// Insert an allocation and deallocation for the given MemRefType. +static Value insertAllocAndDealloc(MemRefType type, Location loc, + PatternRewriter &rewriter, + bool insertDealloc, + ArrayRef operands = {}) { + // Put together alloc operands for any dynamic dimensions of the memref. + AllocOp alloc; + if (!operands.empty()) { + auto memRefShape = type.getShape(); + auto rank = memRefShape.size(); + + std::map fromOperands; + for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { + int memRefDimIdx = rank - 1 - reversedIdx; + if (memRefShape[memRefDimIdx] < 0) { // unknown dimension + Value maxDim = nullptr; + for (int i = 0; i < operands.size(); i++) { + auto operandShape = + operands[i].getType().cast().getShape(); + int operandDimIdx = operandShape.size() - 1 - reversedIdx; + + if (operandDimIdx < 0) + continue; + + // In case of operations with broadcasting, the dimension of the + // alloc result is the maximum size along each dimension of the + // operands. + auto operandDim = + rewriter.create(loc, operands[i], operandDimIdx); + if (maxDim) { + auto maxCondition = rewriter.create(loc, CmpIPredicate::sgt, + operandDim, maxDim); + maxDim = rewriter.create(loc, maxCondition, operandDim, + maxDim); + } else { + maxDim = operandDim; + } + } + fromOperands.insert(std::make_pair(memRefDimIdx, maxDim)); + } + } + + SmallVector allocOperands; + for (int i = 0; i < rank; ++i) + if (memRefShape[i] < 0) + allocOperands.push_back(fromOperands[i]); + alloc = rewriter.create(loc, type, allocOperands); + } else { + alloc = rewriter.create(loc, type); + } + + // Make sure to allocate at the beginning of the block if + // all dimensions are known. + auto *parentBlock = alloc.getOperation()->getBlock(); + if (hasAllConstantDimensions(type)) + alloc.getOperation()->moveBefore(&parentBlock->front()); + + if (insertDealloc) { + auto dealloc = rewriter.create(loc, alloc); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + } + + return alloc; +} + +// Determine if current function returns the result value of the +// current op being lowered. If it does then dealloc should not be +// inserted. +static bool checkInsertDealloc(Operation *currentOp) { + auto parentBlock = currentOp->getBlock(); + + bool insertDealloc = true; + parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) { + assert(currentOp->getNumResults() < 2 && + "No more than one result supported (for now)."); + // If there is at least one result to investigate. + if (currentOp->getNumResults() > 0) { + auto result = currentOp->getResult(0); + for (const auto &operand : op.getOperands()) + if (operand == result) + insertDealloc = false; + } + }); + + return insertDealloc; +} + +// Create a mapping from result type's dimensions to input type's dimensions, +// given that the result type is the result of a reduction op over the input +// type. +std::map +getReductionMapping(MemRefType inputTy, ArrayRef axes, bool keepdims) { + std::map OutInDimMap; + int64_t rank = inputTy.getRank(); + + // Mark reduction axes. + std::vector isReductionAxis; + for (decltype(rank) i = 0; i < rank; ++i) { + if (std::find(axes.begin(), axes.end(), i) != axes.end()) + isReductionAxis.push_back(true); + else + isReductionAxis.push_back(false); + } + + for (decltype(rank) inIndex = 0, outIndex = 0; inIndex < rank; ++inIndex) { + // If it is a reduction axis, there is no relationship among dimensions. + if (isReductionAxis[inIndex]) { + if (keepdims) + outIndex++; + } else { + OutInDimMap.insert(std::make_pair(outIndex, inIndex)); + outIndex++; + } + } + + return OutInDimMap; +} + +// Add bounds associated with the op operand to the KRNL iteration pack. +// Dynamic dimenions are supported. +static void addDimensionToPack(ConversionPatternRewriter &rewriter, + Location loc, KrnlIterateOperandPack &pack, Value operand, int index) { + auto shape = operand.getType().cast().getShape(); + if (shape[index] < 0) { + pack.pushConstantBound(0); + pack.pushOperandBound( + rewriter.create(loc, operand, index).getResult()); + } else { + pack.pushConstantBound(0); + pack.pushConstantBound(shape[index]); + } +} + +// Function that defines the KRNL dialect loops and their respective +// optimized version. +static KrnlOptimizeLoopsOp emitOptimizedLoops( + ConversionPatternRewriter &rewriter, Location loc, + std::vector &loops, std::vector &optimizedLoops, + int64_t numLoops) { + // Define loops. + auto loopsOp = rewriter.create(loc, numLoops); + loops.reserve(numLoops); + for (auto result : loopsOp.getResults()) + loops.push_back(result); + + // Define optimized version of the loops. + auto optimizedLoopsOp = rewriter.create(loc, numLoops); + optimizedLoops.reserve(numLoops); + for (auto result : optimizedLoopsOp.getResults()) + optimizedLoops.push_back(result); + + return optimizedLoopsOp; +} + +// Function that emits the loops and their optimized version. +// The function returns a reference to the inner optimization block. +static Block* defineLoops(ConversionPatternRewriter &rewriter, + Location loc, std::vector &loops, + std::vector &optimizedLoops, int64_t numLoops) { + KrnlOptimizeLoopsOp optimizedLoopsOp = emitOptimizedLoops( + rewriter, loc, loops, optimizedLoops, numLoops); + return &optimizedLoopsOp.region().front(); +} + +// Function which emits a basic set of loops and optimized loops +// for a given operation argument. A reference to the loop optimization +// block is returned in the last argument of the function. +static void emitKrnlLoopsAndIterationForOperand( + ConversionPatternRewriter &rewriter, Location loc, + Value operand, std::vector &originalLoops, + KrnlOptimizeLoopsOp &optimizedLoopsOp, KrnlIterateOp &iterateOp) { + // Operand shape. + auto shape = operand.getType().cast().getShape(); + + // Number of loops. + int64_t rank = shape.size(); + + // Define loops and optimized loops. + std::vector optimizedLoops; + optimizedLoopsOp = emitOptimizedLoops(rewriter, loc, originalLoops, + optimizedLoops, rank); + + KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); + // Iterate over the loop nest. + for (int i = 0; i < rank; ++i) + addDimensionToPack(rewriter, loc, pack, operand, i); + + iterateOp = rewriter.create(loc, pack); +} + +unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { + auto elementType = memRefType.getElementType(); + + unsigned sizeInBits; + if (elementType.isIntOrFloat()) { + sizeInBits = elementType.getIntOrFloatBitWidth(); + } else { + auto vectorType = elementType.cast(); + sizeInBits = + vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); + } + return llvm::divideCeil(sizeInBits, 8); +} + +// Get run-time dimension information for unknown dimensions used for +// broadcasting. +std::map> +getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, + MemRefType memRefType, ArrayRef operands) { + auto memRefShape = memRefType.getShape(); + int64_t rank = memRefShape.size(); + // For unknown dimensions, we need to get dimension values at runtime in + // order to do broadcasting. + std::map> DimInfo; + // For each result dimension, compute the number of sharing operands. + // Sharing operands are operands sharing the same index (counting from the + // rightmost to the leftmost) for a given dimension. + std::map sharedDimCount; + for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { + int dimIdx = rank - 1 - reversedIdx; + sharedDimCount[dimIdx] = 0; + for (int i = 0; i < operands.size(); ++i) { + auto shape = operands[i].getType().cast().getShape(); + if (reversedIdx <= shape.size() - 1) + sharedDimCount[dimIdx]++; + } + } + // An unknown dimension can have a value of 1 or N (N > 1). + // If its value is 1, it is broadcasted dimension. + // Otherwise, non-broadcasted dimension. + // We only care about unknown dimensions whose number of sharing operands is + // more than one, since they are potentially broadcasted dimensions. + for (int i = 0; i < operands.size(); ++i) { + std::map broadcastedDims; + auto shape = operands[i].getType().cast().getShape(); + int size = shape.size(); + for (int j = 0; j < shape.size(); ++j) { + if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) { + auto dim = rewriter.create(loc, operands[i], j).getResult(); + auto one = rewriter.create(loc, 1); + auto isBroadcasted = + rewriter.create(loc, CmpIPredicate::eq, dim, one); + broadcastedDims.insert(std::make_pair(j, isBroadcasted)); + } + } + DimInfo.insert(std::make_pair(i, broadcastedDims)); + } + return DimInfo; +} + +// Extract induction variables that are used for broadcasting values of a +// given operand. +std::vector +getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, + ArrayRef loopIVs, Value operand, + std::map broadcastedDims) { + // `operand` must has a ranked type. This should have been checked by the + // shape inference pass. + auto operandShape = operand.getType().cast().getShape(); + auto rank = operandShape.size(); + auto loopCount = loopIVs.size(); + + std::vector newLoopIVs; + for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { + auto dimIdx = rank - 1 - reversedIdx; + auto loopIdx = loopCount - 1 - reversedIdx; + if (operandShape[dimIdx] == 1) { + // Broadcasted dimension + auto zero = rewriter.create(loc, 0); + newLoopIVs.insert(newLoopIVs.begin(), zero); + } else if ((operandShape[dimIdx] == -1) && + (broadcastedDims.find(dimIdx) != broadcastedDims.end())) { + // Unknown dimension, it can have a value of 1 or N (N > 1). + // If its value is 1, it is broadcasted dimension. + // Otherwise, non-broadcasted dimension. + auto zero = rewriter.create(loc, 0); + auto idx = rewriter.create(loc, broadcastedDims[dimIdx], zero, + loopIVs[loopIdx]); + newLoopIVs.insert(newLoopIVs.begin(), idx); + } else { + // Non-broadcasted dimension + newLoopIVs.insert(newLoopIVs.begin(), loopIVs[loopIdx]); + } + } + return newLoopIVs; +} + +namespace { + +template +struct ScalarOp; + +template <> +struct ScalarOp { + using FOp = AddFOp; + using IOp = AddIOp; +}; + +template <> +struct ScalarOp { + using FOp = MulFOp; + using IOp = MulIOp; +}; + +template <> +struct ScalarOp { + using FOp = DivFOp; + using IOp = SignedDivIOp; +}; + +template <> +struct ScalarOp { + using FOp = SubFOp; + using IOp = SubIOp; +}; + +template <> +struct ScalarOp { + using FOp = AndOp; // not use + using IOp = AndOp; +}; + +template <> +struct ScalarOp { + using FOp = OrOp; // not use + using IOp = OrOp; +}; + +template <> +struct ScalarOp { + using FOp = XOrOp; // not use + using IOp = XOrOp; +}; + +template <> +struct ScalarOp { + using FOp = ExpOp; + using IOp = ExpOp; // not use +}; + +template <> +struct ScalarOp { + using FOp = AddFOp; + using IOp = AddIOp; +}; + +template <> +struct ScalarOp { + using FOp = TanhOp; + using IOp = TanhOp; // not use +}; + +template <> +struct ScalarOp { + using FOp = CosOp; + using IOp = CosOp; // not use +}; + +template <> +struct ScalarOp { + using FOp = LogOp; + using IOp = LogOp; // not use +}; + +template <> +struct ScalarOp { + using FOp = MulFOp; + using IOp = MulIOp; +}; + +template <> +struct ScalarOp { + using FOp = AddFOp; + using IOp = AddIOp; +}; + +template <> +struct ScalarOp { + using FOp = KrnlSqrtOp; + using IOp = KrnlSqrtOp; // not use +}; + +template +using ScalarFOp = typename ScalarOp::FOp; +template +using ScalarIOp = typename ScalarOp::IOp; + +// Get the identity element of a operation. +// Return NULL if the function does not have identity. +template +DataType getIdentityValue() { + return NULL; +} + +template <> +float getIdentityValue(){ + return (float)-std::numeric_limits::infinity(); +} + +template <> +int getIdentityValue(){ + return std::numeric_limits::min(); +} + +template <> +float getIdentityValue(){ + return (float)std::numeric_limits::infinity(); +} + +template <> +int getIdentityValue(){ + return std::numeric_limits::max(); +} + +template <> +float getIdentityValue(){ + return (float)1.0; +} + +template <> +int getIdentityValue(){ + return 1; +} + +template <> +float getIdentityValue(){ + return (float)0; +} + +template <> +int getIdentityValue(){ + return 0; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering to Krnl dialect. +//===----------------------------------------------------------------------===// +template +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + /* Lower UnaryOp to Ops in the Standard dialect. + */ + auto loc = op->getLoc(); + Type element_type = operands.front().getType(); + if (element_type.isa()) { + return rewriter.create>(loc, result_types, operands, + mlir::None); + } else if (element_type.isa()) { + return rewriter.create>(loc, result_types, operands, + mlir::None); + } else { + emitError(loc, "unsupported element type"); + return nullptr; + } +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXSinhOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), + // ConstantOp 2) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto two = rewriter.create(loc, FloatAttr::get(elementType, 2)); + auto neg = rewriter.create(loc, zero, operand); + auto exp = rewriter.create(loc, operand); + auto negExp = rewriter.create(loc, neg); + auto result = rewriter.create( + loc, rewriter.create(loc, exp, negExp), two); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXCoshOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))), + // ConstantOp 2) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto two = rewriter.create(loc, FloatAttr::get(elementType, 2)); + auto neg = rewriter.create(loc, zero, operand); + auto exp = rewriter.create(loc, operand); + auto negExp = rewriter.create(loc, neg); + auto result = rewriter.create( + loc, rewriter.create(loc, exp, negExp), two); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXSigmoidOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, + ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1, + // AddFOp(ConstantOp 1, ExpOp(NegFOp(%X)))) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); + auto neg = rewriter.create(loc, zero, operand); + auto negExp = rewriter.create(loc, neg); + auto result = rewriter.create( + loc, one, rewriter.create(loc, one, negExp)); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXHardSigmoidOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp( + Operation *op, ArrayRef result_types, ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // %Y = AddFOp(MulFOp(alpha, %X), beta) + // %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0), + // %Y, + // Constant 0) + // ONNXHardSigmoidOp(%X) = SelectOp(CmpFOp(OLT, %Z, Constant 1), + // %Z, + // Constant 1) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), + llvm::dyn_cast(op).alpha().convertToFloat()); + auto betaAttribute = FloatAttr::get(rewriter.getF32Type(), + llvm::dyn_cast(op).beta().convertToFloat()); + auto elementType = result_types[0]; + + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); + auto alpha = rewriter.create(loc, alphaAttribute); + auto beta = rewriter.create(loc, betaAttribute); + + auto add = rewriter.create( + loc, rewriter.create(loc, alpha, operand), beta); + auto maxPredicate = + rewriter.create(loc, CmpFPredicate::OGT, add, zero); + auto max = rewriter.create(loc, maxPredicate, add, zero); + auto minPredicate = + rewriter.create(loc, CmpFPredicate::OLT, max, one); + auto result = rewriter.create(loc, minPredicate, max, one); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXEluOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), + // MulFOp(alpha, SubFOp(ExpOp(%X), 1)), + // %X) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), + llvm::dyn_cast(op).alpha().convertToFloat()); + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); + auto alpha = rewriter.create(loc, alphaAttribute); + auto exp = rewriter.create(loc, operand); + auto lessThanZero = + rewriter.create(loc, CmpFPredicate::OLT, operand, zero); + auto result = rewriter.create( + loc, lessThanZero, + rewriter.create(loc, alpha, + rewriter.create(loc, exp, one)), + operand); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXReluOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), + // ConstantOp 0, + // %X) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto lessThanZero = + rewriter.create(loc, CmpFPredicate::OLT, operand, zero); + auto result = rewriter.create(loc, lessThanZero, zero, operand); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXLeakyReluOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, + ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), + // MulFOp(alpha, %X), + // %X) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), + llvm::dyn_cast(op).alpha().convertToFloat()); + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto alpha = rewriter.create(loc, alphaAttribute); + auto lessThanZero = + rewriter.create(loc, CmpFPredicate::OLT, operand, zero); + auto result = rewriter.create( + loc, lessThanZero, rewriter.create(loc, alpha, operand), operand); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXSeluOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0), + // MulFOp(gamma, %X), + // MulFOp(gamma, + // SubFOp(MulFOp(alpha, ExpOp(%X)), + // alpha))) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), + llvm::dyn_cast(op).alpha().convertToFloat()); + auto gammaAttribute = FloatAttr::get(rewriter.getF32Type(), + llvm::dyn_cast(op).gamma().convertToFloat()); + auto elementType = result_types[0]; + + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto alpha = rewriter.create(loc, alphaAttribute); + auto gamma = rewriter.create(loc, gammaAttribute); + auto exp = rewriter.create(loc, operand); + auto greaterThanZero = + rewriter.create(loc, CmpFPredicate::OGT, operand, zero); + auto select = rewriter.create( + loc, greaterThanZero, operand, + rewriter.create(loc, rewriter.create(loc, alpha, exp), + alpha)); + auto result = rewriter.create(loc, gamma, select); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXReciprocalOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp( + Operation *op, ArrayRef result_types, ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); + auto result = rewriter.create(loc, one, operand); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXSoftplusOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp( + Operation *op, ArrayRef result_types, ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXSoftplusOp(%X) = LogOp(AddFOp(ExpOp(%X), ConstantOp 1)) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto exp = rewriter.create(loc, operand); + auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); + auto add = rewriter.create(loc, exp, one); + auto result = rewriter.create(loc, add); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXSoftsignOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp( + Operation *op, ArrayRef result_types, ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXSoftsignOp(%X) = DivFOp(ConstantOp 1, %X) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto abs = rewriter.create(loc, operand); + auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); + auto add = rewriter.create(loc, abs, one); + auto result = rewriter.create(loc, operand, add); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXSignOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + + auto loc = op->getLoc(); + Value operand = operands[0]; + Type element_type = operands.front().getType(); + // TODO: unsigned int should be supported separately? + if (element_type.isa()) { + // %Y = SelectOP(CmpIOp(GT, %X, ConstantOp 0), + // ConstantOp 1, + // COnstantOp -1) + // ONNXSignOp(%X) = SelectOP(CmpIOp(EQ, %X, ConstantOp 0), + // ConstantOp 0, + // %Y) + auto zero = rewriter.create(loc, rewriter.getI32IntegerAttr(0)); + auto one = rewriter.create(loc, rewriter.getI32IntegerAttr(1)); + auto minusOne = + rewriter.create(loc, rewriter.getI32IntegerAttr(-1)); + auto plusPredicate = + rewriter.create(loc, CmpIPredicate::sgt, operand, zero); + auto plusSelect = + rewriter.create(loc, plusPredicate, one, minusOne); + auto zeroPredicate = + rewriter.create(loc, CmpIPredicate::eq, operand, zero); + auto result = + rewriter.create(loc, zeroPredicate, zero, plusSelect); + return result; + } else if (element_type.isa()) { + // %Y = SelectOP(CmpFOp(OGT, %X, ConstantOp 0), + // ConstantOp 1, + // ConstantOp -1) + // ONNXSignOp(%X) = SelectOP(CmpFOp(OEQ, %X, ConstantOp 0), + // ConstantOp 0, + // %Y) + auto zero = + rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); + auto one = rewriter.create(loc, rewriter.getF32FloatAttr(1.0f)); + auto minusOne = + rewriter.create(loc, rewriter.getF32FloatAttr(-1.0f)); + auto plusPredicate = + rewriter.create(loc, CmpFPredicate::OGT, operand, zero); + auto plusSelect = + rewriter.create(loc, plusPredicate, one, minusOne); + auto zeroPredicate = + rewriter.create(loc, CmpFPredicate::OEQ, operand, zero); + auto result = + rewriter.create(loc, zeroPredicate, zero, plusSelect); + return result; + } else { + emitError(loc, "unsupported element type"); + } +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXMaxOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y), + // %X, + // %Y) + auto loc = op->getLoc(); + Value lhs = operands[0]; + Value rhs = operands[1]; + auto max = rewriter.create(loc, CmpFPredicate::OGT, lhs, rhs); + auto result = rewriter.create(loc, max, lhs, rhs); + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXMinOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y), + // %X, + // %Y) + auto loc = op->getLoc(); + Value lhs = operands[0]; + Value rhs = operands[1]; + auto min = rewriter.create(loc, CmpFPredicate::OLT, lhs, rhs); + auto result = rewriter.create(loc, min, lhs, rhs); + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXReduceMaxOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, + ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + auto loc = op->getLoc(); + Value lhs = operands[0]; + Value rhs = operands[1]; + Type element_type = lhs.getType(); + if (element_type.isa()) { + auto max = rewriter.create(loc, CmpIPredicate::sgt, lhs, rhs); + auto result = rewriter.create(loc, max, lhs, rhs); + return result; + } else if (element_type.isa()) { + auto max = rewriter.create(loc, CmpFPredicate::OGT, lhs, rhs); + auto result = rewriter.create(loc, max, lhs, rhs); + return result; + } else { + emitError(loc, "unsupported element type"); + return nullptr; + } +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXReduceMinOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, + ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + auto loc = op->getLoc(); + Value lhs = operands[0]; + Value rhs = operands[1]; + Type element_type = lhs.getType(); + if (element_type.isa()) { + auto min = rewriter.create(loc, CmpIPredicate::slt, lhs, rhs); + auto result = rewriter.create(loc, min, lhs, rhs); + return result; + } else if (element_type.isa()) { + auto min = rewriter.create(loc, CmpFPredicate::OLT, lhs, rhs); + auto result = rewriter.create(loc, min, lhs, rhs); + return result; + } else { + emitError(loc, "unsupported element type"); + return nullptr; + } +} + +// Element-wise unary ops lowering to Krnl dialect. +//===----------------------------------------------------------------------===// +template +struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { + ONNXElementwiseUnaryOpLowering(MLIRContext *ctx) + : ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {} + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // TODO: Check that the types are valid. + // An element-wise unary operation must have all operands and the result of + // the same type. This should have been verified by the verifier. + + auto loc = op->getLoc(); + + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertToMemRefType(*op->result_type_begin()); + + // If the output has a dynamic dimension, pass the operands required for + // each dynamic dimension to the AllocOp. The first operand of the + // operation is used. The operands of the op need to match in terms of + // dimensions with the result at this pre-optimization phase. + // TODO: verify that dimensions match. + // TODO: can the dimension of the result differ after optimizations? + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, + {operands[0]}); + + std::vector originalLoops; + KrnlOptimizeLoopsOp optimizedLoopsOp; + KrnlIterateOp iterateOp; + emitKrnlLoopsAndIterationForOperand( + rewriter, loc, operands[0], originalLoops, + optimizedLoopsOp, iterateOp); + Block &optimizationBlock = optimizedLoopsOp.region().front(); + Block &iterationBlock = iterateOp.bodyRegion().front(); + + // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. + rewriter.setInsertionPointToEnd(&optimizationBlock); + // Return from KrnlOptimizeLoopsOp body. + // When no optimizations are present we just return the loops + // unchaged. + rewriter.create(loc, originalLoops); + + // 2. Insert instructions inside the KernelIterateOp body. + rewriter.setInsertionPointToStart(&iterationBlock); + + // Handle the operation: + SmallVector loopIVs; + for (auto arg : iterationBlock.getArguments()) + loopIVs.push_back(arg); + + auto loadedVal = rewriter.create(loc, operands[0], loopIVs); + auto loweredOpResult = mapToLowerScalarOp( + op, memRefType.getElementType(), {loadedVal}, rewriter); + // Store result in the resulting array. + rewriter.create(loc, loweredOpResult, alloc, loopIVs); + + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +// Element-wise variadic ops lowering to Krnl dialect. +//===----------------------------------------------------------------------===// +template +struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { + ONNXElementwiseVariadicOpLowering(MLIRContext *ctx) + : ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {} + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // TODO: Check that the types are valid. + // An element-wise variadic operation must have all operands and the result + // of the same type. This should have been verified by the verifier. + auto loc = op->getLoc(); + auto numArgs = op->getNumOperands(); + + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertToMemRefType(*op->result_type_begin()); + + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + // If the output has a dynamic dimension, we compute its dimension at + // runtime by using dimensions from the operands. + // In particular, we need to know from which operand a result dimension + // comes from. + // TODO: can the dimension of the result differ after optimizations? + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, + operands); + + // Get run-time dimension information for unknown dimensions used for + // broadcasting. + std::map> broadcastedDimInfo = + getBroadcastedDimInfo(loc, rewriter, memRefType, operands); + + std::vector originalLoops; + KrnlOptimizeLoopsOp optimizedLoopsOp; + KrnlIterateOp iterateOp; + emitKrnlLoopsAndIterationForOperand( + rewriter, loc, alloc, originalLoops, + optimizedLoopsOp, iterateOp); + Block &optimizationBlock = optimizedLoopsOp.region().front(); + Block &iterationBlock = iterateOp.bodyRegion().front(); + + // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. + rewriter.setInsertionPointToEnd(&optimizationBlock); + // Return from KrnlOptimizeLoopsOp body. + // When no optimizations are present we just return the loops unchaged. + rewriter.create(loc, originalLoops); + + // 2. Insert instructions inside the KernelIterateOp body. + rewriter.setInsertionPointToStart(&iterationBlock); + + // Handle the operation: + SmallVector loopIVs; + for (auto arg : iterationBlock.getArguments()) + loopIVs.push_back(arg); + + // Fold over operands for each of their scalar values + Value accumulated, next; + auto accumulatedLoopIVs = getLoopIVsForBroadcasting( + loc, rewriter, loopIVs, operands[0], broadcastedDimInfo[0]); + accumulated = rewriter.create(loc, operands[0], accumulatedLoopIVs); + for (unsigned i = 1; i < numArgs; i++) { + auto nextLoopIVs = getLoopIVsForBroadcasting( + loc, rewriter, loopIVs, operands[i], broadcastedDimInfo[i]); + next = rewriter.create(loc, operands[i], nextLoopIVs); + accumulated = mapToLowerScalarOp( + op, memRefType.getElementType(), {accumulated, next}, rewriter); + } + // Store result in the resulting array. + rewriter.create(loc, accumulated, alloc, loopIVs); + + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +struct ONNXSoftmaxOpLowering : public ConversionPattern { + ONNXSoftmaxOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {} + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // softmax(x) = let max_x = max(x) in + // let exp_x = exp(x - max_x) in + // let sum = sum(exp_x) in + // exp_x / sum + auto memRefType = convertToMemRefType(*op->result_type_begin()); + int64_t rank = memRefType.getRank(); + int64_t axis = llvm::dyn_cast(op).axis().getSExtValue(); + axis = axis >= 0 ? axis : rank + axis; + assert(axis >= -rank && axis <= rank - 1); + + auto loc = op->getLoc(); + + // Insert an allocation and deallocation for the result of this operation. + auto elementType = memRefType.getElementType(); + + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, + operands[0]); + + // Shape of the result + auto memRefShape = memRefType.getShape(); + + // Insert allocations and deallocations for sum and max. + MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0); + Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true); + Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true); + Value zero = + rewriter.create(loc, FloatAttr::get(elementType, 0)); + Value negInfinity = rewriter.create( + loc, + FloatAttr::get(elementType, -std::numeric_limits::infinity())); + + // Define loops. + std::vector originalLoops; + std::vector optimizedLoops; + Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, + optimizedLoops, rank); + + // Coerce the input into a 2-D tensor. `axis` will be the coercing point. + // This coercing follows the softmax definition in ONNX: + // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softmax + // Here, we create an outer loop and inner loop for handling the two + // dimensions. The outer loop is only created once `axis` is not zero. + + // Define an outer loop with respect to axis. + std::vector outerLoops, optimizedOuterLoops; + outerLoops.reserve(axis); + optimizedOuterLoops.reserve(axis); + for (int i = 0; i < axis; ++i) { + outerLoops.push_back(originalLoops[i]); + optimizedOuterLoops.push_back(optimizedLoops[i]); + } + KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops); + for (int i = 0; i < axis; ++i) + addDimensionToPack(rewriter, loc, outerPack, operands[0], i); + + // Define an inner loop with respect to axis. + std::vector innerLoops, optimizedInnerLoops; + innerLoops.reserve(rank - axis); + optimizedInnerLoops.reserve(rank - axis); + for (int i = axis; i < rank; ++i) { + innerLoops.push_back(originalLoops[i]); + optimizedInnerLoops.push_back(optimizedLoops[i]); + } + KrnlIterateOperandPack innerPack(rewriter, innerLoops, optimizedInnerLoops); + for (int i = axis; i < rank; ++i) + addDimensionToPack(rewriter, loc, innerPack, operands[0], i); + + KrnlIterateOp outerIterateOp, maxIterateOp, sumIterateOp, softmaxIterateOp; + SmallVector outerLoopIVs; + if (axis != 0) { + outerIterateOp = rewriter.create(loc, outerPack); + + // No optimization + rewriter.setInsertionPointToEnd(optimizationBlock); + rewriter.create(loc, originalLoops); + + // Insert instructions inside the outer loop. + Block &outerIterationBlock = outerIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&outerIterationBlock); + for (auto arg : outerIterationBlock.getArguments()) + outerLoopIVs.push_back(arg); + + // Reset accumulators. + rewriter.create(loc, zero, sumOp); + rewriter.create(loc, negInfinity, maxOp); + + // Create an inner loop to compute max. + maxIterateOp = rewriter.create(loc, innerPack); + // Create an inner loop to compute sum. + sumIterateOp = rewriter.create(loc, innerPack); + // Create an inner loop to compute softmax. + softmaxIterateOp = rewriter.create(loc, innerPack); + } else { + // Reset accumulators. + rewriter.create(loc, zero, sumOp); + rewriter.create(loc, negInfinity, maxOp); + + // Create an inner loop to compute max. + maxIterateOp = rewriter.create(loc, innerPack); + // Create an inner loop to compute sum. + sumIterateOp = rewriter.create(loc, innerPack); + // Create an inner loop to compute softmax. + softmaxIterateOp = rewriter.create(loc, innerPack); + + // No optimization + rewriter.setInsertionPointToEnd(optimizationBlock); + rewriter.create(loc, originalLoops); + } + + // Insert instructions inside the max loop. + Block &maxIterationBlock = maxIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&maxIterationBlock); + + // Get induction variables. + SmallVector maxLoopIVs; + for (auto arg : outerLoopIVs) + maxLoopIVs.push_back(arg); + for (auto arg : maxIterationBlock.getArguments()) + maxLoopIVs.push_back(arg); + + // Compute the max value. + Value max = rewriter.create(loc, maxOp); + Value nextMax = rewriter.create(loc, operands[0], maxLoopIVs); + auto maxCond = + rewriter.create(loc, CmpFPredicate::OGT, max, nextMax); + max = rewriter.create(loc, maxCond, max, nextMax); + rewriter.create(loc, max, maxOp); + + // Get the max. + rewriter.setInsertionPoint(sumIterateOp); + max = rewriter.create(loc, maxOp); + + // Insert instructions inside the sum loop. + Block &sumIterationBlock = sumIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&sumIterationBlock); + + // Get induction variables. + SmallVector sumLoopIVs; + for (auto arg : outerLoopIVs) + sumLoopIVs.push_back(arg); + for (auto arg : sumIterationBlock.getArguments()) + sumLoopIVs.push_back(arg); + + // Sum up values. + Value sum = rewriter.create(loc, sumOp); + Value next = rewriter.create(loc, operands[0], sumLoopIVs); + Value sub = rewriter.create(loc, next, max); + Value exp = rewriter.create(loc, sub); + sum = rewriter.create(loc, sum, exp); + rewriter.create(loc, sum, sumOp); + // Store intermediate values in the result to avoid recomputation. + rewriter.create(loc, exp, alloc, sumLoopIVs); + + // Get the sum. + rewriter.setInsertionPoint(softmaxIterateOp); + sum = rewriter.create(loc, sumOp); + + // Insert instructions inside the softmax loop. + Block &softmaxIterationBlock = softmaxIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&softmaxIterationBlock); + + // Get induction variables. + SmallVector softmaxLoopIVs; + for (auto arg : outerLoopIVs) + softmaxLoopIVs.push_back(arg); + for (auto arg : softmaxIterationBlock.getArguments()) + softmaxLoopIVs.push_back(arg); + + // Compute softmax. + Value expLoadedVal = rewriter.create(loc, alloc, softmaxLoopIVs); + Value result = rewriter.create(loc, expLoadedVal, sum); + rewriter.create(loc, result, alloc, softmaxLoopIVs); + + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +struct ONNXReshapeOpLowering : public ConversionPattern { + ONNXReshapeOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + + auto memRefType = convertToMemRefType(*op->result_type_begin()); + auto memRefShape = memRefType.getShape(); + auto inputShape = operands[0].getType().cast().getShape(); + + // Insert an allocation and deallocation for the result of this operation. + Value alloc; + + // Compute size in bytes using the input tensor. + Value tensorSize = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), + getMemRefEltSizeInBytes(memRefType))); + for (int i = 0; i < inputShape.size(); ++i) { + Value dimVal; + if (inputShape[i] < 0) { + Value dim = rewriter.create(loc, operands[0], i); + dimVal = + rewriter.create(loc, dim, rewriter.getIntegerType(64)); + } else { + dimVal = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), + inputShape[i])); + } + tensorSize = rewriter.create(loc, tensorSize, dimVal); + } + + bool insertDealloc = checkInsertDealloc(op); + if (hasAllConstantDimensions(memRefType)) { + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + } else { + // If a dimension is zero, the actual dimension value is taken from the + // input tensor. + // + // If the shape array has a negative dimension (-1), we compute its actual + // dimension value from the other dimensions. But we don't have enough + // information about the other dimensions at this point. So, we need to + // scan the shape first to calculate reduction of all of the dimensions. + // If the reduction is negative, then the shape array contains a negative + // dimension. Otherwise, the reduction is the same as the one computed + // from the input tensor. + Value tensorSizeFromShape = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), + getMemRefEltSizeInBytes(memRefType))); + SmallVector DimInfo; + for (int i = 0; i < memRefShape.size(); ++i) { + Value index = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i)); + // Load index from array of indices. + Value loadedVal = rewriter.create(loc, operands[1], index); + // If a dimension is zero, the actual dimension value is taken from the + // input tensor. + // + // If a dimension is negative, it is computed from the other dimensions. + // But we don't have enough information about the other dimensions at + // this point. So, we let it as it is (-1), and compute it later. + if (i < inputShape.size()) { + Value dimVal; + auto loadedValType = loadedVal.getType().cast(); + if (inputShape[i] < 0) { + Value dim = rewriter.create(loc, operands[0], i); + dimVal = rewriter.create(loc, dim, loadedValType); + } else { + dimVal = rewriter.create( + loc, rewriter.getIntegerAttr(loadedValType, inputShape[i])); + } + auto zero = rewriter.create( + loc, rewriter.getIntegerAttr(loadedValType, 0)); + auto isZero = + rewriter.create(loc, CmpIPredicate::eq, loadedVal, zero); + loadedVal = rewriter.create(loc, isZero, dimVal, loadedVal); + } + // Check if the loaded index is already the correct width of 64 bits. + // Convert the value to a 64 bit integer if needed. + Value int64LoadedVal = loadedVal; + if (loadedVal.getType().cast().getWidth() < 64) + int64LoadedVal = rewriter.create( + loc, loadedVal, rewriter.getIntegerType(64)); + tensorSizeFromShape = + rewriter.create(loc, tensorSizeFromShape, int64LoadedVal); + // Store intermediate results to use later. + DimInfo.emplace_back(int64LoadedVal); + } + // Reverse tensorSizeFromShape since it is negative if the shape array has + // a negative dimension. This is safe since we only use it to compute the + // actual value for the negative dimension. + auto zero = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + tensorSizeFromShape = + rewriter.create(loc, zero, tensorSizeFromShape); + + // Obtain operands for AllocOp. + SmallVector allocOperands; + auto negOne = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), -1)); + + for (int i = 0; i < memRefShape.size(); ++i) { + auto dimVal = DimInfo[i]; + auto isNegOne = + rewriter.create(loc, CmpIPredicate::eq, dimVal, negOne); + // If dimension is negative, compute its value from the other + // dimensions. + auto actualDimVal = + rewriter.create(loc, tensorSize, tensorSizeFromShape); + auto loadedVal = + rewriter.create(loc, isNegOne, actualDimVal, dimVal); + allocOperands.push_back(rewriter.create( + loc, loadedVal, rewriter.getIndexType())); + } + AllocOp allocateMemref = + rewriter.create(loc, memRefType, allocOperands); + + // Make sure to allocate at the beginning of the block if + // all dimensions are known. + auto *parentBlock = allocateMemref.getOperation()->getBlock(); + if (insertDealloc) { + auto dealloc = rewriter.create(loc, allocateMemref); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + } + + alloc = allocateMemref; + } + + rewriter.create(loc, alloc, operands[0], tensorSize); + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +struct ONNXGemmOpLowering : public ConversionPattern { + ONNXGemmOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXGemmOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + + Value A, B, C; + A = operands[0]; + B = operands[1]; + C = operands[2]; + + auto memRefType = convertToMemRefType(*op->result_type_begin()); + + auto alphaAttr = FloatAttr::get(memRefType.getElementType(), + llvm::dyn_cast(op).alpha().convertToFloat()); + auto betaAttr = FloatAttr::get(memRefType.getElementType(), + llvm::dyn_cast(op).beta().convertToFloat()); + auto alpha = rewriter.create(loc, alphaAttr); + auto beta = rewriter.create(loc, betaAttr); + + bool isTransA = (llvm::dyn_cast(op).transA() != 0); + bool isTransB = (llvm::dyn_cast(op).transB() != 0); + + // Insert an allocation and deallocation for the result of this operation. + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else { + auto memRefShape = memRefType.getShape(); + SmallVector allocOperands; + if (memRefShape[0] < 0) { + auto dim = rewriter.create(loc, A, (isTransA) ? 1 : 0); + allocOperands.emplace_back(dim); + } + if (memRefShape[1] < 0) { + auto dim = rewriter.create(loc, B, (isTransB) ? 0 : 1); + allocOperands.emplace_back(dim); + } + alloc = rewriter.create(loc, memRefType, allocOperands); + if (insertDealloc) { + auto *parentBlock = alloc.getDefiningOp()->getBlock(); + auto dealloc = rewriter.create(loc, alloc); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + } + } + + // Number of loops + auto memRefShape = memRefType.getShape(); + int64_t numLoops = 3; + + // Define loops. + std::vector originalLoops; + std::vector optimizedLoops; + Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, + optimizedLoops, numLoops); + + // We have two Krnl loops: + // - Outer loop iterates over the output matrix dimensions, and + // - Reduction loop iterates over the reduction dimension. + + // Outer loop + std::vector outerLoops, optimizedOuterLoops; + outerLoops.reserve(2); + optimizedOuterLoops.reserve(2); + for (int i = 0; i < 2; ++i) { + outerLoops.push_back(originalLoops[i]); + optimizedOuterLoops.push_back(optimizedLoops[i]); + } + KrnlIterateOperandPack outerPack(rewriter, outerLoops, + optimizedOuterLoops); + // Induction variables for the outer loops + for (int i = 0; i < 2; ++i) + addDimensionToPack(rewriter, loc, outerPack, alloc, i); + + // Reduction loop + std::vector reductionLoops, optimizedReductionLoops; + reductionLoops.reserve(1); + optimizedReductionLoops.reserve(1); + reductionLoops.push_back(originalLoops[2]); + optimizedReductionLoops.push_back(optimizedLoops[2]); + KrnlIterateOperandPack reductionPack(rewriter, reductionLoops, + optimizedReductionLoops); + // Induction variable for the reduction dimension + // Try to find and use a static value from A or B first. + // If it failed then use a dynamic value. + auto ATy = A.getType().cast(); + auto BTy = B.getType().cast(); + int64_t K_A_Idx = (isTransA) ? 0 : 1; + int64_t K_B_Idx = (isTransB) ? 1 : 0; + reductionPack.pushConstantBound(0); + if (ATy.getShape()[K_A_Idx] != -1) + reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]); + else + if (BTy.getShape()[K_B_Idx] != -1) + reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]); + else + reductionPack.pushOperandBound( + rewriter.create(loc, B, K_B_Idx).getResult()); + + // Get run-time dimension information for unknown dimensions used for + // broadcasting. + // GemmOp supports unidirectional broadcasting from C to A*B. + // Hence, it must be enough to get broadcasting information for C only. + std::map broadcastedDimInfo; + auto shape = C.getType().cast().getShape(); + for (int i = 0; i < shape.size(); ++i) { + if (shape[i] < 0) { + auto dim = rewriter.create(loc, C, i).getResult(); + auto one = rewriter.create(loc, 1); + auto isBroadcasted = + rewriter.create(loc, CmpIPredicate::eq, dim, one); + broadcastedDimInfo.insert(std::make_pair(i, isBroadcasted)); + } + } + + auto outerIterateOp = rewriter.create(loc, outerPack); + + // Now perform the insertions into the body of the + // just generated instructions: + + // No optimization + rewriter.setInsertionPointToEnd(optimizationBlock); + rewriter.create(loc, originalLoops); + + // Insert instructions inside the outer loop. + Block &outerIterationBlock = outerIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&outerIterationBlock); + + // Induction variables + SmallVector loopMNIVs; + for (auto arg : outerIterationBlock.getArguments()) { + loopMNIVs.emplace_back(arg); + } + + // Initialize the output of A*B + auto zero = rewriter.create( + loc, FloatAttr::get(memRefType.getElementType(), 0)); + rewriter.create(loc, zero, alloc, loopMNIVs); + + // Compute A*B + auto matmulIterateOp = rewriter.create(loc, reductionPack); + + // Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting) + auto loopCIVs = getLoopIVsForBroadcasting( + loc, rewriter, loopMNIVs, C, broadcastedDimInfo); + auto loadedC = rewriter.create(loc, C, loopCIVs); + auto loadedAB = rewriter.create(loc, alloc, loopMNIVs); + auto alphaAB = rewriter.create(loc, alpha, loadedAB); + auto betaC = rewriter.create(loc, beta, loadedC); + auto Y = rewriter.create(loc, alphaAB, betaC); + rewriter.create(loc, Y, alloc, loopMNIVs); + + // Insert instructions to do matrix multiplication: A*B + Block &matmulIterationBlock = matmulIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&matmulIterationBlock); + + // Induction variables + SmallVector loopKIVs, loopAIVs, loopBIVs; + for (auto arg : matmulIterationBlock.getArguments()) + loopKIVs.emplace_back(arg); + if (isTransA) { + loopAIVs.emplace_back(loopKIVs[0]); + loopAIVs.emplace_back(loopMNIVs[0]); + } else { + loopAIVs.emplace_back(loopMNIVs[0]); + loopAIVs.emplace_back(loopKIVs[0]); + } + if (isTransB) { + loopBIVs.emplace_back(loopMNIVs[1]); + loopBIVs.emplace_back(loopKIVs[0]); + } else { + loopBIVs.emplace_back(loopKIVs[0]); + loopBIVs.emplace_back(loopMNIVs[1]); + } + + // Matmul computation + auto loadedA = rewriter.create(loc, A, loopAIVs); + auto loadedB = rewriter.create(loc, B, loopBIVs); + auto loadedY = rewriter.create(loc, alloc, loopMNIVs); + auto AB = rewriter.create(loc, loadedA, loadedB); + auto accumulated = rewriter.create(loc, loadedY, AB); + rewriter.create(loc, accumulated, alloc, loopMNIVs); + + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +struct ONNXUnsqueezeOpLowering : public ConversionPattern { + ONNXUnsqueezeOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + auto memRefType = convertToMemRefType(*op->result_type_begin()); + int outRank = memRefType.getRank(); + + // Assume that `axes` has been validated by shape inference. + // So, here we just get it. + ArrayAttr axisAttrs = llvm::dyn_cast(op).axesAttr(); + SmallVector axes; + for (auto axisAttr : axisAttrs.getValue()) { + int axis = axisAttr.cast().getInt(); + axis = axis >= 0 ? axis : (outRank + axis); + axes.emplace_back(axis); + } + + // Insert an allocation and deallocation for the result of this operation. + Value alloc; + + // Compute size in bytes. + Value tensorSize = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), + getMemRefEltSizeInBytes(memRefType))); + + bool insertDealloc = checkInsertDealloc(op); + auto memRefShape = memRefType.getShape(); + if (hasAllConstantDimensions(memRefType)) { + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + for (int i = 0; i < memRefShape.size(); ++i) { + Value dimVal = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), + memRefShape[i])); + tensorSize = rewriter.create(loc, tensorSize, dimVal); + } + } else { + // Unknown dimensions are always the operand's dimensions. + SmallVector allocOperands; + for (int outIdx = 0, inIdx = 0; outIdx < memRefShape.size(); ++outIdx) { + Value dimVal = nullptr; + if (memRefShape[outIdx] < 0) { + Value index = rewriter.create(loc, operands[0], inIdx); + dimVal = rewriter.create( + loc, index, rewriter.getIntegerType(64)); + allocOperands.emplace_back(index); + } else { + dimVal = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), + memRefShape[outIdx])); + } + tensorSize = rewriter.create(loc, tensorSize, dimVal); + if (std::find(axes.begin(), axes.end(), outIdx) == axes.end()) + inIdx++; + } + alloc = rewriter.create(loc, memRefType, allocOperands); + auto *parentBlock = alloc.getDefiningOp()->getBlock(); + if (insertDealloc) { + auto dealloc = rewriter.create(loc, alloc); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + } + } + rewriter.create(loc, alloc, operands[0], tensorSize); + rewriter.replaceOp(op, alloc); + return matchSuccess(); + } +}; + +struct ONNXTransposeOpLowering : public ConversionPattern { + ONNXTransposeOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertToMemRefType(*op->result_type_begin()); + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, + {operands[0]}); + + // Number of loops + auto memRefShape = memRefType.getShape(); + int64_t rank = memRefShape.size(); + + // Define loops. + std::vector originalLoops; + std::vector optimizedLoops; + Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, + optimizedLoops, rank); + + KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); + // Iterate over the loop nest using the input shape. + for (int i = 0; i < rank; ++i) + addDimensionToPack(rewriter, loc, pack, operands[0], i); + + auto iterateOp = rewriter.create(loc, pack); + Block &iterationBlock = iterateOp.bodyRegion().front(); + + // Now perform the insertions into the body of the + // just generated instructions: + + // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. + rewriter.setInsertionPointToEnd(optimizationBlock); + // Return from KrnlOptimizeLoopsOp body. + // When no optimizations are present we just return the loops + // unchaged. + rewriter.create(loc, originalLoops); + + // 2. Insert instructions inside the KernelIterateOp body. + rewriter.setInsertionPointToStart(&iterationBlock); + + // Handle the operation. + + // Read perm attribute. + SmallVector perm; + auto permAttribute = llvm::dyn_cast(op).permAttr(); + if (permAttribute) { + for (auto permVal : permAttribute.getValue()) + perm.emplace_back(permVal.cast().getInt()); + } else { + // TODO: Remove when perm is guaranteed to be present (even for + // the default case). This means that perm was added by shape + // inference or another pass to contain the values corresponding + // to the default behavior of Transpose. + for (int i = iterationBlock.getArguments().size()-1; i >= 0; i--) + perm.emplace_back(i); + } + + SmallVector inLoopIVs; + for (auto arg : iterationBlock.getArguments()) + inLoopIVs.emplace_back(arg); + + SmallVector outLoopIVs; + for (int i=0; i(loc, operands[0], inLoopIVs); + rewriter.create(loc, inVal, alloc, outLoopIVs); + + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +struct ONNXIdentityOpLowering : public ConversionPattern { + ONNXIdentityOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + rewriter.replaceOp(op, operands[0]); + return matchSuccess(); + } +}; + +struct ONNXConvNoBiasOpLowering : public ConversionPattern { + ONNXConvNoBiasOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertToMemRefType(*op->result_type_begin()); + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + ONNXConvNoBiasOp convOp = llvm::dyn_cast(op); + + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, + {operands[0]}); + + auto resultShape = memRefType.getShape(); + auto inputShape = operands[0].getType().cast().getShape(); + auto kernelShape = operands[1].getType().cast().getShape(); + + // R = ConvNoBias(D, K) + // + // The input/output shapes will look like this: + // + // D (NxCxHxW) x K (MxC/groupxKHxKW) -> R (NxMxRHxRW) + // + // M is a multiple of the number of groups: + // M = group * kernelsPerGroup + // + // The loop nest will look as follows: + // + // strides = [s1, s2] + // + // kernelsPerGroup = M / group; + // for n = 0 .. N: + // for g = 0 .. group: + // for m = 0 .. kernelsPerGroup: + // kernel = g * kernelsPerGroup + m; + // for r1 = 0 .. RH: + // for r2 = 0 .. RW: + // R[n][kernel][r1][r2] = 0; + // for c = 0 .. C/group: + // for k1 = 0 .. KH: + // for k2 = 0 .. KW: + // R[n][kernel][r1][r2] = + // D[n][g * (C / group) + c][s1 * r1 + k1][s2 * r2 + k2] * + // K[kernel][c][k1][k2]; + // + // Naming: + // n, g, m: outer loop nest indices + // r1, r2: spatial loop nest indices + // c, k1, k2: inner loop nest indices + // + // TODO: handle padding. + // + // In the general case: + // + // D (NxCxD1xD2x...xDdim) x K (MxC/groupxK1xK2x...xKdim) + // -> R (NxMxR1xR2x...xRdim) + // + // The above loop nest can be adapted by increasing the number + // of r- and k-index loop i.e. r1 r2 and k1 k2 loops. + + // Set up outermost loops: n g m r1 r2 ... rdim + // Skip g if group is 1. + + // Before we start the iteration we need to compute the number of + // unsplit kernels and fetch the number of groups from the attribute + // list. Group is always a compilation constant. + int64_t group = convOp.group().getSExtValue(); + // Compute the number of unsplit kernels. The number of kernels + // must be a multiple of the number of groups. + int64_t kernelsPerGroup = floor(kernelShape[0] / group); + auto kernelsPerGroupValue = + rewriter.create(loc, kernelsPerGroup); + auto zero = rewriter.create( + loc, FloatAttr::get(memRefType.getElementType(), 0)); + Value subchannels; + if (kernelShape[1] < 0) { + subchannels = + rewriter.create(loc, operands[1], 1).getResult(); + } else { + subchannels = rewriter.create( + loc, kernelShape[1]); + } + + // 1. Define outer loops and emit empty optimization block: + int64_t nOuterLoops = (group > 1) ? 3 : 2; + std::vector outerLoops; + std::vector optimizedOuterLoops; + Block *optimizationBlock = defineLoops(rewriter, loc, outerLoops, + optimizedOuterLoops, nOuterLoops); + + // Prepare iteration arguments over outer loop nest. + KrnlIterateOperandPack pack( + rewriter, outerLoops, optimizedOuterLoops); + // for n = 0 .. N: + pack.pushConstantBound(0); + if (inputShape[0] < 0) + pack.pushOperandBound( + rewriter.create(loc, operands[0], 0).getResult()); + else + pack.pushConstantBound(inputShape[0]); + // for g = 0 .. N: + if (group > 1) { + pack.pushConstantBound(0); + pack.pushConstantBound(group); + } + // for m = 0 .. kernelsPerGroup: + pack.pushConstantBound(0); + pack.pushConstantBound(kernelsPerGroup); + // Outer loop iteration. + auto iterateOp = rewriter.create(loc, pack); + Block &outerIterationBlock = iterateOp.bodyRegion().front(); + // Emit optimizations for outer loops: + rewriter.setInsertionPointToEnd(optimizationBlock); + rewriter.create(loc, outerLoops); + rewriter.setInsertionPointToStart(&outerIterationBlock); + { + // 2. Emit the body of the outer loop nest. + + // 2.1 Compute kernel order number: kernel = g * kernelsPerGroup + m; + // If group is not set then the value of the kernel ID is + // identical to that of the loop over kernels. + Value kernel = outerIterationBlock.getArguments()[1]; + if (group > 1) { + // Middle loop is over groups and third loop is over the + // kernel identifiers in the current group. + auto kernelsOffset = rewriter.create(loc, + outerIterationBlock.getArguments()[1], + kernelsPerGroupValue); + kernel = rewriter.create(loc, kernelsOffset, + outerIterationBlock.getArguments()[2]); + } + + // 2.2 Define spatial loops + int64_t nSpatialLoops = resultShape.size() - 2; + std::vector spatialLoops; + std::vector optimizedSpatialLoops; + Block *optSpatialLoopBlock = defineLoops(rewriter, loc, spatialLoops, + optimizedSpatialLoops, nSpatialLoops); + + // 2.3 Prepare iteration arguments for spatial loop nest. + KrnlIterateOperandPack spatialPack( + rewriter, spatialLoops, optimizedSpatialLoops); + for (int i = 2; i < resultShape.size(); ++i) + addDimensionToPack(rewriter, loc, spatialPack, alloc, i); + + // 2.4 Emit loop nest over output spatial dimensions. + // for rX = 0 .. RX + auto spatialIterateOp = + rewriter.create(loc, spatialPack); + Block &spatialIterationBlock = spatialIterateOp.bodyRegion().front(); + // 2.5 Emit optimizations for outer loops: + rewriter.setInsertionPointToEnd(optSpatialLoopBlock); + rewriter.create(loc, spatialLoops); + rewriter.setInsertionPointToStart(&spatialIterationBlock); + { + // 3. Emit the body of the spatial loop nest. + // 3.1 Emit: R[n][kernel][r1][r2] = 0; + SmallVector resultIndices; + // n + resultIndices.emplace_back(outerIterationBlock.getArguments()[0]); + // kernel + resultIndices.emplace_back(kernel); + // rX + for (auto arg : spatialIterationBlock.getArguments()) + resultIndices.emplace_back(arg); + // Store initializer value into output location. + rewriter.create(loc, zero, alloc, resultIndices); + + // 3.2 Define inner loops. + int64_t nInnerLoops = 1 + (kernelShape.size() - 2); + std::vector innerLoops; + std::vector optimizedInnerLoops; + Block *optInnerLoopBlock = defineLoops(rewriter, loc, innerLoops, + optimizedInnerLoops, nInnerLoops); + + // 3.3 Prepare iteration arguments for inner loop nest. + KrnlIterateOperandPack innerPack( + rewriter, innerLoops, optimizedInnerLoops); + // for c = 0 .. C/group + innerPack.pushConstantBound(0); + innerPack.pushConstantBound(kernelShape[1]); + // for Kx = 0 .. KX + for (int i = 2; i < kernelShape.size(); ++i) + addDimensionToPack(rewriter, loc, innerPack, operands[1], i); + + // 3.4 Emit inner loop nest. + auto innerIterateOp = + rewriter.create(loc, innerPack); + Block &innerIterationBlock = innerIterateOp.bodyRegion().front(); + // 3.5 Emit optimizations for outer loops: + rewriter.setInsertionPointToEnd(optInnerLoopBlock); + rewriter.create(loc, innerLoops); + rewriter.setInsertionPointToStart(&innerIterationBlock); + { + // 4. Emit inner loop body + // R[n][kernel][r1][r2] = + // D[n][g * (C / group) + c][s1 * r1 + k1][s2 * r2 + k2] * + // K[kernel][c][k1][k2]; + + // 4.1 Prepare indices for accesing the data tensor. + SmallVector dataIndices; + // n + dataIndices.emplace_back(outerIterationBlock.getArguments()[0]); + // g * (C / group) + c + Value channelDepth = innerIterationBlock.getArguments()[0]; + if (group > 1) + channelDepth = rewriter.create(loc, channelDepth, + rewriter.create(loc, subchannels, + outerIterationBlock.getArguments()[1])); + dataIndices.emplace_back(channelDepth); + // sX * rX + kX + auto stridesAttribute = convOp.stridesAttr(); + // Read strides attribute + SmallVector strides; + if (stridesAttribute) + for (auto stride : stridesAttribute.getValue()) + strides.emplace_back(stride.cast().getInt()); + for (int i = 0; i < kernelShape.size() - 2; ++i) { + Value spatialIndex = spatialIterationBlock.getArguments()[i]; + // If strides are present then emit the correct access index. + if (stridesAttribute && strides[i] > 1) + spatialIndex = rewriter.create(loc, + rewriter.create(loc, strides[i]), + spatialIterationBlock.getArguments()[i]); + dataIndices.emplace_back( + rewriter.create(loc, spatialIndex, + innerIterationBlock.getArguments()[i+1])); + } + + // 4.2 Prepare indices for accessing the kernel tensor. + SmallVector kernelIndices; + // kernel + kernelIndices.emplace_back(kernel); + // c + kernelIndices.emplace_back(innerIterationBlock.getArguments()[0]); + // kX + for (int i = 0; i < kernelShape.size() - 2; ++i) + kernelIndices.emplace_back( + innerIterationBlock.getArguments()[i+1]); + + // 4.3 Compute convolution. + auto loadData = + rewriter.create(loc, operands[0], dataIndices); + auto loadKernel = + rewriter.create(loc, operands[1], kernelIndices); + auto loadPartialSum = + rewriter.create(loc, alloc, resultIndices); + Value result = rewriter.create(loc, loadPartialSum, + rewriter.create(loc, loadData, loadKernel)); + // 4.4 Store computed value into output location. + rewriter.create(loc, result, alloc, resultIndices); + } + } + } + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +//===----------------------------------------------------------------------===// +// Reduction ops lowering to Krnl dialect. +//===----------------------------------------------------------------------===// +template +struct ONNXReductionOpLowering : public ConversionPattern { + ONNXReductionOpLowering(MLIRContext *ctx) + : ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + /* + * Condition: reduction function must be associative and commutative. + * + * Example 1 (here, reduction function is `+`): + * Induction variables: (i0, i1, i2) + * axes = [0, 2] + * keepdims = true + * krnl.iterate() with (i0, i1, i2) { + * Y(0, i1, 0) += X(i0, i1, i2) + * } + * + * Example 2 (here, reduction function is `+`): + * Induction variables: (i0, i1, i2) + * axes = [0, 2] + * keepdims = false + * krnl.iterate() with (i0, i1, i2) { + * Y(i1) += X(i0, i1, i2) + * } + * + */ + auto loc = op->getLoc(); + auto memRefInType = operands[0].getType().cast(); + auto memRefInShape = memRefInType.getShape(); + auto memRefOutType = convertToMemRefType(*op->result_type_begin()); + int64_t inRank = memRefInType.getRank(); + int64_t outRank = memRefOutType.getRank(); + + // Get attributes + ArrayAttr axisAttrs = llvm::dyn_cast(op).axesAttr(); + std::vector axes; + if (axisAttrs) { + for (auto axisAttr : axisAttrs.getValue()) { + int64_t axis = axisAttr.cast().getInt(); + axis = axis >= 0 ? axis : (inRank + axis); + assert(axis >= -inRank && axis <= inRank - 1); + if (std::find(axes.begin(), axes.end(), axis) == axes.end()) + axes.push_back(axis); + } + } else { + for (decltype(inRank) i = 0; i < inRank; ++i) { + axes.push_back(i); + } + } + // KeepDims + auto keepdims = + llvm::dyn_cast(op).keepdims(); + bool isKeepdims = (keepdims == 1) ? true : false; + + // Get type information + auto memRefOutShape = memRefOutType.getShape(); + auto elementOutType = memRefOutType.getElementType(); + std::map outInDimMap = + getReductionMapping(memRefInType, axes, isKeepdims); + + // Insert an allocation and deallocation for the result of this operation. + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + if (hasAllConstantDimensions(memRefOutType)) { + alloc = insertAllocAndDealloc(memRefOutType, loc, rewriter, insertDealloc); + } else { + SmallVector allocOperands; + for (decltype(outRank) i = 0; i < outRank; ++i) { + if (memRefOutShape[i] < 0) { + auto dim = rewriter.create(loc, operands[0], outInDimMap[i]); + allocOperands.push_back(dim); + } + } + alloc = rewriter.create(loc, memRefOutType, allocOperands); + if (insertDealloc) { + auto *parentBlock = alloc.getDefiningOp()->getBlock(); + auto dealloc = rewriter.create(loc, alloc); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + } + } + + // There are two Krnl loops: + // - One to initialize the result memref, and + // - One to do reduction + + // Define loops to initialize the result. + std::vector originalLoopsInit; + std::vector optimizedLoopsInit; + Block *optimizationBlockInit = defineLoops(rewriter, loc, originalLoopsInit, + optimizedLoopsInit, outRank); + + // Iteration information + KrnlIterateOperandPack packInit(rewriter, originalLoopsInit, + optimizedLoopsInit); + for (decltype(outRank) i = 0; i < outRank; ++i) { + addDimensionToPack(rewriter, loc, packInit, alloc, i); + } + auto iterateOpInit = rewriter.create(loc, packInit); + Block &iterationBlockInit = iterateOpInit.bodyRegion().front(); + + // Perform the insertions into the body of the initialization loop. + // No optimization + rewriter.setInsertionPointToEnd(optimizationBlockInit); + rewriter.create(loc, originalLoopsInit); + + // Insert instructions inside the KernelIterateOp body. + rewriter.setInsertionPointToStart(&iterationBlockInit); + + // Handle the operation: + SmallVector loopIVs; + for (auto arg : iterationBlockInit.getArguments()) { + loopIVs.push_back(arg); + } + + Value identity; + if (elementOutType.isa()) { + identity = rewriter.create( + loc, FloatAttr::get(elementOutType, + getIdentityValue())); + } else if (elementOutType.isa()) { + identity = rewriter.create( + loc, IntegerAttr::get(elementOutType, + getIdentityValue())); + } else { + emitError(loc, "unsupported element type"); + } + rewriter.create(loc, identity, alloc, loopIVs); + + // Define an Krnl loop to do reduction. + rewriter.setInsertionPointAfter(iterateOpInit); + std::vector originalLoops, optimizedLoops; + Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, + optimizedLoops, inRank); + // Iteration information + KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); + for (decltype(inRank) i = 0; i < inRank; ++i) { + addDimensionToPack(rewriter, loc, pack, operands[0], i); + } + auto iterateOp = rewriter.create(loc, pack); + Block &iterationBlock = iterateOp.bodyRegion().front(); + + // Perform the insertions into the body of the reduction loop. + // No optimization + rewriter.setInsertionPointToEnd(optimizationBlock); + rewriter.create(loc, originalLoops); + + // Insert instructions inside the KernelIterateOp body. + rewriter.setInsertionPointToStart(&iterationBlock); + + // Handle the operation: + SmallVector inLoopIVs, outLoopIVs; + auto args = iterationBlock.getArguments(); + for (int i = 0; i < args.size(); ++i) { + inLoopIVs.push_back(args[i]); + } + Value zeroIndex = nullptr; + for (decltype(inRank) i = 0; i < outRank; ++i) { + if (outInDimMap.find(i) != outInDimMap.end()) { + outLoopIVs.push_back(inLoopIVs[outInDimMap[i]]); + } else { + if (zeroIndex) { + outLoopIVs.push_back(zeroIndex); + } else { + zeroIndex = rewriter.create(loc, 0); + outLoopIVs.push_back(zeroIndex); + } + } + } + + Value next, accumulated; + next = rewriter.create(loc, operands[0], inLoopIVs); + accumulated = rewriter.create(loc, alloc, outLoopIVs); + accumulated = mapToLowerScalarOp( + op, memRefOutType.getElementType(), {accumulated, next}, rewriter); + rewriter.create(loc, accumulated, alloc, outLoopIVs); + + rewriter.replaceOp(op, alloc); + return matchSuccess(); + } +}; + +//===----------------------------------------------------------------------===// +// EntryPoint Op lowering to Krnl Entry Point. +//===----------------------------------------------------------------------===// + +class ONNXEntryPointLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(ONNXEntryPointOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, + op.getAttrOfType( + ONNXEntryPointOp::getEntryPointFuncAttrName()), + op.getAttrOfType(ONNXEntryPointOp::getNumInputsAttrName()), + op.getAttrOfType( + ONNXEntryPointOp::getNumOutputsAttrName())); + return matchSuccess(); + } +}; + +//===----------------------------------------------------------------------===// +// Conversion from Tensor type to the Standard dialect MemRef type. +//===----------------------------------------------------------------------===// + +struct TensorTypeConverter : public TypeConverter { + using TypeConverter::TypeConverter; + + LogicalResult convertType(Type t, SmallVectorImpl &results) override { + if (auto type = convertToMemRefType(t)) { + results.push_back(type); + return success(); + } + + results.push_back(t); + return success(); + } + + /// Return true if the inputs and outputs of the given function type are + /// legal. [Taken from MLIR and adapted to only check the legality of the + /// inputs. Once unranked results can be handled gracefully this + /// override needs to be removed in favour of the original MLIR one.] + bool isSignatureLegal(FunctionType funcType) { + return llvm::all_of(funcType.getInputs(), + [this](Type type) { return isLegal(type); }); + } +}; + +} // end anonymous namespace. + +//===----------------------------------------------------------------------===// +// Frontend to Krnl Dialect lowering pass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering to Krnl loops of the ONNX operations. +namespace { +struct FrontendToKrnlLoweringPass + : public ModulePass { + void runOnModule() final; +}; +} // end anonymous namespace. + +void FrontendToKrnlLoweringPass::runOnModule() { + auto module = getModule(); + + // The first thing to define is the conversion target. This will define the + // final target for this lowering. + ConversionTarget target(getContext()); + + // We define the specific operations, or dialects, that are legal targets for + // this lowering. + target + .addLegalDialect(); + + // TODO: enable this once more ops are supported. + // We also define the ONNX dialect as Illegal so that the conversion will fail + // if any of these operations are *not* converted. + // target.addIllegalDialect(); + + // TODO: add any other ops which are considered legal. + // Some operations can be marked as being still legal. + // Example: target.addLegalOp(); + + // Now that the conversion target has been defined, we just need to provide + // the set of patterns that will lower the frontend operations. + OwningRewritePatternList patterns; + + // Convert TensorType to MemRef + TensorTypeConverter tensor_to_memref_converter; + target.addDynamicallyLegalOp([&](FuncOp op) { + // FuncOp is legal only if types have been converted to Std types. + return tensor_to_memref_converter.isSignatureLegal(op.getType()); + }); + + // Type conversion for function signatures. + // Call MLIR FuncOp signature conversion when result type is + // a ranked tensor. + populateFuncOpTypeConversionPattern(patterns, &getContext(), + tensor_to_memref_converter); + + // Frontent operation lowering. + patterns.insert, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXReshapeOpLowering, ONNXEntryPointLowering, + ONNXReductionOpLowering, + ONNXReductionOpLowering, + ONNXReductionOpLowering, + ONNXReductionOpLowering, + ONNXSoftmaxOpLowering, ONNXGemmOpLowering, + ONNXUnsqueezeOpLowering, ONNXTransposeOpLowering, + ONNXIdentityOpLowering, ONNXConvNoBiasOpLowering + >(&getContext()); + + // With the target and rewrite patterns defined, we can now attempt the + // conversion. The conversion will signal failure if any of our `illegal` + // operations were not converted successfully. + if (failed(applyPartialConversion(module, target, patterns))) + signalPassFailure(); +} + +std::unique_ptr mlir::createLowerToKrnlPass() { + return std::make_unique(); +} + +static PassRegistration + pass("lower-frontend", "Lower frontend ops to Krnl dialect."); diff --git a/src/pass/shape_inference_pass.cpp b/src/pass/shape_inference_pass.cpp index 0d4ae18..2f80ea7 100644 --- a/src/pass/shape_inference_pass.cpp +++ b/src/pass/shape_inference_pass.cpp @@ -130,6 +130,7 @@ public: op->getName().getStringRef() != "onnx.ConvNoBias" && op->getName().getStringRef() != "onnx.PadConstantPad" && op->getName().getStringRef() != "onnx.PadConstantValuePad" && + op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" && op->getName().getStringRef() != "onnx.Unsqueeze") return false; return llvm::any_of(op->getResultTypes(), [](Type result_type) { diff --git a/src/runtime/dyn_memref.cpp b/src/runtime/dyn_memref.cpp index a20001c..454947b 100644 --- a/src/runtime/dyn_memref.cpp +++ b/src/runtime/dyn_memref.cpp @@ -35,7 +35,7 @@ DynMemRef *getDynMemRef(OrderedDynMemRefDict *tensorDict, int idx) { void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx, DynMemRef *tensor) { - if (tensorDict->orderedNames.capacity() <= idx) + if (tensorDict->orderedNames.size() <= idx) tensorDict->orderedNames.resize(idx + 1); // The dynamic memref is essentially anonymous, since we are storing it by diff --git a/src/tool/onnf_opt/onnf_opt.cpp b/src/tool/onnf_opt/onnf_opt.cpp index 2311b66..597bfd4 100644 --- a/src/tool/onnf_opt/onnf_opt.cpp +++ b/src/tool/onnf_opt/onnf_opt.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -46,6 +47,10 @@ static llvm::cl::opt verify_passes( llvm::cl::init(true)); int main(int argc, char **argv) { + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); + mlir::registerDialect(); llvm::InitLLVM y(argc, argv); mlir::registerDialect(); diff --git a/src/transform/lower_to_llvm.cpp b/src/transform/lower_to_llvm.cpp index 7d01207..b4b46f4 100644 --- a/src/transform/lower_to_llvm.cpp +++ b/src/transform/lower_to_llvm.cpp @@ -224,7 +224,10 @@ public: // Based on the static entry point type signature, unpack dynamic memory // refs to corresponding static memory refs. - auto *staticEntryPointFunc = module.lookupSymbol(staticEntryPointFuncName); + auto wrappedStaticEntryPointFuncName = + "_mlir_ciface_" + staticEntryPointFuncName.lower(); + auto *staticEntryPointFunc = + module.lookupSymbol(wrappedStaticEntryPointFuncName); assert(staticEntryPointFunc && isa(staticEntryPointFunc) && "entry point func must exist and be an llvm func op"); @@ -268,7 +271,8 @@ public: // Call static entry point with the memref ptrs created, and get output. auto outputMemRefs = rewriter.create( loc, staticEntryPointTy.getFunctionResultType(), - rewriter.getSymbolRefAttr(staticEntryPointFuncName), staticInputs); + rewriter.getSymbolRefAttr(wrappedStaticEntryPointFuncName), + staticInputs); // Create wrapped output. auto wrappedOutput = callApi(rewriter, loc, apiRegistry, @@ -563,7 +567,9 @@ void KrnlToLLVMLoweringPass::runOnModule() { OwningRewritePatternList patterns; populateAffineToStdConversionPatterns(patterns, &getContext()); populateLoopToStdConversionPatterns(patterns, &getContext()); - populateStdToLLVMConversionPatterns(typeConverter, patterns); + populateStdToLLVMConversionPatterns(typeConverter, patterns, + /*useAlloca=*/false, + /*emitCWrapper=*/true); // Lower from the `krnl` dialect i.e. the Reshape operation. patterns.insert, %arg1 : tensor<4xi32>) -> tensor<*x "std.return"(%0) : (tensor<*xf32>) -> () // CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) + // CHECK: [[TMP:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.insertvalue %arg0, %0[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.insertvalue %arg1, %1[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.insertvalue %arg2, %2[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.insertvalue %arg3, %3[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.insertvalue %arg5, %4[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.insertvalue %arg4, %5[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[TMP1:%.+]] = llvm.insertvalue %arg6, %6[4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: [[RES:%.+]] = llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }"> // CHECK: [[EXT_VAL_0:%.+]] = llvm.extractvalue [[RES]][1] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }"> // CHECK: [[DST:%.+]] = llvm.bitcast [[EXT_VAL_0]] : !llvm<"float*"> to !llvm<"i8*"> - // CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue %0[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue [[TMP1]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm<"float*"> to !llvm<"i8*"> // CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64 // CHECK: [[VOLATILE:%.+]] = llvm.mlir.constant(0 : i1) : !llvm.i1 diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 8f4843a..9da12ac 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -795,12 +795,42 @@ func @test_gemm(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>, %arg2: tenso // CHECK: [[SUM:%.+]] = addf [[Y]], [[AB]] : f32 // CHECK: store [[SUM]], [[RES]][%arg3, %arg4] : memref<10x10xf32> // CHECK: } - // CHECK: [[C:%.+]] = load %arg2[%arg4] : memref<10xf32> // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg3, %arg4] : memref<10x10xf32> // CHECK: [[ALPHA_AB:%.+]] = mulf [[ALPHA]], [[LOAD_Y]] : f32 + // CHECK: [[C:%.+]] = load %arg2[%arg4] : memref<10xf32> // CHECK: [[BETA_C:%.+]] = mulf [[BETA]], [[C]] : f32 // CHECK: [[Y_RES:%.+]] = addf [[ALPHA_AB]], [[BETA_C]] : f32 // CHECK: store [[Y_RES]], [[RES]][%arg3, %arg4] : memref<10x10xf32> + // CHECK: } + // CHECK: return [[RES]] : memref<10x10xf32> + // CHECK: } +} + +func @test_gemm_no_bias(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>) -> tensor<*xf32> { + %0 ="onnx.GemmNoBias"(%arg0, %arg1) {alpha = 1.0 : f32, beta = 5.0 : f32, transA = 1, transB = 0} : (tensor<5x10xf32>, tensor<5x10xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_gemm_no_bias + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32> + // CHECK: [[ALPHA:%.+]] = constant 1.000000e+00 : f32 + // CHECK: [[BETA:%.+]] = constant 5.000000e+00 : f32 + // CHECK: [[DEF_LOOPS:%.+]]:3 = krnl.define_loops 3 + // CHECK: [[OPT_LOOPS:%.+]]:3 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1, [[DEF_LOOPS]]#2 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: krnl.iterate([[OPT_LOOPS]]#2) with ([[DEF_LOOPS]]#2 -> %arg4 = 0 to 5) { + // CHECK: [[A:%.+]] = load %arg0[%arg4, %arg2] : memref<5x10xf32> + // CHECK: [[B:%.+]] = load %arg1[%arg4, %arg3] : memref<5x10xf32> + // CHECK: [[Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[AB:%.+]] = mulf [[A]], [[B]] : f32 + // CHECK: [[SUM:%.+]] = addf [[Y]], [[AB]] : f32 + // CHECK: store [[SUM]], [[RES]][%arg2, %arg3] : memref<10x10xf32> + // CHECK: } + // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[ALPHA_AB:%.+]] = mulf [[ALPHA]], [[LOAD_Y]] : f32 + // CHECK: store [[ALPHA_AB]], [[RES]][%arg2, %arg3] : memref<10x10xf32> + // CHECK: } // CHECK: return [[RES]] : memref<10x10xf32> // CHECK: } } @@ -1283,3 +1313,63 @@ func @test_conv_no_bias_no_pad_w_strides(%arg0 : tensor<1x9x32x64xf32>, %arg1 : // CHECK: return [[RES]] : memref<1x5x14x29xf32> } + +func @test_batchnorm_testmode_Nd(%arg0: tensor<1x2x1x3xf32>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>, %arg4: tensor<2xf32>) -> tensor<1x2x1x3xf32> { + %0 = "onnx.BatchNormalizationTestMode"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<1x2x1x3xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<1x2x1x3xf32> + return %0 : tensor<1x2x1x3xf32> + + // CHECK-LABEL: test_batchnorm_testmode_Nd + // CHECK: [[RES:%.+]] = alloc() : memref<1x2x1x3xf32> + // CHECK: [[EPSILON:%.+]] = constant 9.99999974E-6 : f32 + // CHECK: [[DEF_LOOPS:%.+]]:4 = krnl.define_loops 4 + // CHECK: [[OPT_LOOPS:%.+]]:4 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1, [[DEF_LOOPS]]#2, [[DEF_LOOPS]]#3 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#1 -> %arg5 = 0 to 2) { + // CHECK: [[SCALE:%.+]] = load %arg1[%arg5] : memref<2xf32> + // CHECK: [[BIAS:%.+]] = load %arg2[%arg5] : memref<2xf32> + // CHECK: [[MEAN:%.+]] = load %arg3[%arg5] : memref<2xf32> + // CHECK: [[VARIANCE:%.+]] = load %arg4[%arg5] : memref<2xf32> + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#2, [[OPT_LOOPS]]#3) with ([[DEF_LOOPS]]#0 -> %arg6 = 0 to 1, [[DEF_LOOPS]]#2 -> %arg7 = 0 to 1, [[DEF_LOOPS]]#3 -> %arg8 = 0 to 3) { + // CHECK: [[LOADED_VAL:%.+]] = load %arg0[%arg6, %arg5, %arg7, %arg8] : memref<1x2x1x3xf32> + // CHECK: [[DIVIDEND:%.+]] = subf [[LOADED_VAL]], [[MEAN]] : f32 + // CHECK: [[ADJUSTED_VARIANCE:%.+]] = addf [[VARIANCE]], [[EPSILON]] : f32 + // CHECK: [[DIVISOR:%.+]] = "krnl.sqrt"([[ADJUSTED_VARIANCE]]) : (f32) -> f32 + // CHECK: [[NORM:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32 + // CHECK: [[SCALE_NORM:%.+]] = mulf [[SCALE]], [[NORM]] : f32 + // CHECK: [[SHIFT_SCALE_NORM:%.+]] = addf [[SCALE_NORM]], [[BIAS]] : f32 + // CHECK: store [[SHIFT_SCALE_NORM]], [[RES]][%arg6, %arg5, %arg7, %arg8] : memref<1x2x1x3xf32> + // CHECK: } + // CHECK: } + // CHECK: return [[RES]] : memref<1x2x1x3xf32> +} + +func @test_batchnorm_testmode_1d(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<10xf32> { + %0 = "onnx.BatchNormalizationTestMode"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<10xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<10xf32> + return %0 : tensor<10xf32> + + // CHECK-LABEL: test_batchnorm_testmode_1d + // CHECK: [[RES:%.+]] = alloc() : memref<10xf32> + // CHECK: [[EPSILON:%.+]] = constant 9.99999974E-6 : f32 + // CHECK: [[DEF_LOOPS:%.+]] = krnl.define_loops 1 + // CHECK: [[OPT_LOOPS:%.+]] = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]] + // CHECK: } : () -> !krnl.loop + // CHECK: %[[ZERO_INDEX:.+]] = constant 0 : index + // CHECK: [[SCALE:%.+]] = load %arg1[%[[ZERO_INDEX]]] : memref<1xf32> + // CHECK: [[BIAS:%.+]] = load %arg2[%[[ZERO_INDEX]]] : memref<1xf32> + // CHECK: [[MEAN:%.+]] = load %arg3[%[[ZERO_INDEX]]] : memref<1xf32> + // CHECK: [[VARIANCE:%.+]] = load %arg4[%[[ZERO_INDEX]]] : memref<1xf32> + // CHECK: krnl.iterate([[OPT_LOOPS]]) with ([[DEF_LOOPS]] -> %arg5 = 0 to 10) { + // CHECK: [[LOADED_VAL:%.+]] = load %arg0[%arg5] : memref<10xf32> + // CHECK: [[DIVIDEND:%.+]] = subf [[LOADED_VAL]], [[MEAN]] : f32 + // CHECK: [[ADJUSTED_VARIANCE:%.+]] = addf [[VARIANCE]], [[EPSILON]] : f32 + // CHECK: [[DIVISOR:%.+]] = "krnl.sqrt"([[ADJUSTED_VARIANCE]]) : (f32) -> f32 + // CHECK: [[NORM:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32 + // CHECK: [[SCALE_NORM:%.+]] = mulf [[SCALE]], [[NORM]] : f32 + // CHECK: [[SHIFT_SCALE_NORM:%.+]] = addf [[SCALE_NORM]], [[BIAS]] : f32 + // CHECK: store [[SHIFT_SCALE_NORM]], [[RES]][%arg5] : memref<10xf32> + // CHECK: } + // CHECK: return [[RES]] : memref<10xf32> +} +