From b9f2f25b563e3fb6f37b6322aa1cfd7e5205fef9 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Wed, 19 Feb 2020 16:17:48 +0900 Subject: [PATCH] [NFC] Categorize ONNX ops lowering (#80) * Create two categories: elementwise and tensor * typos * Create directories for categories * Edit comments * Extract a function that creates a KrnlIterateOp * Add comments * Extract some common parts * Revise softmax * Add reduction.inc * Move lower-frontend to lib/conversion * Move directory to directory * Change file/directory names * Comment format * Add matmul.inc --- src/CMakeLists.txt | 2 +- .../onnx_to_krnl/convert_onnx_to_krnl.cpp | 529 ++++ .../rewrite_patterns/math/elementwise.inc | 646 ++++ .../rewrite_patterns/math/gemm.inc | 209 ++ .../rewrite_patterns/math/matmul.inc | 345 +++ .../rewrite_patterns/math/reduction.inc | 307 ++ .../rewrite_patterns/math/softmax.inc | 205 ++ .../onnx_to_krnl/rewrite_patterns/nn/conv.inc | 282 ++ .../rewrite_patterns/tensor/identity.inc | 26 + .../rewrite_patterns/tensor/reshape.inc | 151 + .../rewrite_patterns/tensor/transpose.inc | 99 + .../rewrite_patterns/tensor/unsqueeze.inc | 86 + src/pass/lower_frontend_to_krnl.cpp | 2719 ----------------- 13 files changed, 2886 insertions(+), 2720 deletions(-) create mode 100644 src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp create mode 100644 src/conversion/onnx_to_krnl/rewrite_patterns/math/elementwise.inc create mode 100644 src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc create mode 100644 src/conversion/onnx_to_krnl/rewrite_patterns/math/matmul.inc create mode 100644 src/conversion/onnx_to_krnl/rewrite_patterns/math/reduction.inc create mode 100644 src/conversion/onnx_to_krnl/rewrite_patterns/math/softmax.inc create mode 100644 src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc create mode 100644 src/conversion/onnx_to_krnl/rewrite_patterns/tensor/identity.inc create mode 100644 src/conversion/onnx_to_krnl/rewrite_patterns/tensor/reshape.inc create mode 100644 src/conversion/onnx_to_krnl/rewrite_patterns/tensor/transpose.inc create mode 100644 src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc delete mode 100644 src/pass/lower_frontend_to_krnl.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index bf44984..4a03cbd 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -57,7 +57,7 @@ target_include_directories(onnf_shape_inference target_link_libraries(onnf_shape_inference ${MLIRLibs}) add_dependencies(onnf_shape_inference gen_krnl_ops) -add_library(onnf_lower_frontend pass/lower_frontend_to_krnl.cpp) +add_library(onnf_lower_frontend conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp) target_include_directories(onnf_lower_frontend PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT} ${ONNF_SRC_ROOT}) diff --git a/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp b/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp new file mode 100644 index 0000000..9c9b826 --- /dev/null +++ b/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp @@ -0,0 +1,529 @@ +//====- convert_onnx_to_krnl.cpp - ONNX dialects to Krnl lowering ---------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file implements the lowering of frontend operations to a combination of +// Krnl IR and standard operations. +// +//===----------------------------------------------------------------------===// +#include + +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Sequence.h" + +#include "src/dialect/krnl/krnl_helper.hpp" +#include "src/dialect/krnl/krnl_ops.hpp" +#include "src/dialect/onnx/onnx_ops.hpp" +#include "src/pass/passes.hpp" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// FrontendToAffine RewritePatterns +//===----------------------------------------------------------------------===// + +/// Check is all dimensions are known at compile time. +static bool hasAllConstantDimensions(MemRefType type) { + auto memRefShape = type.getShape(); + for (int i = 0; i < memRefShape.size(); ++i) + if (memRefShape[i] < 0) + return false; + return true; +} + +/// Convert the given TensorType into the corresponding MemRefType. +static MemRefType convertTensorToMemRef(TensorType type) { + assert(type.hasRank() && "expected only ranked shapes"); + return MemRefType::get(type.getShape(), type.getElementType()); +} + +/// Insert an allocation and deallocation for the given MemRefType. +static Value insertAllocAndDealloc(MemRefType type, Location loc, + PatternRewriter &rewriter, + bool insertDealloc, + ArrayRef operands = {}) { + // Put together alloc operands for any dynamic dimensions of the memref. + AllocOp alloc; + if (!operands.empty()) { + auto memRefShape = type.getShape(); + auto rank = memRefShape.size(); + + std::map fromOperands; + for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { + int memRefDimIdx = rank - 1 - reversedIdx; + if (memRefShape[memRefDimIdx] < 0) { // unknown dimension + Value maxDim = nullptr; + for (int i = 0; i < operands.size(); i++) { + auto operandShape = + operands[i].getType().cast().getShape(); + int operandDimIdx = operandShape.size() - 1 - reversedIdx; + + if (operandDimIdx < 0) + continue; + + // In case of operations with broadcasting, the dimension of the + // alloc result is the maximum size along each dimension of the + // operands. + auto operandDim = + rewriter.create(loc, operands[i], operandDimIdx); + if (maxDim) { + auto maxCondition = rewriter.create(loc, CmpIPredicate::sgt, + operandDim, maxDim); + maxDim = rewriter.create(loc, maxCondition, operandDim, + maxDim); + } else { + maxDim = operandDim; + } + } + fromOperands.insert(std::make_pair(memRefDimIdx, maxDim)); + } + } + + SmallVector allocOperands; + for (int i = 0; i < rank; ++i) + if (memRefShape[i] < 0) + allocOperands.push_back(fromOperands[i]); + alloc = rewriter.create(loc, type, allocOperands); + } else { + alloc = rewriter.create(loc, type); + } + + // Make sure to allocate at the beginning of the block if + // all dimensions are known. + auto *parentBlock = alloc.getOperation()->getBlock(); + if (hasAllConstantDimensions(type)) + alloc.getOperation()->moveBefore(&parentBlock->front()); + + if (insertDealloc) { + auto dealloc = rewriter.create(loc, alloc); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + } + + return alloc; +} + +// Determine if current function returns the result value of the +// current op being lowered. If it does then dealloc should not be +// inserted. +static bool checkInsertDealloc(Operation *currentOp) { + auto parentBlock = currentOp->getBlock(); + + bool insertDealloc = true; + parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) { + assert(currentOp->getNumResults() < 2 && + "No more than one result supported (for now)."); + // If there is at least one result to investigate. + if (currentOp->getNumResults() > 0) { + auto result = currentOp->getResult(0); + for (const auto &operand : op.getOperands()) + if (operand == result) + insertDealloc = false; + } + }); + + return insertDealloc; +} + +// Create a mapping from result type's dimensions to input type's dimensions, +// given that the result type is the result of a reduction op over the input +// type. +std::map +getReductionMapping(MemRefType inputTy, ArrayRef axes, bool keepdims) { + std::map OutInDimMap; + int64_t rank = inputTy.getRank(); + + // Mark reduction axes. + std::vector isReductionAxis; + for (decltype(rank) i = 0; i < rank; ++i) { + if (std::find(axes.begin(), axes.end(), i) != axes.end()) + isReductionAxis.push_back(true); + else + isReductionAxis.push_back(false); + } + + for (decltype(rank) inIndex = 0, outIndex = 0; inIndex < rank; ++inIndex) { + // If it is a reduction axis, there is no relationship among dimensions. + if (isReductionAxis[inIndex]) { + if (keepdims) + outIndex++; + } else { + OutInDimMap.insert(std::make_pair(outIndex, inIndex)); + outIndex++; + } + } + + return OutInDimMap; +} + +// Add bounds associated with the op operand to the KRNL iteration pack. +// Dynamic dimenions are supported. +static void addDimensionToPack(ConversionPatternRewriter &rewriter, + Location loc, KrnlIterateOperandPack &pack, + Value operand, int index) { + auto shape = operand.getType().cast().getShape(); + if (shape[index] < 0) { + pack.pushConstantBound(0); + pack.pushOperandBound( + rewriter.create(loc, operand, index).getResult()); + } else { + pack.pushConstantBound(0); + pack.pushConstantBound(shape[index]); + } +} + +// Function that defines the KRNL dialect loops and their respective +// optimized version. +static KrnlOptimizeLoopsOp +emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc, + std::vector &loops, + std::vector &optimizedLoops, int64_t numLoops) { + // Define loops. + auto loopsOp = rewriter.create(loc, numLoops); + loops.reserve(numLoops); + for (auto result : loopsOp.getResults()) + loops.push_back(result); + + // Define optimized version of the loops. + auto optimizedLoopsOp = rewriter.create(loc, numLoops); + optimizedLoops.reserve(numLoops); + for (auto result : optimizedLoopsOp.getResults()) + optimizedLoops.push_back(result); + + return optimizedLoopsOp; +} + +// Function that emits the loops and their optimized version. +// The function returns a reference to the inner optimization block. +static Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc, + std::vector &loops, + std::vector &optimizedLoops, + int64_t numLoops) { + KrnlOptimizeLoopsOp optimizedLoopsOp = + emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops); + return &optimizedLoopsOp.region().front(); +} + +// Function which emits a basic set of loops and optimized loops +// for a given operation argument. A reference to the loop optimization +// block is returned in the last argument of the function. +static void emitKrnlLoopsAndIterationForOperand( + ConversionPatternRewriter &rewriter, Location loc, Value operand, + std::vector &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp, + KrnlIterateOp &iterateOp) { + // Operand shape. + auto shape = operand.getType().cast().getShape(); + + // Number of loops. + int64_t rank = shape.size(); + + // Define loops and optimized loops. + std::vector optimizedLoops; + optimizedLoopsOp = + emitOptimizedLoops(rewriter, loc, originalLoops, optimizedLoops, rank); + + KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); + // Iterate over the loop nest. + for (int i = 0; i < rank; ++i) + addDimensionToPack(rewriter, loc, pack, operand, i); + + iterateOp = rewriter.create(loc, pack); +} + +unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { + auto elementType = memRefType.getElementType(); + + unsigned sizeInBits; + if (elementType.isIntOrFloat()) { + sizeInBits = elementType.getIntOrFloatBitWidth(); + } else { + auto vectorType = elementType.cast(); + sizeInBits = + vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); + } + return llvm::divideCeil(sizeInBits, 8); +} + +// Get run-time dimension information for unknown dimensions used for +// broadcasting. +std::map> +getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, + MemRefType memRefType, ArrayRef operands) { + auto memRefShape = memRefType.getShape(); + int64_t rank = memRefShape.size(); + // For unknown dimensions, we need to get dimension values at runtime in + // order to do broadcasting. + std::map> DimInfo; + // For each result dimension, compute the number of sharing operands. + // Sharing operands are operands sharing the same index (counting from the + // rightmost to the leftmost) for a given dimension. + std::map sharedDimCount; + for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { + int dimIdx = rank - 1 - reversedIdx; + sharedDimCount[dimIdx] = 0; + for (int i = 0; i < operands.size(); ++i) { + auto shape = operands[i].getType().cast().getShape(); + if (reversedIdx <= shape.size() - 1) + sharedDimCount[dimIdx]++; + } + } + // An unknown dimension can have a value of 1 or N (N > 1). + // If its value is 1, it is broadcasted dimension. + // Otherwise, non-broadcasted dimension. + // We only care about unknown dimensions whose number of sharing operands is + // more than one, since they are potentially broadcasted dimensions. + for (int i = 0; i < operands.size(); ++i) { + std::map broadcastedDims; + auto shape = operands[i].getType().cast().getShape(); + int size = shape.size(); + for (int j = 0; j < shape.size(); ++j) { + if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) { + auto dim = rewriter.create(loc, operands[i], j).getResult(); + auto one = rewriter.create(loc, 1); + auto isBroadcasted = + rewriter.create(loc, CmpIPredicate::eq, dim, one); + broadcastedDims.insert(std::make_pair(j, isBroadcasted)); + } + } + DimInfo.insert(std::make_pair(i, broadcastedDims)); + } + return DimInfo; +} + +// Extract induction variables that are used for broadcasting values of a +// given operand. +std::vector +getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, + ArrayRef loopIVs, Value operand, + std::map broadcastedDims) { + // `operand` must has a ranked type. This should have been checked by the + // shape inference pass. + auto operandShape = operand.getType().cast().getShape(); + auto rank = operandShape.size(); + auto loopCount = loopIVs.size(); + + std::vector newLoopIVs; + for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { + auto dimIdx = rank - 1 - reversedIdx; + auto loopIdx = loopCount - 1 - reversedIdx; + if (operandShape[dimIdx] == 1) { + // Broadcasted dimension + auto zero = rewriter.create(loc, 0); + newLoopIVs.insert(newLoopIVs.begin(), zero); + } else if ((operandShape[dimIdx] == -1) && + (broadcastedDims.find(dimIdx) != broadcastedDims.end())) { + // Unknown dimension, it can have a value of 1 or N (N > 1). + // If its value is 1, it is broadcasted dimension. + // Otherwise, non-broadcasted dimension. + auto zero = rewriter.create(loc, 0); + auto idx = rewriter.create(loc, broadcastedDims[dimIdx], zero, + loopIVs[loopIdx]); + newLoopIVs.insert(newLoopIVs.begin(), idx); + } else { + // Non-broadcasted dimension + newLoopIVs.insert(newLoopIVs.begin(), loopIVs[loopIdx]); + } + } + return newLoopIVs; +} + +namespace { + +// This is to get a scalar operation of a given type for a specific operation. +template +struct ScalarOp { + using FOp = void; + using IOp = void; +}; + +template +using ScalarFOp = typename ScalarOp::FOp; +template +using ScalarIOp = typename ScalarOp::IOp; + +// Get the identity element of a operation. +// Return NULL if the function does not have identity. +template +DataType getIdentityValue() { + return NULL; +} + +//===----------------------------------------------------------------------===// +// This is used in the innermost loop of a KrnlIterateOp to insert computation +// composed of one or many scalar ops. +// Use template specialization for each of different ONNX operations. +//===----------------------------------------------------------------------===// +template +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + auto loc = op->getLoc(); + Type element_type = operands.front().getType(); + if (element_type.isa()) { + return rewriter.create>(loc, result_types, operands, + mlir::None); + } else if (element_type.isa()) { + return rewriter.create>(loc, result_types, operands, + mlir::None); + } else { + emitError(loc, "unsupported element type"); + return nullptr; + } +} + +// We divide the operator lowering into different categories. +// These categories are mostly similar to the operator categories in ONNX: +// https://github.com/onnx/onnx/tree/master/onnx/defs. +// Besides, it is better to put operators with the same computation pattern into +// the same category, e.g. element-wise operators will belong to the elementwise +// category. + +// Math +#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/elementwise.inc" +#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc" +#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/reduction.inc" +#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/softmax.inc" +#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/matmul.inc" +// Tensor +#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/identity.inc" +#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/reshape.inc" +#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/transpose.inc" +#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc" +// Neural network +#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc" + +//===----------------------------------------------------------------------===// +// EntryPoint Op lowering to Krnl Entry Point. +//===----------------------------------------------------------------------===// + +class ONNXEntryPointLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(ONNXEntryPointOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, + op.getAttrOfType( + ONNXEntryPointOp::getEntryPointFuncAttrName()), + op.getAttrOfType(ONNXEntryPointOp::getNumInputsAttrName()), + op.getAttrOfType( + ONNXEntryPointOp::getNumOutputsAttrName())); + return matchSuccess(); + } +}; + +//===----------------------------------------------------------------------===// +// Conversion from Tensor type to the Standard dialect MemRef type. +//===----------------------------------------------------------------------===// + +struct TensorTypeConverter : public TypeConverter { + using TypeConverter::TypeConverter; + + LogicalResult convertType(Type t, SmallVectorImpl &results) override { + if (auto tensor_type = t.dyn_cast()) { + results.push_back(convertTensorToMemRef(tensor_type)); + return success(); + } + + results.push_back(t); + return success(); + } + + /// Return true if the inputs and outputs of the given function type are + /// legal. [Taken from MLIR and adapted to only check the legality of the + /// inputs. Once unranked results can be handled gracefully this + /// override needs to be removed in favour of the original MLIR one.] + bool isSignatureLegal(FunctionType funcType) { + return llvm::all_of(funcType.getInputs(), + [this](Type type) { return isLegal(type); }); + } +}; + +} // end anonymous namespace. + +//===----------------------------------------------------------------------===// +// Frontend to Krnl Dialect lowering pass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering to Krnl loops of the ONNX operations. +namespace { +struct FrontendToKrnlLoweringPass + : public ModulePass { + void runOnModule() final; +}; +} // end anonymous namespace. + +void FrontendToKrnlLoweringPass::runOnModule() { + auto module = getModule(); + + // The first thing to define is the conversion target. This will define the + // final target for this lowering. + ConversionTarget target(getContext()); + + // We define the specific operations, or dialects, that are legal targets for + // this lowering. + target + .addLegalDialect(); + + // TODO: enable this once more ops are supported. + // We also define the ONNX dialect as Illegal so that the conversion will fail + // if any of these operations are *not* converted. + // target.addIllegalDialect(); + + // TODO: add any other ops which are considered legal. + // Some operations can be marked as being still legal. + // Example: target.addLegalOp(); + + // Now that the conversion target has been defined, we just need to provide + // the set of patterns that will lower the frontend operations. + OwningRewritePatternList patterns; + + // Convert TensorType to MemRef + TensorTypeConverter tensor_to_memref_converter; + target.addDynamicallyLegalOp([&](FuncOp op) { + // FuncOp is legal only if types have been converted to Std types. + return tensor_to_memref_converter.isSignatureLegal(op.getType()); + }); + + // Type conversion for function signatures. + // Call MLIR FuncOp signature conversion when result type is + // a ranked tensor. + populateFuncOpTypeConversionPattern(patterns, &getContext(), + tensor_to_memref_converter); + + // Frontend operation lowering. + // Math + populateLoweringONNXElementwiseOpPattern(patterns, &getContext()); + populateLoweringONNXGemmOpPattern(patterns, &getContext()); + populateLoweringONNXReductionOpPattern(patterns, &getContext()); + populateLoweringONNXSoftmaxOpPattern(patterns, &getContext()); + populateLoweringONNXMatMulOpPattern(patterns, &getContext()); + // Tensor + populateLoweringONNXReshapeOpPattern(patterns, &getContext()); + populateLoweringONNXUnsqueezeOpPattern(patterns, &getContext()); + populateLoweringONNXTransposeOpPattern(patterns, &getContext()); + populateLoweringONNXIdentityOpPattern(patterns, &getContext()); + // Neural network + populateLoweringONNXConvOpPattern(patterns, &getContext()); + // Entry point + patterns.insert(&getContext()); + + // With the target and rewrite patterns defined, we can now attempt the + // conversion. The conversion will signal failure if any of our `illegal` + // operations were not converted successfully. + if (failed(applyPartialConversion(module, target, patterns))) + signalPassFailure(); +} + +std::unique_ptr mlir::createLowerToKrnlPass() { + return std::make_unique(); +} + +static PassRegistration + pass("lower-frontend", "Lower frontend ops to Krnl dialect."); diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/math/elementwise.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/math/elementwise.inc new file mode 100644 index 0000000..b48e23a --- /dev/null +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/math/elementwise.inc @@ -0,0 +1,646 @@ +//===----- elementwise.inc - Elementwise Ops ------------------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers ONNX element-wise operators to Krnl dialect. +// +//===----------------------------------------------------------------------===// + +template <> +struct ScalarOp { + using FOp = AddFOp; + using IOp = AddIOp; +}; + +template <> +struct ScalarOp { + using FOp = MulFOp; + using IOp = MulIOp; +}; + +template <> +struct ScalarOp { + using FOp = DivFOp; + using IOp = SignedDivIOp; +}; + +template <> +struct ScalarOp { + using FOp = SubFOp; + using IOp = SubIOp; +}; + +template <> +struct ScalarOp { + using FOp = AndOp; // not use + using IOp = AndOp; +}; + +template <> +struct ScalarOp { + using FOp = OrOp; // not use + using IOp = OrOp; +}; + +template <> +struct ScalarOp { + using FOp = XOrOp; // not use + using IOp = XOrOp; +}; + +template <> +struct ScalarOp { + using FOp = ExpOp; + using IOp = ExpOp; // not use +}; + +template <> +struct ScalarOp { + using FOp = AddFOp; + using IOp = AddIOp; +}; + +template <> +struct ScalarOp { + using FOp = TanhOp; + using IOp = TanhOp; // not use +}; + +template <> +struct ScalarOp { + using FOp = CosOp; + using IOp = CosOp; // not use +}; + +template <> +struct ScalarOp { + using FOp = LogOp; + using IOp = LogOp; // not use +}; + +template <> +struct ScalarOp { + using FOp = KrnlSqrtOp; + using IOp = KrnlSqrtOp; // not use +}; + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXSinhOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), + // ConstantOp 2) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto two = rewriter.create(loc, FloatAttr::get(elementType, 2)); + auto neg = rewriter.create(loc, zero, operand); + auto exp = rewriter.create(loc, operand); + auto negExp = rewriter.create(loc, neg); + auto result = rewriter.create( + loc, rewriter.create(loc, exp, negExp), two); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXCoshOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))), + // ConstantOp 2) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto two = rewriter.create(loc, FloatAttr::get(elementType, 2)); + auto neg = rewriter.create(loc, zero, operand); + auto exp = rewriter.create(loc, operand); + auto negExp = rewriter.create(loc, neg); + auto result = rewriter.create( + loc, rewriter.create(loc, exp, negExp), two); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXSigmoidOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, + ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1, + // AddFOp(ConstantOp 1, ExpOp(NegFOp(%X)))) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); + auto neg = rewriter.create(loc, zero, operand); + auto negExp = rewriter.create(loc, neg); + auto result = rewriter.create( + loc, one, rewriter.create(loc, one, negExp)); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXHardSigmoidOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp( + Operation *op, ArrayRef result_types, ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // %Y = AddFOp(MulFOp(alpha, %X), beta) + // %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0), + // %Y, + // Constant 0) + // ONNXHardSigmoidOp(%X) = SelectOp(CmpFOp(OLT, %Z, Constant 1), + // %Z, + // Constant 1) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), + llvm::dyn_cast(op).alpha().convertToFloat()); + auto betaAttribute = FloatAttr::get(rewriter.getF32Type(), + llvm::dyn_cast(op).beta().convertToFloat()); + auto elementType = result_types[0]; + + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); + auto alpha = rewriter.create(loc, alphaAttribute); + auto beta = rewriter.create(loc, betaAttribute); + + auto add = rewriter.create( + loc, rewriter.create(loc, alpha, operand), beta); + auto maxPredicate = + rewriter.create(loc, CmpFPredicate::OGT, add, zero); + auto max = rewriter.create(loc, maxPredicate, add, zero); + auto minPredicate = + rewriter.create(loc, CmpFPredicate::OLT, max, one); + auto result = rewriter.create(loc, minPredicate, max, one); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXEluOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), + // MulFOp(alpha, SubFOp(ExpOp(%X), 1)), + // %X) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), + llvm::dyn_cast(op).alpha().convertToFloat()); + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); + auto alpha = rewriter.create(loc, alphaAttribute); + auto exp = rewriter.create(loc, operand); + auto lessThanZero = + rewriter.create(loc, CmpFPredicate::OLT, operand, zero); + auto result = rewriter.create( + loc, lessThanZero, + rewriter.create(loc, alpha, + rewriter.create(loc, exp, one)), + operand); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXReluOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), + // ConstantOp 0, + // %X) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto lessThanZero = + rewriter.create(loc, CmpFPredicate::OLT, operand, zero); + auto result = rewriter.create(loc, lessThanZero, zero, operand); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXLeakyReluOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, + ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), + // MulFOp(alpha, %X), + // %X) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), + llvm::dyn_cast(op).alpha().convertToFloat()); + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto alpha = rewriter.create(loc, alphaAttribute); + auto lessThanZero = + rewriter.create(loc, CmpFPredicate::OLT, operand, zero); + auto result = rewriter.create( + loc, lessThanZero, rewriter.create(loc, alpha, operand), operand); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXSeluOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0), + // MulFOp(gamma, %X), + // MulFOp(gamma, + // SubFOp(MulFOp(alpha, ExpOp(%X)), + // alpha))) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(), + llvm::dyn_cast(op).alpha().convertToFloat()); + auto gammaAttribute = FloatAttr::get(rewriter.getF32Type(), + llvm::dyn_cast(op).gamma().convertToFloat()); + auto elementType = result_types[0]; + + auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); + auto alpha = rewriter.create(loc, alphaAttribute); + auto gamma = rewriter.create(loc, gammaAttribute); + auto exp = rewriter.create(loc, operand); + auto greaterThanZero = + rewriter.create(loc, CmpFPredicate::OGT, operand, zero); + auto select = rewriter.create( + loc, greaterThanZero, operand, + rewriter.create(loc, rewriter.create(loc, alpha, exp), + alpha)); + auto result = rewriter.create(loc, gamma, select); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXReciprocalOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp( + Operation *op, ArrayRef result_types, ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); + auto result = rewriter.create(loc, one, operand); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXSoftplusOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp( + Operation *op, ArrayRef result_types, ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXSoftplusOp(%X) = LogOp(AddFOp(ExpOp(%X), ConstantOp 1)) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto exp = rewriter.create(loc, operand); + auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); + auto add = rewriter.create(loc, exp, one); + auto result = rewriter.create(loc, add); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXSoftsignOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp( + Operation *op, ArrayRef result_types, ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXSoftsignOp(%X) = DivFOp(ConstantOp 1, %X) + auto loc = op->getLoc(); + Value operand = operands[0]; + auto elementType = result_types[0]; + + auto abs = rewriter.create(loc, operand); + auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); + auto add = rewriter.create(loc, abs, one); + auto result = rewriter.create(loc, operand, add); + + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXSignOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + + auto loc = op->getLoc(); + Value operand = operands[0]; + Type element_type = operands.front().getType(); + // TODO: unsigned int should be supported separately? + if (element_type.isa()) { + // %Y = SelectOP(CmpIOp(GT, %X, ConstantOp 0), + // ConstantOp 1, + // COnstantOp -1) + // ONNXSignOp(%X) = SelectOP(CmpIOp(EQ, %X, ConstantOp 0), + // ConstantOp 0, + // %Y) + auto zero = rewriter.create(loc, rewriter.getI32IntegerAttr(0)); + auto one = rewriter.create(loc, rewriter.getI32IntegerAttr(1)); + auto minusOne = + rewriter.create(loc, rewriter.getI32IntegerAttr(-1)); + auto plusPredicate = + rewriter.create(loc, CmpIPredicate::sgt, operand, zero); + auto plusSelect = + rewriter.create(loc, plusPredicate, one, minusOne); + auto zeroPredicate = + rewriter.create(loc, CmpIPredicate::eq, operand, zero); + auto result = + rewriter.create(loc, zeroPredicate, zero, plusSelect); + return result; + } else if (element_type.isa()) { + // %Y = SelectOP(CmpFOp(OGT, %X, ConstantOp 0), + // ConstantOp 1, + // ConstantOp -1) + // ONNXSignOp(%X) = SelectOP(CmpFOp(OEQ, %X, ConstantOp 0), + // ConstantOp 0, + // %Y) + auto zero = + rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); + auto one = rewriter.create(loc, rewriter.getF32FloatAttr(1.0f)); + auto minusOne = + rewriter.create(loc, rewriter.getF32FloatAttr(-1.0f)); + auto plusPredicate = + rewriter.create(loc, CmpFPredicate::OGT, operand, zero); + auto plusSelect = + rewriter.create(loc, plusPredicate, one, minusOne); + auto zeroPredicate = + rewriter.create(loc, CmpFPredicate::OEQ, operand, zero); + auto result = + rewriter.create(loc, zeroPredicate, zero, plusSelect); + return result; + } else { + emitError(loc, "unsupported element type"); + } +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXMaxOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y), + // %X, + // %Y) + auto loc = op->getLoc(); + Value lhs = operands[0]; + Value rhs = operands[1]; + auto max = rewriter.create(loc, CmpFPredicate::OGT, lhs, rhs); + auto result = rewriter.create(loc, max, lhs, rhs); + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXMinOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + // ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y), + // %X, + // %Y) + auto loc = op->getLoc(); + Value lhs = operands[0]; + Value rhs = operands[1]; + auto min = rewriter.create(loc, CmpFPredicate::OLT, lhs, rhs); + auto result = rewriter.create(loc, min, lhs, rhs); + return result; +} + +// Element-wise unary ops lowering to Krnl dialect. +//===----------------------------------------------------------------------===// +template +struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { + ONNXElementwiseUnaryOpLowering(MLIRContext *ctx) + : ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {} + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // TODO: Check that the types are valid. + // An element-wise unary operation must have all operands and the result of + // the same type. This should have been verified by the verifier. + auto tensorType = (*op->result_type_begin()).cast(); + auto loc = op->getLoc(); + + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertTensorToMemRef(tensorType); + + // If the output has a dynamic dimension, pass the operands required for + // each dynamic dimension to the AllocOp. The first operand of the + // operation is used. The operands of the op need to match in terms of + // dimensions with the result at this pre-optimization phase. + // TODO: verify that dimensions match. + // TODO: can the dimension of the result differ after optimizations? + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, + {operands[0]}); + + std::vector originalLoops; + KrnlOptimizeLoopsOp optimizedLoopsOp; + KrnlIterateOp iterateOp; + emitKrnlLoopsAndIterationForOperand( + rewriter, loc, operands[0], originalLoops, + optimizedLoopsOp, iterateOp); + Block &optimizationBlock = optimizedLoopsOp.region().front(); + Block &iterationBlock = iterateOp.bodyRegion().front(); + + // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. + rewriter.setInsertionPointToEnd(&optimizationBlock); + // Return from KrnlOptimizeLoopsOp body. + // When no optimizations are present we just return the loops + // unchaged. + rewriter.create(loc, originalLoops); + + // 2. Insert instructions inside the KernelIterateOp body. + rewriter.setInsertionPointToStart(&iterationBlock); + + // Handle the operation: + SmallVector loopIVs; + for (auto arg : iterationBlock.getArguments()) + loopIVs.push_back(arg); + + auto loadedVal = rewriter.create(loc, operands[0], loopIVs); + auto loweredOpResult = mapToLowerScalarOp( + op, memRefType.getElementType(), {loadedVal}, rewriter); + // Store result in the resulting array. + rewriter.create(loc, loweredOpResult, alloc, loopIVs); + + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +// Element-wise variadic ops lowering to Krnl dialect. +//===----------------------------------------------------------------------===// +template +struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { + ONNXElementwiseVariadicOpLowering(MLIRContext *ctx) + : ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {} + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // TODO: Check that the types are valid. + // An element-wise variadic operation must have all operands and the result + // of the same type. This should have been verified by the verifier. + auto tensorType = (*op->result_type_begin()).cast(); + auto loc = op->getLoc(); + auto numArgs = op->getNumOperands(); + + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertTensorToMemRef(tensorType); + + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + // If the output has a dynamic dimension, we compute its dimension at + // runtime by using dimensions from the operands. + // In particular, we need to know from which operand a result dimension + // comes from. + // TODO: can the dimension of the result differ after optimizations? + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, + operands); + + // Get run-time dimension information for unknown dimensions used for + // broadcasting. + std::map> broadcastedDimInfo = + getBroadcastedDimInfo(loc, rewriter, memRefType, operands); + + std::vector originalLoops; + KrnlOptimizeLoopsOp optimizedLoopsOp; + KrnlIterateOp iterateOp; + emitKrnlLoopsAndIterationForOperand( + rewriter, loc, alloc, originalLoops, + optimizedLoopsOp, iterateOp); + Block &optimizationBlock = optimizedLoopsOp.region().front(); + Block &iterationBlock = iterateOp.bodyRegion().front(); + + // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. + rewriter.setInsertionPointToEnd(&optimizationBlock); + // Return from KrnlOptimizeLoopsOp body. + // When no optimizations are present we just return the loops unchaged. + rewriter.create(loc, originalLoops); + + // 2. Insert instructions inside the KernelIterateOp body. + rewriter.setInsertionPointToStart(&iterationBlock); + + // Handle the operation: + SmallVector loopIVs; + for (auto arg : iterationBlock.getArguments()) + loopIVs.push_back(arg); + + // Fold over operands for each of their scalar values + Value accumulated, next; + auto accumulatedLoopIVs = getLoopIVsForBroadcasting( + loc, rewriter, loopIVs, operands[0], broadcastedDimInfo[0]); + accumulated = rewriter.create(loc, operands[0], accumulatedLoopIVs); + for (unsigned i = 1; i < numArgs; i++) { + auto nextLoopIVs = getLoopIVsForBroadcasting( + loc, rewriter, loopIVs, operands[i], broadcastedDimInfo[i]); + next = rewriter.create(loc, operands[i], nextLoopIVs); + accumulated = mapToLowerScalarOp( + op, memRefType.getElementType(), {accumulated, next}, rewriter); + } + // Store result in the resulting array. + rewriter.create(loc, accumulated, alloc, loopIVs); + + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +void populateLoweringONNXElementwiseOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseVariadicOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseVariadicOpLowering>(ctx); +} diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc new file mode 100644 index 0000000..f25dc44 --- /dev/null +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc @@ -0,0 +1,209 @@ +//===----- gemm.inc - Lowering Gemm Op ------------------------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the ONNX Gemm Operator to Krnl dialect. +// +//===----------------------------------------------------------------------===// + +struct ONNXGemmOpLowering : public ConversionPattern { + ONNXGemmOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXGemmOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto tensorType = (*op->result_type_begin()).cast(); + auto loc = op->getLoc(); + + Value A, B, C; + A = operands[0]; + B = operands[1]; + C = operands[2]; + + auto alphaAttr = FloatAttr::get(tensorType.getElementType(), + llvm::dyn_cast(op).alpha().convertToFloat()); + auto betaAttr = FloatAttr::get(tensorType.getElementType(), + llvm::dyn_cast(op).beta().convertToFloat()); + auto alpha = rewriter.create(loc, alphaAttr); + auto beta = rewriter.create(loc, betaAttr); + + bool isTransA = (llvm::dyn_cast(op).transA() != 0); + bool isTransB = (llvm::dyn_cast(op).transB() != 0); + + // Result type + auto memRefType = convertTensorToMemRef(tensorType); + + // Insert an allocation and deallocation for the result of this operation. + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else { + auto memRefShape = memRefType.getShape(); + SmallVector allocOperands; + if (memRefShape[0] < 0) { + auto dim = rewriter.create(loc, A, (isTransA) ? 1 : 0); + allocOperands.emplace_back(dim); + } + if (memRefShape[1] < 0) { + auto dim = rewriter.create(loc, B, (isTransB) ? 0 : 1); + allocOperands.emplace_back(dim); + } + alloc = rewriter.create(loc, memRefType, allocOperands); + if (insertDealloc) { + auto *parentBlock = alloc.getDefiningOp()->getBlock(); + auto dealloc = rewriter.create(loc, alloc); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + } + } + + // Number of loops + auto memRefShape = memRefType.getShape(); + int64_t numLoops = 3; + + // Define loops. + std::vector originalLoops; + std::vector optimizedLoops; + Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, + optimizedLoops, numLoops); + + // We have two Krnl loops: + // - Outer loop iterates over the output matrix dimensions, and + // - Reduction loop iterates over the reduction dimension. + + // Outer loop + std::vector outerLoops, optimizedOuterLoops; + outerLoops.reserve(2); + optimizedOuterLoops.reserve(2); + for (int i = 0; i < 2; ++i) { + outerLoops.push_back(originalLoops[i]); + optimizedOuterLoops.push_back(optimizedLoops[i]); + } + KrnlIterateOperandPack outerPack(rewriter, outerLoops, + optimizedOuterLoops); + // Induction variables for the outer loops + for (int i = 0; i < 2; ++i) + addDimensionToPack(rewriter, loc, outerPack, alloc, i); + + // Reduction loop + std::vector reductionLoops, optimizedReductionLoops; + reductionLoops.reserve(1); + optimizedReductionLoops.reserve(1); + reductionLoops.push_back(originalLoops[2]); + optimizedReductionLoops.push_back(optimizedLoops[2]); + KrnlIterateOperandPack reductionPack(rewriter, reductionLoops, + optimizedReductionLoops); + // Induction variable for the reduction dimension + // Try to find and use a static value from A or B first. + // If it failed then use a dynamic value. + auto ATy = A.getType().cast(); + auto BTy = B.getType().cast(); + int64_t K_A_Idx = (isTransA) ? 0 : 1; + int64_t K_B_Idx = (isTransB) ? 1 : 0; + reductionPack.pushConstantBound(0); + if (ATy.getShape()[K_A_Idx] != -1) + reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]); + else + if (BTy.getShape()[K_B_Idx] != -1) + reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]); + else + reductionPack.pushOperandBound( + rewriter.create(loc, B, K_B_Idx).getResult()); + + // Get run-time dimension information for unknown dimensions used for + // broadcasting. + // GemmOp supports unidirectional broadcasting from C to A*B. + // Hence, it must be enough to get broadcasting information for C only. + std::map broadcastedDimInfo; + auto shape = C.getType().cast().getShape(); + for (int i = 0; i < shape.size(); ++i) { + if (shape[i] < 0) { + auto dim = rewriter.create(loc, C, i).getResult(); + auto one = rewriter.create(loc, 1); + auto isBroadcasted = + rewriter.create(loc, CmpIPredicate::eq, dim, one); + broadcastedDimInfo.insert(std::make_pair(i, isBroadcasted)); + } + } + + auto outerIterateOp = rewriter.create(loc, outerPack); + + // Now perform the insertions into the body of the + // just generated instructions: + + // No optimization + rewriter.setInsertionPointToEnd(optimizationBlock); + rewriter.create(loc, originalLoops); + + // Insert instructions inside the outer loop. + Block &outerIterationBlock = outerIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&outerIterationBlock); + + // Induction variables + SmallVector loopMNIVs; + for (auto arg : outerIterationBlock.getArguments()) { + loopMNIVs.emplace_back(arg); + } + + // Initialize the output of A*B + auto zero = rewriter.create( + loc, FloatAttr::get(memRefType.getElementType(), 0)); + rewriter.create(loc, zero, alloc, loopMNIVs); + + // Compute A*B + auto matmulIterateOp = rewriter.create(loc, reductionPack); + + // Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting) + auto loopCIVs = getLoopIVsForBroadcasting( + loc, rewriter, loopMNIVs, C, broadcastedDimInfo); + auto loadedC = rewriter.create(loc, C, loopCIVs); + auto loadedAB = rewriter.create(loc, alloc, loopMNIVs); + auto alphaAB = rewriter.create(loc, alpha, loadedAB); + auto betaC = rewriter.create(loc, beta, loadedC); + auto Y = rewriter.create(loc, alphaAB, betaC); + rewriter.create(loc, Y, alloc, loopMNIVs); + + // Insert instructions to do matrix multiplication: A*B + Block &matmulIterationBlock = matmulIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&matmulIterationBlock); + + // Induction variables + SmallVector loopKIVs, loopAIVs, loopBIVs; + for (auto arg : matmulIterationBlock.getArguments()) + loopKIVs.emplace_back(arg); + if (isTransA) { + loopAIVs.emplace_back(loopKIVs[0]); + loopAIVs.emplace_back(loopMNIVs[0]); + } else { + loopAIVs.emplace_back(loopMNIVs[0]); + loopAIVs.emplace_back(loopKIVs[0]); + } + if (isTransB) { + loopBIVs.emplace_back(loopMNIVs[1]); + loopBIVs.emplace_back(loopKIVs[0]); + } else { + loopBIVs.emplace_back(loopKIVs[0]); + loopBIVs.emplace_back(loopMNIVs[1]); + } + + // Matmul computation + auto loadedA = rewriter.create(loc, A, loopAIVs); + auto loadedB = rewriter.create(loc, B, loopBIVs); + auto loadedY = rewriter.create(loc, alloc, loopMNIVs); + auto AB = rewriter.create(loc, loadedA, loadedB); + auto accumulated = rewriter.create(loc, loadedY, AB); + rewriter.create(loc, accumulated, alloc, loopMNIVs); + + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +void populateLoweringONNXGemmOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/math/matmul.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/math/matmul.inc new file mode 100644 index 0000000..5c6ebd7 --- /dev/null +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/math/matmul.inc @@ -0,0 +1,345 @@ +//===----- matmul.inc - Lowering Matmul Op --------------------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the ONNX Matmul Operator to Krnl dialect. +// +//===----------------------------------------------------------------------===// + +struct ONNXMatMulOpLowering : public ConversionPattern { + ONNXMatMulOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto tensorType = (*op->result_type_begin()).cast(); + auto loc = op->getLoc(); + + Value A = operands[0]; + Value B = operands[1]; + auto AShape = A.getType().cast().getShape(); + auto BShape = B.getType().cast().getShape(); + + // There are three cases related to the shapes of the two arguments: + // - Both arguments are N-D, N >= 2 + // - Either argument is 1-D, the other is N-D, N >= 2 + // - Both arguments are 1-D + + // Result type + auto memRefType = convertTensorToMemRef(tensorType); + auto elementType = memRefType.getElementType(); + auto memRefShape = memRefType.getShape(); + + // A value zero + Value zero; + if (elementType.isa()) { + zero = rewriter.create( + loc, IntegerAttr::get(memRefType.getElementType(), 0)); + } else if (elementType.isa()) { + zero = rewriter.create( + loc, FloatAttr::get(memRefType.getElementType(), 0)); + } else { + emitError(loc, "unsupported element type"); + } + + // Insert an allocation and deallocation for the result of this operation. + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else { + SmallVector allocOperands; + if (AShape.size() >= 2 && BShape.size() >= 2) { + // Both arguments are N-D, N >= 2 + // (s1 x s2 x... x sK x M x K) MATMUL (K x N) + // => + // (s1 x s2 x... x sK x M x N) + for (int i = 0; i < memRefShape.size() - 2; ++i) { + if (memRefShape[i] < 0) { + if ((AShape.size() == 2) && (BShape.size() > 2)) + allocOperands.emplace_back(rewriter.create(loc, B, i)); + else if ((AShape.size() > 2) && (BShape.size() == 2)) + allocOperands.emplace_back(rewriter.create(loc, A, i)); + } + } + if (memRefShape[memRefShape.size() - 2] < 0) { + auto dim = rewriter.create(loc, A, memRefShape.size() - 2); + allocOperands.emplace_back(dim); + } + if (memRefShape[memRefShape.size() - 1] < 0) { + auto dim = rewriter.create(loc, B, memRefShape.size() - 1); + allocOperands.emplace_back(dim); + } + } else if (AShape.size() == 1 && BShape.size() >= 2) { + // Either argument is 1-D + // K MATMUL (s1 x s2 x... x sK x K x N) + // => + // (s1 x s2 x... x sK x N) + for (int i = 0; i < memRefShape.size() - 1; ++i) { + if (memRefShape[i] < 0) { + auto dim = rewriter.create(loc, B, i); + allocOperands.emplace_back(dim); + } + } + if (memRefShape[memRefShape.size() - 1] < 0) { + auto dim = rewriter.create(loc, B, BShape.size() - 1); + allocOperands.emplace_back(dim); + } + } else if (AShape.size() >= 2 && BShape.size() == 1) { + // Either argument is 1-D + // (s1 x s2 x... x sK x M x K) MATMUL K + // => + // (s1 x s2 x... x sK x M) + for (int i = 0; i < memRefShape.size() - 1; ++i) { + if (memRefShape[i] < 0) { + auto dim = rewriter.create(loc, A, i); + allocOperands.emplace_back(dim); + } + } + if (memRefShape[memRefShape.size() - 1] < 0) { + auto dim = rewriter.create(loc, A, AShape.size() - 2); + allocOperands.emplace_back(dim); + } + } else if (AShape.size() == 1 && BShape.size() == 1) { + // Both arguments are 1-D + if (memRefShape[0] < 0) { + auto dim = rewriter.create(loc, A, 0); + allocOperands.emplace_back(dim); + } + } else { + emitError(loc, "Invalid shapes"); + } + + alloc = rewriter.create(loc, memRefType, allocOperands); + } + + if (AShape.size() >= 2 || BShape.size() >= 2) { + // Cases 1 and 2: + // - Both arguments are N-D, N >= 2 + // - Either argument is 1-D, the other is N-D, N >= 2 + + // Define loops for batch dimensions. + std::vector originalLoops; + std::vector optimizedLoops; + Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, + optimizedLoops, memRefShape.size()); + + // Outer KrnlIterateOp + SmallVector loopBatchIVs; + bool hasBatchLoop = false; + if (AShape.size() > 2 || BShape.size() > 2) { + SmallVector batchAxes; + int matmulResultDims = + ((AShape.size() == 1 || BShape.size() == 1)) ? 1 : 2; + for (int i = 0; i < memRefShape.size() - matmulResultDims; ++i) + batchAxes.emplace_back(i); + + std::vector outerLoops, optimizedOuterLoops; + outerLoops.reserve(batchAxes.size()); + optimizedOuterLoops.reserve(batchAxes.size()); + for (int i = 0; i < batchAxes.size(); ++i) { + outerLoops.push_back(originalLoops[i]); + optimizedOuterLoops.push_back(optimizedLoops[i]); + } + KrnlIterateOperandPack outerPack(rewriter, outerLoops, + optimizedOuterLoops); + for (int i = 0; i < batchAxes.size(); ++i) { + addDimensionToPack(rewriter, loc, outerPack, alloc, i); + } + auto outerIterateOp = rewriter.create(loc, outerPack); + + // No optimization + rewriter.setInsertionPointToEnd(optimizationBlock); + rewriter.create(loc, originalLoops); + + // Insert instructions into the outer KrnlIterateOp. + Block &outerIterationBlock = outerIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&outerIterationBlock); + + // Induction variables: non-matrix-multiplication variables. + for (auto arg : outerIterationBlock.getArguments()) { + loopBatchIVs.emplace_back(arg); + } + + hasBatchLoop = true; + } + + // Now, we define loops for matrix multiplication. + + // Create a KrnlIterateOp for matrix multiplication. + KrnlIterateOp matmulIterateOp; + std::vector matmulLoops, optimizedMatmulLoops; + if (AShape.size() >= 2 && BShape.size() >= 2) { + // 2-D x 2-D. Result has two dimensions. + matmulLoops.reserve(2); + optimizedMatmulLoops.reserve(2); + for (int i = 2; i > 0; --i) { + matmulLoops.emplace_back(originalLoops[memRefShape.size() - i]); + optimizedMatmulLoops.emplace_back( + optimizedLoops[memRefShape.size() - i]); + } + KrnlIterateOperandPack matmulPack(rewriter, matmulLoops, + optimizedMatmulLoops); + for (int i = 2; i > 0; --i) { + addDimensionToPack(rewriter, loc, matmulPack, alloc, + memRefShape.size() - i); + } + matmulIterateOp = rewriter.create(loc, matmulPack); + } else { + // 1-D x 2-D, and vice versa. Result has one dimension. + matmulLoops.reserve(1); + optimizedMatmulLoops.reserve(1); + matmulLoops.emplace_back(originalLoops[memRefShape.size() - 1]); + optimizedMatmulLoops.emplace_back( + optimizedLoops[memRefShape.size() - 1]); + KrnlIterateOperandPack matmulPack(rewriter, matmulLoops, + optimizedMatmulLoops); + addDimensionToPack(rewriter, loc, matmulPack, alloc, + memRefShape.size() - 1); + matmulIterateOp = rewriter.create(loc, matmulPack); + } + + if (!hasBatchLoop) { + // No optimization + rewriter.setInsertionPointToEnd(optimizationBlock); + rewriter.create(loc, originalLoops); + } + + // Insert instructions into the matmul KrnlIterateOp. + Block &matmulIterationBlock = matmulIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&matmulIterationBlock); + + // Induction variables: M, N + SmallVector loopMNIVs; + for (auto arg : matmulIterationBlock.getArguments()) { + loopMNIVs.emplace_back(arg); + } + // Induction variables for the final result. + SmallVector loopBatchMNIVs; + for (auto arg : loopBatchIVs) { + loopBatchMNIVs.emplace_back(arg); + } + for (auto arg : loopMNIVs) { + loopBatchMNIVs.emplace_back(arg); + } + + // Fill the output with value 0. + rewriter.create(loc, zero, alloc, loopBatchMNIVs); + + // Iterate along the reduction dimension. + // Use a value from A. + std::vector reduceLoops; + std::vector optimizedReduceLoops; + Block *optimizationReduceBlock = + defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1); + KrnlIterateOperandPack reducePack(rewriter, reduceLoops, + optimizedReduceLoops); + addDimensionToPack(rewriter, loc, reducePack, A, AShape.size() - 1); + auto reduceIterateOp = rewriter.create(loc, reducePack); + + // No optimization + rewriter.setInsertionPointToEnd(optimizationReduceBlock); + rewriter.create(loc, reduceLoops); + + // Insert instructions into the reduction KrnlIterateOp. + Block &reduceIterationBlock = reduceIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&reduceIterationBlock); + + // Induction variables + SmallVector loopKIVs, loopBatchMKIVs, loopBatchKNIVs; + // K + loopKIVs.emplace_back(reduceIterationBlock.getArguments()[0]); + // MK + if (AShape.size() > 2) + for (auto arg : loopBatchIVs) + loopBatchMKIVs.emplace_back(arg); + if (AShape.size() >= 2) + loopBatchMKIVs.emplace_back(loopMNIVs[0]); + loopBatchMKIVs.emplace_back(loopKIVs[0]); + // KN + if (BShape.size() > 2) + for (auto arg : loopBatchIVs) + loopBatchKNIVs.emplace_back(arg); + loopBatchKNIVs.emplace_back(loopKIVs[0]); + if (BShape.size() >= 2) + if (AShape.size() >= 2) + loopBatchKNIVs.emplace_back(loopMNIVs[1]); + else + loopBatchKNIVs.emplace_back(loopMNIVs[0]); + + // Matmul computation + auto loadedA = rewriter.create(loc, A, loopBatchMKIVs); + auto loadedB = rewriter.create(loc, B, loopBatchKNIVs); + auto loadedY = rewriter.create(loc, alloc, loopBatchMNIVs); + if (elementType.isa()) { + auto AB = rewriter.create(loc, loadedA, loadedB); + auto accumulated = rewriter.create(loc, loadedY, AB); + rewriter.create(loc, accumulated, alloc, loopBatchMNIVs); + } else if (elementType.isa()) { + auto AB = rewriter.create(loc, loadedA, loadedB); + auto accumulated = rewriter.create(loc, loadedY, AB); + rewriter.create(loc, accumulated, alloc, loopBatchMNIVs); + } + } else if ((AShape.size() == 1) && (BShape.size() == 1)) { + // Case 3: + // - Both arguments are 1-D + + // Fill the output with value 0. + Value zeroIndex = rewriter.create(loc, 0); + rewriter.create(loc, zero, alloc, zeroIndex); + + // Iterate along the reduction dimension. + // Use a value from A. + std::vector reduceLoops; + std::vector optimizedReduceLoops; + Block *optimizationReduceBlock = + defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1); + KrnlIterateOperandPack reducePack(rewriter, reduceLoops, + optimizedReduceLoops); + addDimensionToPack(rewriter, loc, reducePack, A, 0); + auto reduceIterateOp = rewriter.create(loc, reducePack); + + // No optimization + rewriter.setInsertionPointToEnd(optimizationReduceBlock); + rewriter.create(loc, reduceLoops); + + // Insert instructions into the reduction KrnlIterateOp. + Block &reduceIterationBlock = reduceIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&reduceIterationBlock); + + // Induction variables + SmallVector loopKIVs; + // K + loopKIVs.emplace_back(reduceIterationBlock.getArgument(0)); + + // Matmul computation + auto loadedA = rewriter.create(loc, A, loopKIVs); + auto loadedB = rewriter.create(loc, B, loopKIVs); + auto loadedY = rewriter.create(loc, alloc, zeroIndex); + if (elementType.isa()) { + auto AB = rewriter.create(loc, loadedA, loadedB); + auto accumulated = rewriter.create(loc, loadedY, AB); + rewriter.create(loc, accumulated, alloc, zeroIndex); + } else if (elementType.isa()) { + auto AB = rewriter.create(loc, loadedA, loadedB); + auto accumulated = rewriter.create(loc, loadedY, AB); + rewriter.create(loc, accumulated, alloc, zeroIndex); + } + } else { + // No scalar matrix multiplication. + llvm_unreachable("Unsupported scalar matrix multiplication."); + } + + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +void populateLoweringONNXMatMulOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/math/reduction.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/math/reduction.inc new file mode 100644 index 0000000..27f594e --- /dev/null +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/math/reduction.inc @@ -0,0 +1,307 @@ +//===----- reduction.inc - Lowering Reduction Ops -------------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the ONNX Reduction Operators to Krnl dialect. +// +//===----------------------------------------------------------------------===// + +// Identity values +template <> +float getIdentityValue(){ + return (float)-std::numeric_limits::infinity(); +} + +template <> +int getIdentityValue(){ + return std::numeric_limits::min(); +} + +template <> +float getIdentityValue(){ + return (float)std::numeric_limits::infinity(); +} + +template <> +int getIdentityValue(){ + return std::numeric_limits::max(); +} + +template <> +float getIdentityValue(){ + return (float)1.0; +} + +template <> +int getIdentityValue(){ + return 1; +} + +template <> +float getIdentityValue(){ + return (float)0; +} + +template <> +int getIdentityValue(){ + return 0; +} + +// Scalar ops +template <> +struct ScalarOp { + using FOp = MulFOp; + using IOp = MulIOp; +}; + +template <> +struct ScalarOp { + using FOp = AddFOp; + using IOp = AddIOp; +}; + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXReduceMaxOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, + ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + auto loc = op->getLoc(); + Value lhs = operands[0]; + Value rhs = operands[1]; + Type element_type = lhs.getType(); + if (element_type.isa()) { + auto max = rewriter.create(loc, CmpIPredicate::sgt, lhs, rhs); + auto result = rewriter.create(loc, max, lhs, rhs); + return result; + } else if (element_type.isa()) { + auto max = rewriter.create(loc, CmpFPredicate::OGT, lhs, rhs); + auto result = rewriter.create(loc, max, lhs, rhs); + return result; + } else { + emitError(loc, "unsupported element type"); + return nullptr; + } +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXReduceMinOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, + ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + auto loc = op->getLoc(); + Value lhs = operands[0]; + Value rhs = operands[1]; + Type element_type = lhs.getType(); + if (element_type.isa()) { + auto min = rewriter.create(loc, CmpIPredicate::slt, lhs, rhs); + auto result = rewriter.create(loc, min, lhs, rhs); + return result; + } else if (element_type.isa()) { + auto min = rewriter.create(loc, CmpFPredicate::OLT, lhs, rhs); + auto result = rewriter.create(loc, min, lhs, rhs); + return result; + } else { + emitError(loc, "unsupported element type"); + return nullptr; + } +} + +template +struct ONNXReductionOpLowering : public ConversionPattern { + ONNXReductionOpLowering(MLIRContext *ctx) + : ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + /* + * Condition: reduction function must be associative and commutative. + * + * Example 1 (here, reduction function is `+`): + * Induction variables: (i0, i1, i2) + * axes = [0, 2] + * keepdims = true + * krnl.iterate() with (i0, i1, i2) { + * Y(0, i1, 0) += X(i0, i1, i2) + * } + * + * Example 2 (here, reduction function is `+`): + * Induction variables: (i0, i1, i2) + * axes = [0, 2] + * keepdims = false + * krnl.iterate() with (i0, i1, i2) { + * Y(i1) += X(i0, i1, i2) + * } + * + */ + auto loc = op->getLoc(); + auto memRefInType = operands[0].getType().cast(); + auto memRefInShape = memRefInType.getShape(); + auto tensorOutType = (*op->result_type_begin()).cast(); + int64_t inRank = memRefInType.getRank(); + int64_t outRank = tensorOutType.getRank(); + + // Get attributes + ArrayAttr axisAttrs = llvm::dyn_cast(op).axesAttr(); + std::vector axes; + if (axisAttrs) { + for (auto axisAttr : axisAttrs.getValue()) { + int64_t axis = axisAttr.cast().getInt(); + axis = axis >= 0 ? axis : (inRank + axis); + assert(axis >= -inRank && axis <= inRank - 1); + if (std::find(axes.begin(), axes.end(), axis) == axes.end()) + axes.push_back(axis); + } + } else { + for (decltype(inRank) i = 0; i < inRank; ++i) { + axes.push_back(i); + } + } + // KeepDims + auto keepdims = + llvm::dyn_cast(op).keepdims(); + bool isKeepdims = (keepdims == 1) ? true : false; + + // Get type information + auto memRefOutType = convertTensorToMemRef(tensorOutType); + auto memRefOutShape = memRefOutType.getShape(); + auto elementOutType = memRefOutType.getElementType(); + std::map outInDimMap = + getReductionMapping(memRefInType, axes, isKeepdims); + + // Insert an allocation and deallocation for the result of this operation. + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + if (hasAllConstantDimensions(memRefOutType)) { + alloc = insertAllocAndDealloc(memRefOutType, loc, rewriter, insertDealloc); + } else { + SmallVector allocOperands; + for (decltype(outRank) i = 0; i < outRank; ++i) { + if (memRefOutShape[i] < 0) { + auto dim = rewriter.create(loc, operands[0], outInDimMap[i]); + allocOperands.push_back(dim); + } + } + alloc = rewriter.create(loc, memRefOutType, allocOperands); + if (insertDealloc) { + auto *parentBlock = alloc.getDefiningOp()->getBlock(); + auto dealloc = rewriter.create(loc, alloc); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + } + } + + // There are two Krnl loops: + // - One to initialize the result memref, and + // - One to do reduction + + // Define loops to initialize the result. + std::vector originalLoopsInit; + std::vector optimizedLoopsInit; + Block *optimizationBlockInit = defineLoops(rewriter, loc, originalLoopsInit, + optimizedLoopsInit, outRank); + + // Iteration information + KrnlIterateOperandPack packInit(rewriter, originalLoopsInit, + optimizedLoopsInit); + for (decltype(outRank) i = 0; i < outRank; ++i) { + addDimensionToPack(rewriter, loc, packInit, alloc, i); + } + auto iterateOpInit = rewriter.create(loc, packInit); + Block &iterationBlockInit = iterateOpInit.bodyRegion().front(); + + // Perform the insertions into the body of the initialization loop. + // No optimization + rewriter.setInsertionPointToEnd(optimizationBlockInit); + rewriter.create(loc, originalLoopsInit); + + // Insert instructions inside the KernelIterateOp body. + rewriter.setInsertionPointToStart(&iterationBlockInit); + + // Handle the operation: + SmallVector loopIVs; + for (auto arg : iterationBlockInit.getArguments()) { + loopIVs.push_back(arg); + } + + Value identity; + if (elementOutType.isa()) { + identity = rewriter.create( + loc, FloatAttr::get(elementOutType, + getIdentityValue())); + } else if (elementOutType.isa()) { + identity = rewriter.create( + loc, IntegerAttr::get(elementOutType, + getIdentityValue())); + } else { + emitError(loc, "unsupported element type"); + } + rewriter.create(loc, identity, alloc, loopIVs); + + // Define an Krnl loop to do reduction. + rewriter.setInsertionPointAfter(iterateOpInit); + std::vector originalLoops, optimizedLoops; + Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, + optimizedLoops, inRank); + // Iteration information + KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); + for (decltype(inRank) i = 0; i < inRank; ++i) { + addDimensionToPack(rewriter, loc, pack, operands[0], i); + } + auto iterateOp = rewriter.create(loc, pack); + Block &iterationBlock = iterateOp.bodyRegion().front(); + + // Perform the insertions into the body of the reduction loop. + // No optimization + rewriter.setInsertionPointToEnd(optimizationBlock); + rewriter.create(loc, originalLoops); + + // Insert instructions inside the KernelIterateOp body. + rewriter.setInsertionPointToStart(&iterationBlock); + + // Handle the operation: + SmallVector inLoopIVs, outLoopIVs; + auto args = iterationBlock.getArguments(); + for (int i = 0; i < args.size(); ++i) { + inLoopIVs.push_back(args[i]); + } + Value zeroIndex = nullptr; + for (decltype(inRank) i = 0; i < outRank; ++i) { + if (outInDimMap.find(i) != outInDimMap.end()) { + outLoopIVs.push_back(inLoopIVs[outInDimMap[i]]); + } else { + if (zeroIndex) { + outLoopIVs.push_back(zeroIndex); + } else { + zeroIndex = rewriter.create(loc, 0); + outLoopIVs.push_back(zeroIndex); + } + } + } + + Value next, accumulated; + next = rewriter.create(loc, operands[0], inLoopIVs); + accumulated = rewriter.create(loc, alloc, outLoopIVs); + accumulated = mapToLowerScalarOp( + op, memRefOutType.getElementType(), {accumulated, next}, rewriter); + rewriter.create(loc, accumulated, alloc, outLoopIVs); + + rewriter.replaceOp(op, alloc); + return matchSuccess(); + } +}; + +void populateLoweringONNXReductionOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert, + ONNXReductionOpLowering, + ONNXReductionOpLowering, + ONNXReductionOpLowering>(ctx); +} diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/math/softmax.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/math/softmax.inc new file mode 100644 index 0000000..eb126c0 --- /dev/null +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/math/softmax.inc @@ -0,0 +1,205 @@ +//===----- softmax.inc - Softmax Op ---------------------------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers ONNX softmax operator to Krnl dialect. +// +//===----------------------------------------------------------------------===// + +struct ONNXSoftmaxOpLowering : public ConversionPattern { + ONNXSoftmaxOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {} + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // softmax(x) = let max_x = max(x) in + // let exp_x = exp(x - max_x) in + // let sum = sum(exp_x) in + // exp_x / sum + auto tensorType = (*op->result_type_begin()).cast(); + int64_t rank = tensorType.getRank(); + int64_t axis = llvm::dyn_cast(op).axis().getSExtValue(); + axis = axis >= 0 ? axis : rank + axis; + assert(axis >= -rank && axis <= rank - 1); + + auto loc = op->getLoc(); + + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertTensorToMemRef(tensorType); + auto elementType = memRefType.getElementType(); + + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, + operands[0]); + + // Shape of the result + auto memRefShape = memRefType.getShape(); + + // Insert allocations and deallocations for sum and max. + MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0); + Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true); + Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true); + Value zero = + rewriter.create(loc, FloatAttr::get(elementType, 0)); + Value negInfinity = rewriter.create( + loc, + FloatAttr::get(elementType, -std::numeric_limits::infinity())); + + // Define loops. + std::vector originalLoops; + std::vector optimizedLoops; + Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, + optimizedLoops, rank); + + // Coerce the input into a 2-D tensor. `axis` will be the coercing point. + // This coercing follows the softmax definition in ONNX: + // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softmax + // Here, we create an outer loop and inner loop for handling the two + // dimensions. The outer loop is only created once `axis` is not zero. + + // Define an outer loop with respect to axis. + std::vector outerLoops, optimizedOuterLoops; + outerLoops.reserve(axis); + optimizedOuterLoops.reserve(axis); + for (int i = 0; i < axis; ++i) { + outerLoops.push_back(originalLoops[i]); + optimizedOuterLoops.push_back(optimizedLoops[i]); + } + KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops); + for (int i = 0; i < axis; ++i) + addDimensionToPack(rewriter, loc, outerPack, operands[0], i); + + // Define an inner loop with respect to axis. + std::vector innerLoops, optimizedInnerLoops; + innerLoops.reserve(rank - axis); + optimizedInnerLoops.reserve(rank - axis); + for (int i = axis; i < rank; ++i) { + innerLoops.push_back(originalLoops[i]); + optimizedInnerLoops.push_back(optimizedLoops[i]); + } + KrnlIterateOperandPack innerPack(rewriter, innerLoops, optimizedInnerLoops); + for (int i = axis; i < rank; ++i) + addDimensionToPack(rewriter, loc, innerPack, operands[0], i); + + KrnlIterateOp outerIterateOp, maxIterateOp, sumIterateOp, softmaxIterateOp; + SmallVector outerLoopIVs; + if (axis != 0) { + outerIterateOp = rewriter.create(loc, outerPack); + + // No optimization + rewriter.setInsertionPointToEnd(optimizationBlock); + rewriter.create(loc, originalLoops); + + // Insert instructions inside the outer loop. + Block &outerIterationBlock = outerIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&outerIterationBlock); + for (auto arg : outerIterationBlock.getArguments()) + outerLoopIVs.push_back(arg); + + // Reset accumulators. + rewriter.create(loc, zero, sumOp); + rewriter.create(loc, negInfinity, maxOp); + + // Create an inner loop to compute max. + maxIterateOp = rewriter.create(loc, innerPack); + // Create an inner loop to compute sum. + sumIterateOp = rewriter.create(loc, innerPack); + // Create an inner loop to compute softmax. + softmaxIterateOp = rewriter.create(loc, innerPack); + } else { + // Reset accumulators. + rewriter.create(loc, zero, sumOp); + rewriter.create(loc, negInfinity, maxOp); + + // Create an inner loop to compute max. + maxIterateOp = rewriter.create(loc, innerPack); + // Create an inner loop to compute sum. + sumIterateOp = rewriter.create(loc, innerPack); + // Create an inner loop to compute softmax. + softmaxIterateOp = rewriter.create(loc, innerPack); + + // No optimization + rewriter.setInsertionPointToEnd(optimizationBlock); + rewriter.create(loc, originalLoops); + } + + // Insert instructions inside the max loop. + Block &maxIterationBlock = maxIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&maxIterationBlock); + + // Get induction variables. + SmallVector maxLoopIVs; + for (auto arg : outerLoopIVs) + maxLoopIVs.push_back(arg); + for (auto arg : maxIterationBlock.getArguments()) + maxLoopIVs.push_back(arg); + + // Compute the max value. + Value max = rewriter.create(loc, maxOp); + Value nextMax = rewriter.create(loc, operands[0], maxLoopIVs); + auto maxCond = + rewriter.create(loc, CmpFPredicate::OGT, max, nextMax); + max = rewriter.create(loc, maxCond, max, nextMax); + rewriter.create(loc, max, maxOp); + + // Get the max. + rewriter.setInsertionPoint(sumIterateOp); + max = rewriter.create(loc, maxOp); + + // Insert instructions inside the sum loop. + Block &sumIterationBlock = sumIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&sumIterationBlock); + + // Get induction variables. + SmallVector sumLoopIVs; + for (auto arg : outerLoopIVs) + sumLoopIVs.push_back(arg); + for (auto arg : sumIterationBlock.getArguments()) + sumLoopIVs.push_back(arg); + + // Sum up values. + Value sum = rewriter.create(loc, sumOp); + Value next = rewriter.create(loc, operands[0], sumLoopIVs); + Value sub = rewriter.create(loc, next, max); + Value exp = rewriter.create(loc, sub); + sum = rewriter.create(loc, sum, exp); + rewriter.create(loc, sum, sumOp); + // Store intermediate values in the result to avoid recomputation. + rewriter.create(loc, exp, alloc, sumLoopIVs); + + // Get the sum. + rewriter.setInsertionPoint(softmaxIterateOp); + sum = rewriter.create(loc, sumOp); + + // Insert instructions inside the softmax loop. + Block &softmaxIterationBlock = softmaxIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&softmaxIterationBlock); + + // Get induction variables. + SmallVector softmaxLoopIVs; + for (auto arg : outerLoopIVs) + softmaxLoopIVs.push_back(arg); + for (auto arg : softmaxIterationBlock.getArguments()) + softmaxLoopIVs.push_back(arg); + + // Compute softmax. + Value expLoadedVal = rewriter.create(loc, alloc, softmaxLoopIVs); + Value result = rewriter.create(loc, expLoadedVal, sum); + rewriter.create(loc, result, alloc, softmaxLoopIVs); + + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +void populateLoweringONNXSoftmaxOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc new file mode 100644 index 0000000..3ecfa3e --- /dev/null +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc @@ -0,0 +1,282 @@ +//===----- conv.inc - Lowering Convolution Op -----------------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the ONNX Convolution Operators to Krnl dialect. +// +//===----------------------------------------------------------------------===// + +struct ONNXConvNoBiasOpLowering : public ConversionPattern { + ONNXConvNoBiasOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto tensorType = (*op->result_type_begin()).cast(); + auto loc = op->getLoc(); + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertTensorToMemRef(tensorType); + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + ONNXConvNoBiasOp convOp = llvm::dyn_cast(op); + + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, + {operands[0]}); + + auto resultShape = memRefType.getShape(); + auto inputShape = operands[0].getType().cast().getShape(); + auto kernelShape = operands[1].getType().cast().getShape(); + + // R = ConvNoBias(D, K) + // + // The input/output shapes will look like this: + // + // D (NxCxHxW) x K (MxC/groupxKHxKW) -> R (NxMxRHxRW) + // + // M is a multiple of the number of groups: + // M = group * kernelsPerGroup + // + // The loop nest will look as follows: + // + // strides = [s1, s2] + // + // kernelsPerGroup = M / group; + // for n = 0 .. N: + // for g = 0 .. group: + // for m = 0 .. kernelsPerGroup: + // kernel = g * kernelsPerGroup + m; + // for r1 = 0 .. RH: + // for r2 = 0 .. RW: + // R[n][kernel][r1][r2] = 0; + // for c = 0 .. C/group: + // for k1 = 0 .. KH: + // for k2 = 0 .. KW: + // R[n][kernel][r1][r2] = + // D[n][g * (C / group) + c][s1 * r1 + k1][s2 * r2 + k2] * + // K[kernel][c][k1][k2]; + // + // Naming: + // n, g, m: outer loop nest indices + // r1, r2: spatial loop nest indices + // c, k1, k2: inner loop nest indices + // + // TODO: handle padding. + // + // In the general case: + // + // D (NxCxD1xD2x...xDdim) x K (MxC/groupxK1xK2x...xKdim) + // -> R (NxMxR1xR2x...xRdim) + // + // The above loop nest can be adapted by increasing the number + // of r- and k-index loop i.e. r1 r2 and k1 k2 loops. + + // Set up outermost loops: n g m r1 r2 ... rdim + // Skip g if group is 1. + + // Before we start the iteration we need to compute the number of + // unsplit kernels and fetch the number of groups from the attribute + // list. Group is always a compilation constant. + int64_t group = convOp.group().getSExtValue(); + // Compute the number of unsplit kernels. The number of kernels + // must be a multiple of the number of groups. + int64_t kernelsPerGroup = floor(kernelShape[0] / group); + auto kernelsPerGroupValue = + rewriter.create(loc, kernelsPerGroup); + auto zero = rewriter.create( + loc, FloatAttr::get(memRefType.getElementType(), 0)); + Value subchannels; + if (kernelShape[1] < 0) { + subchannels = + rewriter.create(loc, operands[1], 1).getResult(); + } else { + subchannels = rewriter.create( + loc, kernelShape[1]); + } + + // 1. Define outer loops and emit empty optimization block: + int64_t nOuterLoops = (group > 1) ? 3 : 2; + std::vector outerLoops; + std::vector optimizedOuterLoops; + Block *optimizationBlock = defineLoops(rewriter, loc, outerLoops, + optimizedOuterLoops, nOuterLoops); + + // Prepare iteration arguments over outer loop nest. + KrnlIterateOperandPack pack( + rewriter, outerLoops, optimizedOuterLoops); + // for n = 0 .. N: + pack.pushConstantBound(0); + if (inputShape[0] < 0) + pack.pushOperandBound( + rewriter.create(loc, operands[0], 0).getResult()); + else + pack.pushConstantBound(inputShape[0]); + // for g = 0 .. N: + if (group > 1) { + pack.pushConstantBound(0); + pack.pushConstantBound(group); + } + // for m = 0 .. kernelsPerGroup: + pack.pushConstantBound(0); + pack.pushConstantBound(kernelsPerGroup); + // Outer loop iteration. + auto iterateOp = rewriter.create(loc, pack); + Block &outerIterationBlock = iterateOp.bodyRegion().front(); + // Emit optimizations for outer loops: + rewriter.setInsertionPointToEnd(optimizationBlock); + rewriter.create(loc, outerLoops); + rewriter.setInsertionPointToStart(&outerIterationBlock); + { + // 2. Emit the body of the outer loop nest. + + // 2.1 Compute kernel order number: kernel = g * kernelsPerGroup + m; + // If group is not set then the value of the kernel ID is + // identical to that of the loop over kernels. + Value kernel = outerIterationBlock.getArguments()[1]; + if (group > 1) { + // Middle loop is over groups and third loop is over the + // kernel identifiers in the current group. + auto kernelsOffset = rewriter.create(loc, + outerIterationBlock.getArguments()[1], + kernelsPerGroupValue); + kernel = rewriter.create(loc, kernelsOffset, + outerIterationBlock.getArguments()[2]); + } + + // 2.2 Define spatial loops + int64_t nSpatialLoops = resultShape.size() - 2; + std::vector spatialLoops; + std::vector optimizedSpatialLoops; + Block *optSpatialLoopBlock = defineLoops(rewriter, loc, spatialLoops, + optimizedSpatialLoops, nSpatialLoops); + + // 2.3 Prepare iteration arguments for spatial loop nest. + KrnlIterateOperandPack spatialPack( + rewriter, spatialLoops, optimizedSpatialLoops); + for (int i = 2; i < resultShape.size(); ++i) + addDimensionToPack(rewriter, loc, spatialPack, alloc, i); + + // 2.4 Emit loop nest over output spatial dimensions. + // for rX = 0 .. RX + auto spatialIterateOp = + rewriter.create(loc, spatialPack); + Block &spatialIterationBlock = spatialIterateOp.bodyRegion().front(); + // 2.5 Emit optimizations for outer loops: + rewriter.setInsertionPointToEnd(optSpatialLoopBlock); + rewriter.create(loc, spatialLoops); + rewriter.setInsertionPointToStart(&spatialIterationBlock); + { + // 3. Emit the body of the spatial loop nest. + // 3.1 Emit: R[n][kernel][r1][r2] = 0; + SmallVector resultIndices; + // n + resultIndices.emplace_back(outerIterationBlock.getArguments()[0]); + // kernel + resultIndices.emplace_back(kernel); + // rX + for (auto arg : spatialIterationBlock.getArguments()) + resultIndices.emplace_back(arg); + // Store initializer value into output location. + rewriter.create(loc, zero, alloc, resultIndices); + + // 3.2 Define inner loops. + int64_t nInnerLoops = 1 + (kernelShape.size() - 2); + std::vector innerLoops; + std::vector optimizedInnerLoops; + Block *optInnerLoopBlock = defineLoops(rewriter, loc, innerLoops, + optimizedInnerLoops, nInnerLoops); + + // 3.3 Prepare iteration arguments for inner loop nest. + KrnlIterateOperandPack innerPack( + rewriter, innerLoops, optimizedInnerLoops); + // for c = 0 .. C/group + innerPack.pushConstantBound(0); + innerPack.pushConstantBound(kernelShape[1]); + // for Kx = 0 .. KX + for (int i = 2; i < kernelShape.size(); ++i) + addDimensionToPack(rewriter, loc, innerPack, operands[1], i); + + // 3.4 Emit inner loop nest. + auto innerIterateOp = + rewriter.create(loc, innerPack); + Block &innerIterationBlock = innerIterateOp.bodyRegion().front(); + // 3.5 Emit optimizations for outer loops: + rewriter.setInsertionPointToEnd(optInnerLoopBlock); + rewriter.create(loc, innerLoops); + rewriter.setInsertionPointToStart(&innerIterationBlock); + { + // 4. Emit inner loop body + // R[n][kernel][r1][r2] = + // D[n][g * (C / group) + c][s1 * r1 + k1][s2 * r2 + k2] * + // K[kernel][c][k1][k2]; + + // 4.1 Prepare indices for accesing the data tensor. + SmallVector dataIndices; + // n + dataIndices.emplace_back(outerIterationBlock.getArguments()[0]); + // g * (C / group) + c + Value channelDepth = innerIterationBlock.getArguments()[0]; + if (group > 1) + channelDepth = rewriter.create(loc, channelDepth, + rewriter.create(loc, subchannels, + outerIterationBlock.getArguments()[1])); + dataIndices.emplace_back(channelDepth); + // sX * rX + kX + auto stridesAttribute = convOp.stridesAttr(); + // Read strides attribute + SmallVector strides; + if (stridesAttribute) + for (auto stride : stridesAttribute.getValue()) + strides.emplace_back(stride.cast().getInt()); + for (int i = 0; i < kernelShape.size() - 2; ++i) { + Value spatialIndex = spatialIterationBlock.getArguments()[i]; + // If strides are present then emit the correct access index. + if (stridesAttribute && strides[i] > 1) + spatialIndex = rewriter.create(loc, + rewriter.create(loc, strides[i]), + spatialIterationBlock.getArguments()[i]); + dataIndices.emplace_back( + rewriter.create(loc, spatialIndex, + innerIterationBlock.getArguments()[i+1])); + } + + // 4.2 Prepare indices for accessing the kernel tensor. + SmallVector kernelIndices; + // kernel + kernelIndices.emplace_back(kernel); + // c + kernelIndices.emplace_back(innerIterationBlock.getArguments()[0]); + // kX + for (int i = 0; i < kernelShape.size() - 2; ++i) + kernelIndices.emplace_back( + innerIterationBlock.getArguments()[i+1]); + + // 4.3 Compute convolution. + auto loadData = + rewriter.create(loc, operands[0], dataIndices); + auto loadKernel = + rewriter.create(loc, operands[1], kernelIndices); + auto loadPartialSum = + rewriter.create(loc, alloc, resultIndices); + Value result = rewriter.create(loc, loadPartialSum, + rewriter.create(loc, loadData, loadKernel)); + // 4.4 Store computed value into output location. + rewriter.create(loc, result, alloc, resultIndices); + } + } + } + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +void populateLoweringONNXConvOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/identity.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/identity.inc new file mode 100644 index 0000000..2ff1633 --- /dev/null +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/identity.inc @@ -0,0 +1,26 @@ +//===----- identity.inc - Lowering Identity Op ----------------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the ONNX Identity Operator to Krnl dialect. +// +//===----------------------------------------------------------------------===// + +struct ONNXIdentityOpLowering : public ConversionPattern { + ONNXIdentityOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + rewriter.replaceOp(op, operands[0]); + return matchSuccess(); + } +}; + +void populateLoweringONNXIdentityOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/reshape.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/reshape.inc new file mode 100644 index 0000000..ed2b185 --- /dev/null +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/reshape.inc @@ -0,0 +1,151 @@ +//===----- reshape.inc - Lowering Reshape Op ------------------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the ONNX Reshape Operator to Krnl dialect. +// +//===----------------------------------------------------------------------===// + +struct ONNXReshapeOpLowering : public ConversionPattern { + ONNXReshapeOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto tensorType = (*op->result_type_begin()).cast(); + auto inputShape = operands[0].getType().cast().getShape(); + auto loc = op->getLoc(); + + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertTensorToMemRef(tensorType); + auto memRefShape = memRefType.getShape(); + Value alloc; + + // Compute size in bytes using the input tensor. + Value tensorSize = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), + getMemRefEltSizeInBytes(memRefType))); + for (int i = 0; i < inputShape.size(); ++i) { + Value dimVal; + if (inputShape[i] < 0) { + Value dim = rewriter.create(loc, operands[0], i); + dimVal = + rewriter.create(loc, dim, rewriter.getIntegerType(64)); + } else { + dimVal = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), + inputShape[i])); + } + tensorSize = rewriter.create(loc, tensorSize, dimVal); + } + + bool insertDealloc = checkInsertDealloc(op); + if (hasAllConstantDimensions(memRefType)) { + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + } else { + // If a dimension is zero, the actual dimension value is taken from the + // input tensor. + // + // If the shape array has a negative dimension (-1), we compute its actual + // dimension value from the other dimensions. But we don't have enough + // information about the other dimensions at this point. So, we need to + // scan the shape first to calculate reduction of all of the dimensions. + // If the reduction is negative, then the shape array contains a negative + // dimension. Otherwise, the reduction is the same as the one computed + // from the input tensor. + Value tensorSizeFromShape = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), + getMemRefEltSizeInBytes(memRefType))); + SmallVector DimInfo; + for (int i = 0; i < memRefShape.size(); ++i) { + Value index = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i)); + // Load index from array of indices. + Value loadedVal = rewriter.create(loc, operands[1], index); + // If a dimension is zero, the actual dimension value is taken from the + // input tensor. + // + // If a dimension is negative, it is computed from the other dimensions. + // But we don't have enough information about the other dimensions at + // this point. So, we let it as it is (-1), and compute it later. + if (i < inputShape.size()) { + Value dimVal; + auto loadedValType = loadedVal.getType().cast(); + if (inputShape[i] < 0) { + Value dim = rewriter.create(loc, operands[0], i); + dimVal = rewriter.create(loc, dim, loadedValType); + } else { + dimVal = rewriter.create( + loc, rewriter.getIntegerAttr(loadedValType, inputShape[i])); + } + auto zero = rewriter.create( + loc, rewriter.getIntegerAttr(loadedValType, 0)); + auto isZero = + rewriter.create(loc, CmpIPredicate::eq, loadedVal, zero); + loadedVal = rewriter.create(loc, isZero, dimVal, loadedVal); + } + // Check if the loaded index is already the correct width of 64 bits. + // Convert the value to a 64 bit integer if needed. + Value int64LoadedVal = loadedVal; + if (loadedVal.getType().cast().getWidth() < 64) + int64LoadedVal = rewriter.create( + loc, loadedVal, rewriter.getIntegerType(64)); + tensorSizeFromShape = + rewriter.create(loc, tensorSizeFromShape, int64LoadedVal); + // Store intermediate results to use later. + DimInfo.emplace_back(int64LoadedVal); + } + // Reverse tensorSizeFromShape since it is negative if the shape array has + // a negative dimension. This is safe since we only use it to compute the + // actual value for the negative dimension. + auto zero = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + tensorSizeFromShape = + rewriter.create(loc, zero, tensorSizeFromShape); + + // Obtain operands for AllocOp. + SmallVector allocOperands; + auto negOne = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), -1)); + + for (int i = 0; i < memRefShape.size(); ++i) { + auto dimVal = DimInfo[i]; + auto isNegOne = + rewriter.create(loc, CmpIPredicate::eq, dimVal, negOne); + // If dimension is negative, compute its value from the other + // dimensions. + auto actualDimVal = + rewriter.create(loc, tensorSize, tensorSizeFromShape); + auto loadedVal = + rewriter.create(loc, isNegOne, actualDimVal, dimVal); + allocOperands.push_back(rewriter.create( + loc, loadedVal, rewriter.getIndexType())); + } + AllocOp allocateMemref = + rewriter.create(loc, memRefType, allocOperands); + + // Make sure to allocate at the beginning of the block if + // all dimensions are known. + auto *parentBlock = allocateMemref.getOperation()->getBlock(); + if (insertDealloc) { + auto dealloc = rewriter.create(loc, allocateMemref); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + } + + alloc = allocateMemref; + } + + rewriter.create(loc, alloc, operands[0], tensorSize); + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +void populateLoweringONNXReshapeOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/transpose.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/transpose.inc new file mode 100644 index 0000000..39cfa8c --- /dev/null +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/transpose.inc @@ -0,0 +1,99 @@ +//===----- transpose.inc - Lowering Transpose Op --------------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the ONNX Transpose Operator to Krnl dialect. +// +//===----------------------------------------------------------------------===// + +struct ONNXTransposeOpLowering : public ConversionPattern { + ONNXTransposeOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto tensorType = (*op->result_type_begin()).cast(); + auto loc = op->getLoc(); + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertTensorToMemRef(tensorType); + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, + {operands[0]}); + + // Number of loops + auto memRefShape = memRefType.getShape(); + int64_t rank = memRefShape.size(); + + // Define loops. + std::vector originalLoops; + std::vector optimizedLoops; + Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, + optimizedLoops, rank); + + KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); + // Iterate over the loop nest using the input shape. + for (int i = 0; i < rank; ++i) + addDimensionToPack(rewriter, loc, pack, operands[0], i); + + auto iterateOp = rewriter.create(loc, pack); + Block &iterationBlock = iterateOp.bodyRegion().front(); + + // Now perform the insertions into the body of the + // just generated instructions: + + // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. + rewriter.setInsertionPointToEnd(optimizationBlock); + // Return from KrnlOptimizeLoopsOp body. + // When no optimizations are present we just return the loops + // unchaged. + rewriter.create(loc, originalLoops); + + // 2. Insert instructions inside the KernelIterateOp body. + rewriter.setInsertionPointToStart(&iterationBlock); + + // Handle the operation. + + // Read perm attribute. + SmallVector perm; + auto permAttribute = llvm::dyn_cast(op).permAttr(); + if (permAttribute) { + for (auto permVal : permAttribute.getValue()) + perm.emplace_back(permVal.cast().getInt()); + } else { + // TODO: Remove when perm is guaranteed to be present (even for + // the default case). This means that perm was added by shape + // inference or another pass to contain the values corresponding + // to the default behavior of Transpose. + for (int i = iterationBlock.getArguments().size()-1; i >= 0; i--) + perm.emplace_back(i); + } + + SmallVector inLoopIVs; + for (auto arg : iterationBlock.getArguments()) + inLoopIVs.emplace_back(arg); + + SmallVector outLoopIVs; + for (int i=0; i(loc, operands[0], inLoopIVs); + rewriter.create(loc, inVal, alloc, outLoopIVs); + + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +void populateLoweringONNXTransposeOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc new file mode 100644 index 0000000..18b9f8b --- /dev/null +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc @@ -0,0 +1,86 @@ +//===----- unsqueeze.inc - Lowering Unsqueeze Op --------------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the ONNX Unsqueeze Operator to Krnl dialect. +// +//===----------------------------------------------------------------------===// + +struct ONNXUnsqueezeOpLowering : public ConversionPattern { + ONNXUnsqueezeOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + auto tensorType = (*op->result_type_begin()).cast(); + int outRank = tensorType.getRank(); + + // Assume that `axes` has been validated by shape inference. + // So, here we just get it. + ArrayAttr axisAttrs = llvm::dyn_cast(op).axesAttr(); + SmallVector axes; + for (auto axisAttr : axisAttrs.getValue()) { + int axis = axisAttr.cast().getInt(); + axis = axis >= 0 ? axis : (outRank + axis); + axes.emplace_back(axis); + } + + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertTensorToMemRef(tensorType); + Value alloc; + + // Compute size in bytes. + Value tensorSize = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), + getMemRefEltSizeInBytes(memRefType))); + + bool insertDealloc = checkInsertDealloc(op); + auto memRefShape = memRefType.getShape(); + if (hasAllConstantDimensions(memRefType)) { + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + for (int i = 0; i < memRefShape.size(); ++i) { + Value dimVal = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), + memRefShape[i])); + tensorSize = rewriter.create(loc, tensorSize, dimVal); + } + } else { + // Unknown dimensions are always the operand's dimensions. + SmallVector allocOperands; + for (int outIdx = 0, inIdx = 0; outIdx < memRefShape.size(); ++outIdx) { + Value dimVal = nullptr; + if (memRefShape[outIdx] < 0) { + Value index = rewriter.create(loc, operands[0], inIdx); + dimVal = rewriter.create( + loc, index, rewriter.getIntegerType(64)); + allocOperands.emplace_back(index); + } else { + dimVal = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), + memRefShape[outIdx])); + } + tensorSize = rewriter.create(loc, tensorSize, dimVal); + if (std::find(axes.begin(), axes.end(), outIdx) == axes.end()) + inIdx++; + } + alloc = rewriter.create(loc, memRefType, allocOperands); + auto *parentBlock = alloc.getDefiningOp()->getBlock(); + if (insertDealloc) { + auto dealloc = rewriter.create(loc, alloc); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + } + } + rewriter.create(loc, alloc, operands[0], tensorSize); + rewriter.replaceOp(op, alloc); + return matchSuccess(); + } +}; + +void populateLoweringONNXUnsqueezeOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} diff --git a/src/pass/lower_frontend_to_krnl.cpp b/src/pass/lower_frontend_to_krnl.cpp deleted file mode 100644 index 08230cb..0000000 --- a/src/pass/lower_frontend_to_krnl.cpp +++ /dev/null @@ -1,2719 +0,0 @@ -//====- lower_frontend_to_krnl.cpp - Frontend dialects to Krnl lowering ---===// -// -// Copyright 2019 The IBM Research Authors. -// -// ============================================================================= -// -// This file implements the lowering of frontend operations to a combination of -// Krnl IR and standard operations. -// -//===----------------------------------------------------------------------===// -#include - -#include "mlir/Dialect/AffineOps/AffineOps.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/Sequence.h" - -#include "src/dialect/krnl/krnl_helper.hpp" -#include "src/dialect/krnl/krnl_ops.hpp" -#include "src/dialect/onnx/onnx_ops.hpp" -#include "src/pass/passes.hpp" - -using namespace mlir; - -//===----------------------------------------------------------------------===// -// FrontendToAffine RewritePatterns -//===----------------------------------------------------------------------===// - -/// Check is all dimensions are known at compile time. -static bool hasAllConstantDimensions(MemRefType type) { - auto memRefShape = type.getShape(); - for (int i = 0; i < memRefShape.size(); ++i) - if (memRefShape[i] < 0) - return false; - return true; -} - -/// Convert the given TensorType into the corresponding MemRefType. -static MemRefType convertTensorToMemRef(TensorType type) { - assert(type.hasRank() && "expected only ranked shapes"); - return MemRefType::get(type.getShape(), type.getElementType()); -} - -/// Insert an allocation and deallocation for the given MemRefType. -static Value insertAllocAndDealloc(MemRefType type, Location loc, - PatternRewriter &rewriter, - bool insertDealloc, - ArrayRef operands = {}) { - // Put together alloc operands for any dynamic dimensions of the memref. - AllocOp alloc; - if (!operands.empty()) { - auto memRefShape = type.getShape(); - auto rank = memRefShape.size(); - - std::map fromOperands; - for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { - int memRefDimIdx = rank - 1 - reversedIdx; - if (memRefShape[memRefDimIdx] < 0) { // unknown dimension - Value maxDim = nullptr; - for (int i = 0; i < operands.size(); i++) { - auto operandShape = - operands[i].getType().cast().getShape(); - int operandDimIdx = operandShape.size() - 1 - reversedIdx; - - if (operandDimIdx < 0) - continue; - - // In case of operations with broadcasting, the dimension of the - // alloc result is the maximum size along each dimension of the - // operands. - auto operandDim = - rewriter.create(loc, operands[i], operandDimIdx); - if (maxDim) { - auto maxCondition = rewriter.create(loc, CmpIPredicate::sgt, - operandDim, maxDim); - maxDim = rewriter.create(loc, maxCondition, operandDim, - maxDim); - } else { - maxDim = operandDim; - } - } - fromOperands.insert(std::make_pair(memRefDimIdx, maxDim)); - } - } - - SmallVector allocOperands; - for (int i = 0; i < rank; ++i) - if (memRefShape[i] < 0) - allocOperands.push_back(fromOperands[i]); - alloc = rewriter.create(loc, type, allocOperands); - } else { - alloc = rewriter.create(loc, type); - } - - // Make sure to allocate at the beginning of the block if - // all dimensions are known. - auto *parentBlock = alloc.getOperation()->getBlock(); - if (hasAllConstantDimensions(type)) - alloc.getOperation()->moveBefore(&parentBlock->front()); - - if (insertDealloc) { - auto dealloc = rewriter.create(loc, alloc); - dealloc.getOperation()->moveBefore(&parentBlock->back()); - } - - return alloc; -} - -// Determine if current function returns the result value of the -// current op being lowered. If it does then dealloc should not be -// inserted. -static bool checkInsertDealloc(Operation *currentOp) { - auto parentBlock = currentOp->getBlock(); - - bool insertDealloc = true; - parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) { - assert(currentOp->getNumResults() < 2 && - "No more than one result supported (for now)."); - // If there is at least one result to investigate. - if (currentOp->getNumResults() > 0) { - auto result = currentOp->getResult(0); - for (const auto &operand : op.getOperands()) - if (operand == result) - insertDealloc = false; - } - }); - - return insertDealloc; -} - -// Create a mapping from result type's dimensions to input type's dimensions, -// given that the result type is the result of a reduction op over the input -// type. -std::map -getReductionMapping(MemRefType inputTy, ArrayRef axes, bool keepdims) { - std::map OutInDimMap; - int64_t rank = inputTy.getRank(); - - // Mark reduction axes. - std::vector isReductionAxis; - for (decltype(rank) i = 0; i < rank; ++i) { - if (std::find(axes.begin(), axes.end(), i) != axes.end()) - isReductionAxis.push_back(true); - else - isReductionAxis.push_back(false); - } - - for (decltype(rank) inIndex = 0, outIndex = 0; inIndex < rank; ++inIndex) { - // If it is a reduction axis, there is no relationship among dimensions. - if (isReductionAxis[inIndex]) { - if (keepdims) - outIndex++; - } else { - OutInDimMap.insert(std::make_pair(outIndex, inIndex)); - outIndex++; - } - } - - return OutInDimMap; -} - -// Add bounds associated with the op operand to the KRNL iteration pack. -// Dynamic dimenions are supported. -static void addDimensionToPack(ConversionPatternRewriter &rewriter, - Location loc, KrnlIterateOperandPack &pack, Value operand, int index) { - auto shape = operand.getType().cast().getShape(); - if (shape[index] < 0) { - pack.pushConstantBound(0); - pack.pushOperandBound( - rewriter.create(loc, operand, index).getResult()); - } else { - pack.pushConstantBound(0); - pack.pushConstantBound(shape[index]); - } -} - -// Function that defines the KRNL dialect loops and their respective -// optimized version. -static KrnlOptimizeLoopsOp emitOptimizedLoops( - ConversionPatternRewriter &rewriter, Location loc, - std::vector &loops, std::vector &optimizedLoops, - int64_t numLoops) { - // Define loops. - auto loopsOp = rewriter.create(loc, numLoops); - loops.reserve(numLoops); - for (auto result : loopsOp.getResults()) - loops.push_back(result); - - // Define optimized version of the loops. - auto optimizedLoopsOp = rewriter.create(loc, numLoops); - optimizedLoops.reserve(numLoops); - for (auto result : optimizedLoopsOp.getResults()) - optimizedLoops.push_back(result); - - return optimizedLoopsOp; -} - -// Function that emits the loops and their optimized version. -// The function returns a reference to the inner optimization block. -static Block* defineLoops(ConversionPatternRewriter &rewriter, - Location loc, std::vector &loops, - std::vector &optimizedLoops, int64_t numLoops) { - KrnlOptimizeLoopsOp optimizedLoopsOp = emitOptimizedLoops( - rewriter, loc, loops, optimizedLoops, numLoops); - return &optimizedLoopsOp.region().front(); -} - -// Function which emits a basic set of loops and optimized loops -// for a given operation argument. A reference to the loop optimization -// block is returned in the last argument of the function. -static void emitKrnlLoopsAndIterationForOperand( - ConversionPatternRewriter &rewriter, Location loc, - Value operand, std::vector &originalLoops, - KrnlOptimizeLoopsOp &optimizedLoopsOp, KrnlIterateOp &iterateOp) { - // Operand shape. - auto shape = operand.getType().cast().getShape(); - - // Number of loops. - int64_t rank = shape.size(); - - // Define loops and optimized loops. - std::vector optimizedLoops; - optimizedLoopsOp = emitOptimizedLoops(rewriter, loc, originalLoops, - optimizedLoops, rank); - - KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); - // Iterate over the loop nest. - for (int i = 0; i < rank; ++i) - addDimensionToPack(rewriter, loc, pack, operand, i); - - iterateOp = rewriter.create(loc, pack); -} - -unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { - auto elementType = memRefType.getElementType(); - - unsigned sizeInBits; - if (elementType.isIntOrFloat()) { - sizeInBits = elementType.getIntOrFloatBitWidth(); - } else { - auto vectorType = elementType.cast(); - sizeInBits = - vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); - } - return llvm::divideCeil(sizeInBits, 8); -} - -// Get run-time dimension information for unknown dimensions used for -// broadcasting. -std::map> -getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, - MemRefType memRefType, ArrayRef operands) { - auto memRefShape = memRefType.getShape(); - int64_t rank = memRefShape.size(); - // For unknown dimensions, we need to get dimension values at runtime in - // order to do broadcasting. - std::map> DimInfo; - // For each result dimension, compute the number of sharing operands. - // Sharing operands are operands sharing the same index (counting from the - // rightmost to the leftmost) for a given dimension. - std::map sharedDimCount; - for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { - int dimIdx = rank - 1 - reversedIdx; - sharedDimCount[dimIdx] = 0; - for (int i = 0; i < operands.size(); ++i) { - auto shape = operands[i].getType().cast().getShape(); - if (reversedIdx <= shape.size() - 1) - sharedDimCount[dimIdx]++; - } - } - // An unknown dimension can have a value of 1 or N (N > 1). - // If its value is 1, it is broadcasted dimension. - // Otherwise, non-broadcasted dimension. - // We only care about unknown dimensions whose number of sharing operands is - // more than one, since they are potentially broadcasted dimensions. - for (int i = 0; i < operands.size(); ++i) { - std::map broadcastedDims; - auto shape = operands[i].getType().cast().getShape(); - int size = shape.size(); - for (int j = 0; j < shape.size(); ++j) { - if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) { - auto dim = rewriter.create(loc, operands[i], j).getResult(); - auto one = rewriter.create(loc, 1); - auto isBroadcasted = - rewriter.create(loc, CmpIPredicate::eq, dim, one); - broadcastedDims.insert(std::make_pair(j, isBroadcasted)); - } - } - DimInfo.insert(std::make_pair(i, broadcastedDims)); - } - return DimInfo; -} - -// Extract induction variables that are used for broadcasting values of a -// given operand. -std::vector -getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, - ArrayRef loopIVs, Value operand, - std::map broadcastedDims) { - // `operand` must has a ranked type. This should have been checked by the - // shape inference pass. - auto operandShape = operand.getType().cast().getShape(); - auto rank = operandShape.size(); - auto loopCount = loopIVs.size(); - - std::vector newLoopIVs; - for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { - auto dimIdx = rank - 1 - reversedIdx; - auto loopIdx = loopCount - 1 - reversedIdx; - if (operandShape[dimIdx] == 1) { - // Broadcasted dimension - auto zero = rewriter.create(loc, 0); - newLoopIVs.insert(newLoopIVs.begin(), zero); - } else if ((operandShape[dimIdx] == -1) && - (broadcastedDims.find(dimIdx) != broadcastedDims.end())) { - // Unknown dimension, it can have a value of 1 or N (N > 1). - // If its value is 1, it is broadcasted dimension. - // Otherwise, non-broadcasted dimension. - auto zero = rewriter.create(loc, 0); - auto idx = rewriter.create(loc, broadcastedDims[dimIdx], zero, - loopIVs[loopIdx]); - newLoopIVs.insert(newLoopIVs.begin(), idx); - } else { - // Non-broadcasted dimension - newLoopIVs.insert(newLoopIVs.begin(), loopIVs[loopIdx]); - } - } - return newLoopIVs; -} - -namespace { - -template -struct ScalarOp; - -template <> -struct ScalarOp { - using FOp = AddFOp; - using IOp = AddIOp; -}; - -template <> -struct ScalarOp { - using FOp = MulFOp; - using IOp = MulIOp; -}; - -template <> -struct ScalarOp { - using FOp = DivFOp; - using IOp = SignedDivIOp; -}; - -template <> -struct ScalarOp { - using FOp = SubFOp; - using IOp = SubIOp; -}; - -template <> -struct ScalarOp { - using FOp = AndOp; // not use - using IOp = AndOp; -}; - -template <> -struct ScalarOp { - using FOp = OrOp; // not use - using IOp = OrOp; -}; - -template <> -struct ScalarOp { - using FOp = XOrOp; // not use - using IOp = XOrOp; -}; - -template <> -struct ScalarOp { - using FOp = ExpOp; - using IOp = ExpOp; // not use -}; - -template <> -struct ScalarOp { - using FOp = AddFOp; - using IOp = AddIOp; -}; - -template <> -struct ScalarOp { - using FOp = TanhOp; - using IOp = TanhOp; // not use -}; - -template <> -struct ScalarOp { - using FOp = CosOp; - using IOp = CosOp; // not use -}; - -template <> -struct ScalarOp { - using FOp = LogOp; - using IOp = LogOp; // not use -}; - -template <> -struct ScalarOp { - using FOp = MulFOp; - using IOp = MulIOp; -}; - -template <> -struct ScalarOp { - using FOp = AddFOp; - using IOp = AddIOp; -}; - -template <> -struct ScalarOp { - using FOp = KrnlSqrtOp; - using IOp = KrnlSqrtOp; // not use -}; - -template -using ScalarFOp = typename ScalarOp::FOp; -template -using ScalarIOp = typename ScalarOp::IOp; - -// Get the identity element of a operation. -// Return NULL if the function does not have identity. -template -DataType getIdentityValue() { - return NULL; -} - -template <> -float getIdentityValue(){ - return (float)-std::numeric_limits::infinity(); -} - -template <> -int getIdentityValue(){ - return std::numeric_limits::min(); -} - -template <> -float getIdentityValue(){ - return (float)std::numeric_limits::infinity(); -} - -template <> -int getIdentityValue(){ - return std::numeric_limits::max(); -} - -template <> -float getIdentityValue(){ - return (float)1.0; -} - -template <> -int getIdentityValue(){ - return 1; -} - -template <> -float getIdentityValue(){ - return (float)0; -} - -template <> -int getIdentityValue(){ - return 0; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering to Krnl dialect. -//===----------------------------------------------------------------------===// -template -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - /* Lower UnaryOp to Ops in the Standard dialect. - */ - auto loc = op->getLoc(); - Type element_type = operands.front().getType(); - if (element_type.isa()) { - return rewriter.create>(loc, result_types, operands, - mlir::None); - } else if (element_type.isa()) { - return rewriter.create>(loc, result_types, operands, - mlir::None); - } else { - emitError(loc, "unsupported element type"); - return nullptr; - } -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXSinhOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), - // ConstantOp 2) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; - - auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); - auto two = rewriter.create(loc, FloatAttr::get(elementType, 2)); - auto neg = rewriter.create(loc, zero, operand); - auto exp = rewriter.create(loc, operand); - auto negExp = rewriter.create(loc, neg); - auto result = rewriter.create( - loc, rewriter.create(loc, exp, negExp), two); - - return result; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXCoshOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))), - // ConstantOp 2) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; - - auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); - auto two = rewriter.create(loc, FloatAttr::get(elementType, 2)); - auto neg = rewriter.create(loc, zero, operand); - auto exp = rewriter.create(loc, operand); - auto negExp = rewriter.create(loc, neg); - auto result = rewriter.create( - loc, rewriter.create(loc, exp, negExp), two); - - return result; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXSigmoidOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, - ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1, - // AddFOp(ConstantOp 1, ExpOp(NegFOp(%X)))) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; - - auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); - auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); - auto neg = rewriter.create(loc, zero, operand); - auto negExp = rewriter.create(loc, neg); - auto result = rewriter.create( - loc, one, rewriter.create(loc, one, negExp)); - - return result; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXHardSigmoidOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp( - Operation *op, ArrayRef result_types, ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // %Y = AddFOp(MulFOp(alpha, %X), beta) - // %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0), - // %Y, - // Constant 0) - // ONNXHardSigmoidOp(%X) = SelectOp(CmpFOp(OLT, %Z, Constant 1), - // %Z, - // Constant 1) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto alphaAttribute = FloatAttr::get( - rewriter.getF32Type(), - llvm::dyn_cast(op).alpha().convertToFloat()); - auto betaAttribute = FloatAttr::get( - rewriter.getF32Type(), - llvm::dyn_cast(op).beta().convertToFloat()); - auto elementType = result_types[0]; - - auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); - auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); - auto alpha = rewriter.create(loc, alphaAttribute); - auto beta = rewriter.create(loc, betaAttribute); - - auto add = rewriter.create( - loc, rewriter.create(loc, alpha, operand), beta); - auto maxPredicate = - rewriter.create(loc, CmpFPredicate::OGT, add, zero); - auto max = rewriter.create(loc, maxPredicate, add, zero); - auto minPredicate = - rewriter.create(loc, CmpFPredicate::OLT, max, one); - auto result = rewriter.create(loc, minPredicate, max, one); - - return result; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXEluOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), - // MulFOp(alpha, SubFOp(ExpOp(%X), 1)), - // %X) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; - - auto alphaAttribute = - FloatAttr::get(rewriter.getF32Type(), - llvm::dyn_cast(op).alpha().convertToFloat()); - auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); - auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); - auto alpha = rewriter.create(loc, alphaAttribute); - auto exp = rewriter.create(loc, operand); - auto lessThanZero = - rewriter.create(loc, CmpFPredicate::OLT, operand, zero); - auto result = rewriter.create( - loc, lessThanZero, - rewriter.create(loc, alpha, - rewriter.create(loc, exp, one)), - operand); - - return result; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXReluOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), - // ConstantOp 0, - // %X) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; - - auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); - auto lessThanZero = - rewriter.create(loc, CmpFPredicate::OLT, operand, zero); - auto result = rewriter.create(loc, lessThanZero, zero, operand); - - return result; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXLeakyReluOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, - ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), - // MulFOp(alpha, %X), - // %X) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; - - auto alphaAttribute = FloatAttr::get( - rewriter.getF32Type(), - llvm::dyn_cast(op).alpha().convertToFloat()); - auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); - auto alpha = rewriter.create(loc, alphaAttribute); - auto lessThanZero = - rewriter.create(loc, CmpFPredicate::OLT, operand, zero); - auto result = rewriter.create( - loc, lessThanZero, rewriter.create(loc, alpha, operand), operand); - - return result; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXSeluOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0), - // MulFOp(gamma, %X), - // MulFOp(gamma, - // SubFOp(MulFOp(alpha, ExpOp(%X)), - // alpha))) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto alphaAttribute = - FloatAttr::get(rewriter.getF32Type(), - llvm::dyn_cast(op).alpha().convertToFloat()); - auto gammaAttribute = - FloatAttr::get(rewriter.getF32Type(), - llvm::dyn_cast(op).gamma().convertToFloat()); - auto elementType = result_types[0]; - - auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); - auto alpha = rewriter.create(loc, alphaAttribute); - auto gamma = rewriter.create(loc, gammaAttribute); - auto exp = rewriter.create(loc, operand); - auto greaterThanZero = - rewriter.create(loc, CmpFPredicate::OGT, operand, zero); - auto select = rewriter.create( - loc, greaterThanZero, operand, - rewriter.create(loc, rewriter.create(loc, alpha, exp), - alpha)); - auto result = rewriter.create(loc, gamma, select); - - return result; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXReciprocalOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp( - Operation *op, ArrayRef result_types, ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; - - auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); - auto result = rewriter.create(loc, one, operand); - - return result; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXSoftplusOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, - ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // ONNXSoftplusOp(%X) = LogOp(AddFOp(ExpOp(%X), ConstantOp 1)) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; - - auto exp = rewriter.create(loc, operand); - auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); - auto add = rewriter.create(loc, exp, one); - auto result = rewriter.create(loc, add); - - return result; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXSoftsignOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, - ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // ONNXSoftsignOp(%X) = DivFOp(ConstantOp 1, %X) - auto loc = op->getLoc(); - Value operand = operands[0]; - auto elementType = result_types[0]; - - auto abs = rewriter.create(loc, operand); - auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); - auto add = rewriter.create(loc, abs, one); - auto result = rewriter.create(loc, operand, add); - - return result; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXSignOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - - auto loc = op->getLoc(); - Value operand = operands[0]; - Type element_type = operands.front().getType(); - // TODO: unsigned int should be supported separately? - if (element_type.isa()) { - // %Y = SelectOP(CmpIOp(GT, %X, ConstantOp 0), - // ConstantOp 1, - // COnstantOp -1) - // ONNXSignOp(%X) = SelectOP(CmpIOp(EQ, %X, ConstantOp 0), - // ConstantOp 0, - // %Y) - auto zero = rewriter.create(loc, rewriter.getI32IntegerAttr(0)); - auto one = rewriter.create(loc, rewriter.getI32IntegerAttr(1)); - auto minusOne = - rewriter.create(loc, rewriter.getI32IntegerAttr(-1)); - auto plusPredicate = - rewriter.create(loc, CmpIPredicate::sgt, operand, zero); - auto plusSelect = - rewriter.create(loc, plusPredicate, one, minusOne); - auto zeroPredicate = - rewriter.create(loc, CmpIPredicate::eq, operand, zero); - auto result = - rewriter.create(loc, zeroPredicate, zero, plusSelect); - return result; - } else if (element_type.isa()) { - // %Y = SelectOP(CmpFOp(OGT, %X, ConstantOp 0), - // ConstantOp 1, - // ConstantOp -1) - // ONNXSignOp(%X) = SelectOP(CmpFOp(OEQ, %X, ConstantOp 0), - // ConstantOp 0, - // %Y) - auto zero = - rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); - auto one = rewriter.create(loc, rewriter.getF32FloatAttr(1.0f)); - auto minusOne = - rewriter.create(loc, rewriter.getF32FloatAttr(-1.0f)); - auto plusPredicate = - rewriter.create(loc, CmpFPredicate::OGT, operand, zero); - auto plusSelect = - rewriter.create(loc, plusPredicate, one, minusOne); - auto zeroPredicate = - rewriter.create(loc, CmpFPredicate::OEQ, operand, zero); - auto result = - rewriter.create(loc, zeroPredicate, zero, plusSelect); - return result; - } else { - emitError(loc, "unsupported element type"); - } -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXMaxOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y), - // %X, - // %Y) - auto loc = op->getLoc(); - Value lhs = operands[0]; - Value rhs = operands[1]; - auto max = rewriter.create(loc, CmpFPredicate::OGT, lhs, rhs); - auto result = rewriter.create(loc, max, lhs, rhs); - return result; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXMinOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y), - // %X, - // %Y) - auto loc = op->getLoc(); - Value lhs = operands[0]; - Value rhs = operands[1]; - auto min = rewriter.create(loc, CmpFPredicate::OLT, lhs, rhs); - auto result = rewriter.create(loc, min, lhs, rhs); - return result; -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXReduceMaxOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, - ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - auto loc = op->getLoc(); - Value lhs = operands[0]; - Value rhs = operands[1]; - Type element_type = lhs.getType(); - if (element_type.isa()) { - auto max = rewriter.create(loc, CmpIPredicate::sgt, lhs, rhs); - auto result = rewriter.create(loc, max, lhs, rhs); - return result; - } else if (element_type.isa()) { - auto max = rewriter.create(loc, CmpFPredicate::OGT, lhs, rhs); - auto result = rewriter.create(loc, max, lhs, rhs); - return result; - } else { - emitError(loc, "unsupported element type"); - return nullptr; - } -} - -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXReduceMinOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, - ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - auto loc = op->getLoc(); - Value lhs = operands[0]; - Value rhs = operands[1]; - Type element_type = lhs.getType(); - if (element_type.isa()) { - auto min = rewriter.create(loc, CmpIPredicate::slt, lhs, rhs); - auto result = rewriter.create(loc, min, lhs, rhs); - return result; - } else if (element_type.isa()) { - auto min = rewriter.create(loc, CmpFPredicate::OLT, lhs, rhs); - auto result = rewriter.create(loc, min, lhs, rhs); - return result; - } else { - emitError(loc, "unsupported element type"); - return nullptr; - } -} - -// Element-wise unary ops lowering to Krnl dialect. -//===----------------------------------------------------------------------===// -template -struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { - ONNXElementwiseUnaryOpLowering(MLIRContext *ctx) - : ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {} - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - // TODO: Check that the types are valid. - // An element-wise unary operation must have all operands and the result of - // the same type. This should have been verified by the verifier. - auto tensorType = (*op->result_type_begin()).cast(); - auto loc = op->getLoc(); - - // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertTensorToMemRef(tensorType); - - // If the output has a dynamic dimension, pass the operands required for - // each dynamic dimension to the AllocOp. The first operand of the - // operation is used. The operands of the op need to match in terms of - // dimensions with the result at this pre-optimization phase. - // TODO: verify that dimensions match. - // TODO: can the dimension of the result differ after optimizations? - Value alloc; - bool insertDealloc = checkInsertDealloc(op); - - if (hasAllConstantDimensions(memRefType)) - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); - else - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, - {operands[0]}); - - std::vector originalLoops; - KrnlOptimizeLoopsOp optimizedLoopsOp; - KrnlIterateOp iterateOp; - emitKrnlLoopsAndIterationForOperand( - rewriter, loc, operands[0], originalLoops, - optimizedLoopsOp, iterateOp); - Block &optimizationBlock = optimizedLoopsOp.region().front(); - Block &iterationBlock = iterateOp.bodyRegion().front(); - - // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. - rewriter.setInsertionPointToEnd(&optimizationBlock); - // Return from KrnlOptimizeLoopsOp body. - // When no optimizations are present we just return the loops - // unchaged. - rewriter.create(loc, originalLoops); - - // 2. Insert instructions inside the KernelIterateOp body. - rewriter.setInsertionPointToStart(&iterationBlock); - - // Handle the operation: - SmallVector loopIVs; - for (auto arg : iterationBlock.getArguments()) - loopIVs.push_back(arg); - - auto loadedVal = rewriter.create(loc, operands[0], loopIVs); - auto loweredOpResult = mapToLowerScalarOp( - op, memRefType.getElementType(), {loadedVal}, rewriter); - // Store result in the resulting array. - rewriter.create(loc, loweredOpResult, alloc, loopIVs); - - rewriter.replaceOp(op, alloc); - - return matchSuccess(); - } -}; - -// Element-wise variadic ops lowering to Krnl dialect. -//===----------------------------------------------------------------------===// -template -struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { - ONNXElementwiseVariadicOpLowering(MLIRContext *ctx) - : ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {} - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - // TODO: Check that the types are valid. - // An element-wise variadic operation must have all operands and the result - // of the same type. This should have been verified by the verifier. - auto tensorType = (*op->result_type_begin()).cast(); - auto loc = op->getLoc(); - auto numArgs = op->getNumOperands(); - - // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertTensorToMemRef(tensorType); - - Value alloc; - bool insertDealloc = checkInsertDealloc(op); - // If the output has a dynamic dimension, we compute its dimension at - // runtime by using dimensions from the operands. - // In particular, we need to know from which operand a result dimension - // comes from. - // TODO: can the dimension of the result differ after optimizations? - if (hasAllConstantDimensions(memRefType)) - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); - else - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, - operands); - - // Get run-time dimension information for unknown dimensions used for - // broadcasting. - std::map> broadcastedDimInfo = - getBroadcastedDimInfo(loc, rewriter, memRefType, operands); - - std::vector originalLoops; - KrnlOptimizeLoopsOp optimizedLoopsOp; - KrnlIterateOp iterateOp; - emitKrnlLoopsAndIterationForOperand( - rewriter, loc, alloc, originalLoops, - optimizedLoopsOp, iterateOp); - Block &optimizationBlock = optimizedLoopsOp.region().front(); - Block &iterationBlock = iterateOp.bodyRegion().front(); - - // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. - rewriter.setInsertionPointToEnd(&optimizationBlock); - // Return from KrnlOptimizeLoopsOp body. - // When no optimizations are present we just return the loops unchaged. - rewriter.create(loc, originalLoops); - - // 2. Insert instructions inside the KernelIterateOp body. - rewriter.setInsertionPointToStart(&iterationBlock); - - // Handle the operation: - SmallVector loopIVs; - for (auto arg : iterationBlock.getArguments()) - loopIVs.push_back(arg); - - // Fold over operands for each of their scalar values - Value accumulated, next; - auto accumulatedLoopIVs = getLoopIVsForBroadcasting( - loc, rewriter, loopIVs, operands[0], broadcastedDimInfo[0]); - accumulated = rewriter.create(loc, operands[0], accumulatedLoopIVs); - for (unsigned i = 1; i < numArgs; i++) { - auto nextLoopIVs = getLoopIVsForBroadcasting( - loc, rewriter, loopIVs, operands[i], broadcastedDimInfo[i]); - next = rewriter.create(loc, operands[i], nextLoopIVs); - accumulated = mapToLowerScalarOp( - op, memRefType.getElementType(), {accumulated, next}, rewriter); - } - // Store result in the resulting array. - rewriter.create(loc, accumulated, alloc, loopIVs); - - rewriter.replaceOp(op, alloc); - - return matchSuccess(); - } -}; - -struct ONNXSoftmaxOpLowering : public ConversionPattern { - ONNXSoftmaxOpLowering(MLIRContext *ctx) - : ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {} - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - // softmax(x) = let max_x = max(x) in - // let exp_x = exp(x - max_x) in - // let sum = sum(exp_x) in - // exp_x / sum - auto tensorType = (*op->result_type_begin()).cast(); - int64_t rank = tensorType.getRank(); - int64_t axis = llvm::dyn_cast(op).axis().getSExtValue(); - axis = axis >= 0 ? axis : rank + axis; - assert(axis >= -rank && axis <= rank - 1); - - auto loc = op->getLoc(); - - // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertTensorToMemRef(tensorType); - auto elementType = memRefType.getElementType(); - - Value alloc; - bool insertDealloc = checkInsertDealloc(op); - if (hasAllConstantDimensions(memRefType)) - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); - else - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, - operands[0]); - - // Shape of the result - auto memRefShape = memRefType.getShape(); - - // Insert allocations and deallocations for sum and max. - MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0); - Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true); - Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true); - Value zero = - rewriter.create(loc, FloatAttr::get(elementType, 0)); - Value negInfinity = rewriter.create( - loc, - FloatAttr::get(elementType, -std::numeric_limits::infinity())); - - // Define loops. - std::vector originalLoops; - std::vector optimizedLoops; - Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, - optimizedLoops, rank); - - // Coerce the input into a 2-D tensor. `axis` will be the coercing point. - // This coercing follows the softmax definition in ONNX: - // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softmax - // Here, we create an outer loop and inner loop for handling the two - // dimensions. The outer loop is only created once `axis` is not zero. - - // Define an outer loop with respect to axis. - std::vector outerLoops, optimizedOuterLoops; - outerLoops.reserve(axis); - optimizedOuterLoops.reserve(axis); - for (int i = 0; i < axis; ++i) { - outerLoops.push_back(originalLoops[i]); - optimizedOuterLoops.push_back(optimizedLoops[i]); - } - KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops); - for (int i = 0; i < axis; ++i) - addDimensionToPack(rewriter, loc, outerPack, operands[0], i); - - // Define an inner loop with respect to axis. - std::vector innerLoops, optimizedInnerLoops; - innerLoops.reserve(rank - axis); - optimizedInnerLoops.reserve(rank - axis); - for (int i = axis; i < rank; ++i) { - innerLoops.push_back(originalLoops[i]); - optimizedInnerLoops.push_back(optimizedLoops[i]); - } - KrnlIterateOperandPack innerPack(rewriter, innerLoops, optimizedInnerLoops); - for (int i = axis; i < rank; ++i) - addDimensionToPack(rewriter, loc, innerPack, operands[0], i); - - KrnlIterateOp outerIterateOp, maxIterateOp, sumIterateOp, softmaxIterateOp; - SmallVector outerLoopIVs; - if (axis != 0) { - outerIterateOp = rewriter.create(loc, outerPack); - - // No optimization - rewriter.setInsertionPointToEnd(optimizationBlock); - rewriter.create(loc, originalLoops); - - // Insert instructions inside the outer loop. - Block &outerIterationBlock = outerIterateOp.bodyRegion().front(); - rewriter.setInsertionPointToStart(&outerIterationBlock); - for (auto arg : outerIterationBlock.getArguments()) - outerLoopIVs.push_back(arg); - - // Reset accumulators. - rewriter.create(loc, zero, sumOp); - rewriter.create(loc, negInfinity, maxOp); - - // Create an inner loop to compute max. - maxIterateOp = rewriter.create(loc, innerPack); - // Create an inner loop to compute sum. - sumIterateOp = rewriter.create(loc, innerPack); - // Create an inner loop to compute softmax. - softmaxIterateOp = rewriter.create(loc, innerPack); - } else { - // Reset accumulators. - rewriter.create(loc, zero, sumOp); - rewriter.create(loc, negInfinity, maxOp); - - // Create an inner loop to compute max. - maxIterateOp = rewriter.create(loc, innerPack); - // Create an inner loop to compute sum. - sumIterateOp = rewriter.create(loc, innerPack); - // Create an inner loop to compute softmax. - softmaxIterateOp = rewriter.create(loc, innerPack); - - // No optimization - rewriter.setInsertionPointToEnd(optimizationBlock); - rewriter.create(loc, originalLoops); - } - - // Insert instructions inside the max loop. - Block &maxIterationBlock = maxIterateOp.bodyRegion().front(); - rewriter.setInsertionPointToStart(&maxIterationBlock); - - // Get induction variables. - SmallVector maxLoopIVs; - for (auto arg : outerLoopIVs) - maxLoopIVs.push_back(arg); - for (auto arg : maxIterationBlock.getArguments()) - maxLoopIVs.push_back(arg); - - // Compute the max value. - Value max = rewriter.create(loc, maxOp); - Value nextMax = rewriter.create(loc, operands[0], maxLoopIVs); - auto maxCond = - rewriter.create(loc, CmpFPredicate::OGT, max, nextMax); - max = rewriter.create(loc, maxCond, max, nextMax); - rewriter.create(loc, max, maxOp); - - // Get the max. - rewriter.setInsertionPoint(sumIterateOp); - max = rewriter.create(loc, maxOp); - - // Insert instructions inside the sum loop. - Block &sumIterationBlock = sumIterateOp.bodyRegion().front(); - rewriter.setInsertionPointToStart(&sumIterationBlock); - - // Get induction variables. - SmallVector sumLoopIVs; - for (auto arg : outerLoopIVs) - sumLoopIVs.push_back(arg); - for (auto arg : sumIterationBlock.getArguments()) - sumLoopIVs.push_back(arg); - - // Sum up values. - Value sum = rewriter.create(loc, sumOp); - Value next = rewriter.create(loc, operands[0], sumLoopIVs); - Value sub = rewriter.create(loc, next, max); - Value exp = rewriter.create(loc, sub); - sum = rewriter.create(loc, sum, exp); - rewriter.create(loc, sum, sumOp); - // Store intermediate values in the result to avoid recomputation. - rewriter.create(loc, exp, alloc, sumLoopIVs); - - // Get the sum. - rewriter.setInsertionPoint(softmaxIterateOp); - sum = rewriter.create(loc, sumOp); - - // Insert instructions inside the softmax loop. - Block &softmaxIterationBlock = softmaxIterateOp.bodyRegion().front(); - rewriter.setInsertionPointToStart(&softmaxIterationBlock); - - // Get induction variables. - SmallVector softmaxLoopIVs; - for (auto arg : outerLoopIVs) - softmaxLoopIVs.push_back(arg); - for (auto arg : softmaxIterationBlock.getArguments()) - softmaxLoopIVs.push_back(arg); - - // Compute softmax. - Value expLoadedVal = rewriter.create(loc, alloc, softmaxLoopIVs); - Value result = rewriter.create(loc, expLoadedVal, sum); - rewriter.create(loc, result, alloc, softmaxLoopIVs); - - rewriter.replaceOp(op, alloc); - - return matchSuccess(); - } -}; - -struct ONNXReshapeOpLowering : public ConversionPattern { - ONNXReshapeOpLowering(MLIRContext *ctx) - : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - auto tensorType = (*op->result_type_begin()).cast(); - auto inputShape = operands[0].getType().cast().getShape(); - auto loc = op->getLoc(); - - // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertTensorToMemRef(tensorType); - auto memRefShape = memRefType.getShape(); - Value alloc; - - // Compute size in bytes using the input tensor. - Value tensorSize = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), - getMemRefEltSizeInBytes(memRefType))); - for (int i = 0; i < inputShape.size(); ++i) { - Value dimVal; - if (inputShape[i] < 0) { - Value dim = rewriter.create(loc, operands[0], i); - dimVal = - rewriter.create(loc, dim, rewriter.getIntegerType(64)); - } else { - dimVal = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), - inputShape[i])); - } - tensorSize = rewriter.create(loc, tensorSize, dimVal); - } - - bool insertDealloc = checkInsertDealloc(op); - if (hasAllConstantDimensions(memRefType)) { - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); - } else { - // If a dimension is zero, the actual dimension value is taken from the - // input tensor. - // - // If the shape array has a negative dimension (-1), we compute its actual - // dimension value from the other dimensions. But we don't have enough - // information about the other dimensions at this point. So, we need to - // scan the shape first to calculate reduction of all of the dimensions. - // If the reduction is negative, then the shape array contains a negative - // dimension. Otherwise, the reduction is the same as the one computed - // from the input tensor. - Value tensorSizeFromShape = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), - getMemRefEltSizeInBytes(memRefType))); - SmallVector DimInfo; - for (int i = 0; i < memRefShape.size(); ++i) { - Value index = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i)); - // Load index from array of indices. - Value loadedVal = rewriter.create(loc, operands[1], index); - // If a dimension is zero, the actual dimension value is taken from the - // input tensor. - // - // If a dimension is negative, it is computed from the other dimensions. - // But we don't have enough information about the other dimensions at - // this point. So, we let it as it is (-1), and compute it later. - if (i < inputShape.size()) { - Value dimVal; - auto loadedValType = loadedVal.getType().cast(); - if (inputShape[i] < 0) { - Value dim = rewriter.create(loc, operands[0], i); - dimVal = rewriter.create(loc, dim, loadedValType); - } else { - dimVal = rewriter.create( - loc, rewriter.getIntegerAttr(loadedValType, inputShape[i])); - } - auto zero = rewriter.create( - loc, rewriter.getIntegerAttr(loadedValType, 0)); - auto isZero = - rewriter.create(loc, CmpIPredicate::eq, loadedVal, zero); - loadedVal = rewriter.create(loc, isZero, dimVal, loadedVal); - } - // Check if the loaded index is already the correct width of 64 bits. - // Convert the value to a 64 bit integer if needed. - Value int64LoadedVal = loadedVal; - if (loadedVal.getType().cast().getWidth() < 64) - int64LoadedVal = rewriter.create( - loc, loadedVal, rewriter.getIntegerType(64)); - tensorSizeFromShape = - rewriter.create(loc, tensorSizeFromShape, int64LoadedVal); - // Store intermediate results to use later. - DimInfo.emplace_back(int64LoadedVal); - } - // Reverse tensorSizeFromShape since it is negative if the shape array has - // a negative dimension. This is safe since we only use it to compute the - // actual value for the negative dimension. - auto zero = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - tensorSizeFromShape = - rewriter.create(loc, zero, tensorSizeFromShape); - - // Obtain operands for AllocOp. - SmallVector allocOperands; - auto negOne = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), -1)); - - for (int i = 0; i < memRefShape.size(); ++i) { - auto dimVal = DimInfo[i]; - auto isNegOne = - rewriter.create(loc, CmpIPredicate::eq, dimVal, negOne); - // If dimension is negative, compute its value from the other - // dimensions. - auto actualDimVal = - rewriter.create(loc, tensorSize, tensorSizeFromShape); - auto loadedVal = - rewriter.create(loc, isNegOne, actualDimVal, dimVal); - allocOperands.push_back(rewriter.create( - loc, loadedVal, rewriter.getIndexType())); - } - AllocOp allocateMemref = - rewriter.create(loc, memRefType, allocOperands); - - // Make sure to allocate at the beginning of the block if - // all dimensions are known. - auto *parentBlock = allocateMemref.getOperation()->getBlock(); - if (insertDealloc) { - auto dealloc = rewriter.create(loc, allocateMemref); - dealloc.getOperation()->moveBefore(&parentBlock->back()); - } - - alloc = allocateMemref; - } - - rewriter.create(loc, alloc, operands[0], tensorSize); - rewriter.replaceOp(op, alloc); - - return matchSuccess(); - } -}; - -struct ONNXMatMulOpLowering : public ConversionPattern { - ONNXMatMulOpLowering(MLIRContext *ctx) - : ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - auto tensorType = (*op->result_type_begin()).cast(); - auto loc = op->getLoc(); - - Value A = operands[0]; - Value B = operands[1]; - auto AShape = A.getType().cast().getShape(); - auto BShape = B.getType().cast().getShape(); - - // There are three cases related to the shapes of the two arguments: - // - Both arguments are N-D, N >= 2 - // - Either argument is 1-D, the other is N-D, N >= 2 - // - Both arguments are 1-D - - // Result type - auto memRefType = convertTensorToMemRef(tensorType); - auto elementType = memRefType.getElementType(); - auto memRefShape = memRefType.getShape(); - - // A value zero - Value zero; - if (elementType.isa()) { - zero = rewriter.create( - loc, IntegerAttr::get(memRefType.getElementType(), 0)); - } else if (elementType.isa()) { - zero = rewriter.create( - loc, FloatAttr::get(memRefType.getElementType(), 0)); - } else { - emitError(loc, "unsupported element type"); - } - - // Insert an allocation and deallocation for the result of this operation. - Value alloc; - bool insertDealloc = checkInsertDealloc(op); - if (hasAllConstantDimensions(memRefType)) - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); - else { - SmallVector allocOperands; - if (AShape.size() >= 2 && BShape.size() >= 2) { - // Both arguments are N-D, N >= 2 - // (s1 x s2 x... x sK x M x K) MATMUL (K x N) - // => - // (s1 x s2 x... x sK x M x N) - for (int i = 0; i < memRefShape.size() - 2; ++i) { - if (memRefShape[i] < 0) { - if ((AShape.size() == 2) && (BShape.size() > 2)) - allocOperands.emplace_back(rewriter.create(loc, B, i)); - else if ((AShape.size() > 2) && (BShape.size() == 2)) - allocOperands.emplace_back(rewriter.create(loc, A, i)); - } - } - if (memRefShape[memRefShape.size() - 2] < 0) { - auto dim = rewriter.create(loc, A, memRefShape.size() - 2); - allocOperands.emplace_back(dim); - } - if (memRefShape[memRefShape.size() - 1] < 0) { - auto dim = rewriter.create(loc, B, memRefShape.size() - 1); - allocOperands.emplace_back(dim); - } - } else if (AShape.size() == 1 && BShape.size() >= 2) { - // Either argument is 1-D - // K MATMUL (s1 x s2 x... x sK x K x N) - // => - // (s1 x s2 x... x sK x N) - for (int i = 0; i < memRefShape.size() - 1; ++i) { - if (memRefShape[i] < 0) { - auto dim = rewriter.create(loc, B, i); - allocOperands.emplace_back(dim); - } - } - if (memRefShape[memRefShape.size() - 1] < 0) { - auto dim = rewriter.create(loc, B, BShape.size() - 1); - allocOperands.emplace_back(dim); - } - } else if (AShape.size() >= 2 && BShape.size() == 1) { - // Either argument is 1-D - // (s1 x s2 x... x sK x M x K) MATMUL K - // => - // (s1 x s2 x... x sK x M) - for (int i = 0; i < memRefShape.size() - 1; ++i) { - if (memRefShape[i] < 0) { - auto dim = rewriter.create(loc, A, i); - allocOperands.emplace_back(dim); - } - } - if (memRefShape[memRefShape.size() - 1] < 0) { - auto dim = rewriter.create(loc, A, AShape.size() - 2); - allocOperands.emplace_back(dim); - } - } else if (AShape.size() == 1 && BShape.size() == 1) { - // Both arguments are 1-D - if (memRefShape[0] < 0) { - auto dim = rewriter.create(loc, A, 0); - allocOperands.emplace_back(dim); - } - } else { - emitError(loc, "Invalid shapes"); - } - - alloc = rewriter.create(loc, memRefType, allocOperands); - } - - if (AShape.size() >= 2 || BShape.size() >= 2) { - // Cases 1 and 2: - // - Both arguments are N-D, N >= 2 - // - Either argument is 1-D, the other is N-D, N >= 2 - - // Define loops for batch dimensions. - std::vector originalLoops; - std::vector optimizedLoops; - Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, - optimizedLoops, memRefShape.size()); - - // Outer KrnlIterateOp - SmallVector loopBatchIVs; - bool hasBatchLoop = false; - if (AShape.size() > 2 || BShape.size() > 2) { - SmallVector batchAxes; - int matmulResultDims = - ((AShape.size() == 1 || BShape.size() == 1)) ? 1 : 2; - for (int i = 0; i < memRefShape.size() - matmulResultDims; ++i) - batchAxes.emplace_back(i); - - std::vector outerLoops, optimizedOuterLoops; - outerLoops.reserve(batchAxes.size()); - optimizedOuterLoops.reserve(batchAxes.size()); - for (int i = 0; i < batchAxes.size(); ++i) { - outerLoops.push_back(originalLoops[i]); - optimizedOuterLoops.push_back(optimizedLoops[i]); - } - KrnlIterateOperandPack outerPack(rewriter, outerLoops, - optimizedOuterLoops); - for (int i = 0; i < batchAxes.size(); ++i) { - addDimensionToPack(rewriter, loc, outerPack, alloc, i); - } - auto outerIterateOp = rewriter.create(loc, outerPack); - - // No optimization - rewriter.setInsertionPointToEnd(optimizationBlock); - rewriter.create(loc, originalLoops); - - // Insert instructions into the outer KrnlIterateOp. - Block &outerIterationBlock = outerIterateOp.bodyRegion().front(); - rewriter.setInsertionPointToStart(&outerIterationBlock); - - // Induction variables: non-matrix-multiplication variables. - for (auto arg : outerIterationBlock.getArguments()) { - loopBatchIVs.emplace_back(arg); - } - - hasBatchLoop = true; - } - - // Now, we define loops for matrix multiplication. - - // Create a KrnlIterateOp for matrix multiplication. - KrnlIterateOp matmulIterateOp; - std::vector matmulLoops, optimizedMatmulLoops; - if (AShape.size() >= 2 && BShape.size() >= 2) { - // 2-D x 2-D. Result has two dimensions. - matmulLoops.reserve(2); - optimizedMatmulLoops.reserve(2); - for (int i = 2; i > 0; --i) { - matmulLoops.emplace_back(originalLoops[memRefShape.size() - i]); - optimizedMatmulLoops.emplace_back( - optimizedLoops[memRefShape.size() - i]); - } - KrnlIterateOperandPack matmulPack(rewriter, matmulLoops, - optimizedMatmulLoops); - for (int i = 2; i > 0; --i) { - addDimensionToPack(rewriter, loc, matmulPack, alloc, - memRefShape.size() - i); - } - matmulIterateOp = rewriter.create(loc, matmulPack); - } else { - // 1-D x 2-D, and vice versa. Result has one dimension. - matmulLoops.reserve(1); - optimizedMatmulLoops.reserve(1); - matmulLoops.emplace_back(originalLoops[memRefShape.size() - 1]); - optimizedMatmulLoops.emplace_back( - optimizedLoops[memRefShape.size() - 1]); - KrnlIterateOperandPack matmulPack(rewriter, matmulLoops, - optimizedMatmulLoops); - addDimensionToPack(rewriter, loc, matmulPack, alloc, - memRefShape.size() - 1); - matmulIterateOp = rewriter.create(loc, matmulPack); - } - - if (!hasBatchLoop) { - // No optimization - rewriter.setInsertionPointToEnd(optimizationBlock); - rewriter.create(loc, originalLoops); - } - - // Insert instructions into the matmul KrnlIterateOp. - Block &matmulIterationBlock = matmulIterateOp.bodyRegion().front(); - rewriter.setInsertionPointToStart(&matmulIterationBlock); - - // Induction variables: M, N - SmallVector loopMNIVs; - for (auto arg : matmulIterationBlock.getArguments()) { - loopMNIVs.emplace_back(arg); - } - // Induction variables for the final result. - SmallVector loopBatchMNIVs; - for (auto arg : loopBatchIVs) { - loopBatchMNIVs.emplace_back(arg); - } - for (auto arg : loopMNIVs) { - loopBatchMNIVs.emplace_back(arg); - } - - // Fill the output with value 0. - rewriter.create(loc, zero, alloc, loopBatchMNIVs); - - // Iterate along the reduction dimension. - // Use a value from A. - std::vector reduceLoops; - std::vector optimizedReduceLoops; - Block *optimizationReduceBlock = - defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1); - KrnlIterateOperandPack reducePack(rewriter, reduceLoops, - optimizedReduceLoops); - addDimensionToPack(rewriter, loc, reducePack, A, AShape.size() - 1); - auto reduceIterateOp = rewriter.create(loc, reducePack); - - // No optimization - rewriter.setInsertionPointToEnd(optimizationReduceBlock); - rewriter.create(loc, reduceLoops); - - // Insert instructions into the reduction KrnlIterateOp. - Block &reduceIterationBlock = reduceIterateOp.bodyRegion().front(); - rewriter.setInsertionPointToStart(&reduceIterationBlock); - - // Induction variables - SmallVector loopKIVs, loopBatchMKIVs, loopBatchKNIVs; - // K - loopKIVs.emplace_back(reduceIterationBlock.getArguments()[0]); - // MK - if (AShape.size() > 2) - for (auto arg : loopBatchIVs) - loopBatchMKIVs.emplace_back(arg); - if (AShape.size() >= 2) - loopBatchMKIVs.emplace_back(loopMNIVs[0]); - loopBatchMKIVs.emplace_back(loopKIVs[0]); - // KN - if (BShape.size() > 2) - for (auto arg : loopBatchIVs) - loopBatchKNIVs.emplace_back(arg); - loopBatchKNIVs.emplace_back(loopKIVs[0]); - if (BShape.size() >= 2) - if (AShape.size() >= 2) - loopBatchKNIVs.emplace_back(loopMNIVs[1]); - else - loopBatchKNIVs.emplace_back(loopMNIVs[0]); - - // Matmul computation - auto loadedA = rewriter.create(loc, A, loopBatchMKIVs); - auto loadedB = rewriter.create(loc, B, loopBatchKNIVs); - auto loadedY = rewriter.create(loc, alloc, loopBatchMNIVs); - if (elementType.isa()) { - auto AB = rewriter.create(loc, loadedA, loadedB); - auto accumulated = rewriter.create(loc, loadedY, AB); - rewriter.create(loc, accumulated, alloc, loopBatchMNIVs); - } else if (elementType.isa()) { - auto AB = rewriter.create(loc, loadedA, loadedB); - auto accumulated = rewriter.create(loc, loadedY, AB); - rewriter.create(loc, accumulated, alloc, loopBatchMNIVs); - } - } else if ((AShape.size() == 1) && (BShape.size() == 1)) { - // Case 3: - // - Both arguments are 1-D - - // Fill the output with value 0. - Value zeroIndex = rewriter.create(loc, 0); - rewriter.create(loc, zero, alloc, zeroIndex); - - // Iterate along the reduction dimension. - // Use a value from A. - std::vector reduceLoops; - std::vector optimizedReduceLoops; - Block *optimizationReduceBlock = - defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1); - KrnlIterateOperandPack reducePack(rewriter, reduceLoops, - optimizedReduceLoops); - addDimensionToPack(rewriter, loc, reducePack, A, 0); - auto reduceIterateOp = rewriter.create(loc, reducePack); - - // No optimization - rewriter.setInsertionPointToEnd(optimizationReduceBlock); - rewriter.create(loc, reduceLoops); - - // Insert instructions into the reduction KrnlIterateOp. - Block &reduceIterationBlock = reduceIterateOp.bodyRegion().front(); - rewriter.setInsertionPointToStart(&reduceIterationBlock); - - // Induction variables - SmallVector loopKIVs; - // K - loopKIVs.emplace_back(reduceIterationBlock.getArgument(0)); - - // Matmul computation - auto loadedA = rewriter.create(loc, A, loopKIVs); - auto loadedB = rewriter.create(loc, B, loopKIVs); - auto loadedY = rewriter.create(loc, alloc, zeroIndex); - if (elementType.isa()) { - auto AB = rewriter.create(loc, loadedA, loadedB); - auto accumulated = rewriter.create(loc, loadedY, AB); - rewriter.create(loc, accumulated, alloc, zeroIndex); - } else if (elementType.isa()) { - auto AB = rewriter.create(loc, loadedA, loadedB); - auto accumulated = rewriter.create(loc, loadedY, AB); - rewriter.create(loc, accumulated, alloc, zeroIndex); - } - } else { - // No scalar matrix multiplication. - llvm_unreachable("Unsupported scalar matrix multiplication."); - } - - rewriter.replaceOp(op, alloc); - - return matchSuccess(); - } -}; - -struct ONNXGemmOpLowering : public ConversionPattern { - ONNXGemmOpLowering(MLIRContext *ctx) - : ConversionPattern(mlir::ONNXGemmOp::getOperationName(), 1, ctx) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - auto tensorType = (*op->result_type_begin()).cast(); - auto loc = op->getLoc(); - - Value A, B, C; - A = operands[0]; - B = operands[1]; - C = operands[2]; - - auto alphaAttr = - FloatAttr::get(tensorType.getElementType(), - llvm::dyn_cast(op).alpha().convertToFloat()); - auto betaAttr = - FloatAttr::get(tensorType.getElementType(), - llvm::dyn_cast(op).beta().convertToFloat()); - auto alpha = rewriter.create(loc, alphaAttr); - auto beta = rewriter.create(loc, betaAttr); - - bool isTransA = (llvm::dyn_cast(op).transA() != 0); - bool isTransB = (llvm::dyn_cast(op).transB() != 0); - - // Result type - auto memRefType = convertTensorToMemRef(tensorType); - - // Insert an allocation and deallocation for the result of this operation. - Value alloc; - bool insertDealloc = checkInsertDealloc(op); - if (hasAllConstantDimensions(memRefType)) - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); - else { - auto memRefShape = memRefType.getShape(); - SmallVector allocOperands; - if (memRefShape[0] < 0) { - auto dim = rewriter.create(loc, A, (isTransA) ? 1 : 0); - allocOperands.emplace_back(dim); - } - if (memRefShape[1] < 0) { - auto dim = rewriter.create(loc, B, (isTransB) ? 0 : 1); - allocOperands.emplace_back(dim); - } - alloc = rewriter.create(loc, memRefType, allocOperands); - if (insertDealloc) { - auto *parentBlock = alloc.getDefiningOp()->getBlock(); - auto dealloc = rewriter.create(loc, alloc); - dealloc.getOperation()->moveBefore(&parentBlock->back()); - } - } - - // Number of loops - auto memRefShape = memRefType.getShape(); - int64_t numLoops = 3; - - // Define loops. - std::vector originalLoops; - std::vector optimizedLoops; - Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, - optimizedLoops, numLoops); - - // We have two Krnl loops: - // - Outer loop iterates over the output matrix dimensions, and - // - Reduction loop iterates over the reduction dimension. - - // Outer loop - std::vector outerLoops, optimizedOuterLoops; - outerLoops.reserve(2); - optimizedOuterLoops.reserve(2); - for (int i = 0; i < 2; ++i) { - outerLoops.push_back(originalLoops[i]); - optimizedOuterLoops.push_back(optimizedLoops[i]); - } - KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops); - // Induction variables for the outer loops - for (int i = 0; i < 2; ++i) - addDimensionToPack(rewriter, loc, outerPack, alloc, i); - - // Reduction loop - std::vector reductionLoops, optimizedReductionLoops; - reductionLoops.reserve(1); - optimizedReductionLoops.reserve(1); - reductionLoops.push_back(originalLoops[2]); - optimizedReductionLoops.push_back(optimizedLoops[2]); - KrnlIterateOperandPack reductionPack(rewriter, reductionLoops, - optimizedReductionLoops); - // Induction variable for the reduction dimension - // Try to find and use a static value from A or B first. - // If it failed then use a dynamic value. - auto ATy = A.getType().cast(); - auto BTy = B.getType().cast(); - int K_A_Idx = (isTransA) ? 0 : 1; - int K_B_Idx = (isTransB) ? 1 : 0; - reductionPack.pushConstantBound(0); - if (ATy.getShape()[K_A_Idx] != -1) - reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]); - else if (BTy.getShape()[K_B_Idx] != -1) - reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]); - else - reductionPack.pushOperandBound( - rewriter.create(loc, B, K_B_Idx).getResult()); - - // Get run-time dimension information for unknown dimensions used for - // broadcasting. - // GemmOp supports unidirectional broadcasting from C to A*B. - // Hence, it must be enough to get broadcasting information for C only. - std::map broadcastedDimInfo; - auto shape = C.getType().cast().getShape(); - for (int i = 0; i < shape.size(); ++i) { - if (shape[i] < 0) { - auto dim = rewriter.create(loc, C, i).getResult(); - auto one = rewriter.create(loc, 1); - auto isBroadcasted = - rewriter.create(loc, CmpIPredicate::eq, dim, one); - broadcastedDimInfo.insert(std::make_pair(i, isBroadcasted)); - } - } - - auto outerIterateOp = rewriter.create(loc, outerPack); - - // Now perform the insertions into the body of the - // just generated instructions: - - // No optimization - rewriter.setInsertionPointToEnd(optimizationBlock); - rewriter.create(loc, originalLoops); - - // Insert instructions inside the outer loop. - Block &outerIterationBlock = outerIterateOp.bodyRegion().front(); - rewriter.setInsertionPointToStart(&outerIterationBlock); - - // Induction variables - SmallVector loopMNIVs; - for (auto arg : outerIterationBlock.getArguments()) { - loopMNIVs.emplace_back(arg); - } - - // Initialize the output of A*B - auto zero = rewriter.create( - loc, FloatAttr::get(memRefType.getElementType(), 0)); - rewriter.create(loc, zero, alloc, loopMNIVs); - - // Compute A*B - auto matmulIterateOp = rewriter.create(loc, reductionPack); - - // Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting) - auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C, - broadcastedDimInfo); - auto loadedC = rewriter.create(loc, C, loopCIVs); - auto loadedAB = rewriter.create(loc, alloc, loopMNIVs); - auto alphaAB = rewriter.create(loc, alpha, loadedAB); - auto betaC = rewriter.create(loc, beta, loadedC); - auto Y = rewriter.create(loc, alphaAB, betaC); - rewriter.create(loc, Y, alloc, loopMNIVs); - - // Insert instructions to do matrix multiplication: A*B - Block &matmulIterationBlock = matmulIterateOp.bodyRegion().front(); - rewriter.setInsertionPointToStart(&matmulIterationBlock); - - // Induction variables - SmallVector loopKIVs, loopAIVs, loopBIVs; - for (auto arg : matmulIterationBlock.getArguments()) - loopKIVs.emplace_back(arg); - if (isTransA) { - loopAIVs.emplace_back(loopKIVs[0]); - loopAIVs.emplace_back(loopMNIVs[0]); - } else { - loopAIVs.emplace_back(loopMNIVs[0]); - loopAIVs.emplace_back(loopKIVs[0]); - } - if (isTransB) { - loopBIVs.emplace_back(loopMNIVs[1]); - loopBIVs.emplace_back(loopKIVs[0]); - } else { - loopBIVs.emplace_back(loopKIVs[0]); - loopBIVs.emplace_back(loopMNIVs[1]); - } - - // Matmul computation - auto loadedA = rewriter.create(loc, A, loopAIVs); - auto loadedB = rewriter.create(loc, B, loopBIVs); - auto loadedY = rewriter.create(loc, alloc, loopMNIVs); - auto AB = rewriter.create(loc, loadedA, loadedB); - auto accumulated = rewriter.create(loc, loadedY, AB); - rewriter.create(loc, accumulated, alloc, loopMNIVs); - - rewriter.replaceOp(op, alloc); - - return matchSuccess(); - } -}; - -struct ONNXUnsqueezeOpLowering : public ConversionPattern { - ONNXUnsqueezeOpLowering(MLIRContext *ctx) - : ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - auto loc = op->getLoc(); - auto tensorType = (*op->result_type_begin()).cast(); - int outRank = tensorType.getRank(); - - // Assume that `axes` has been validated by shape inference. - // So, here we just get it. - ArrayAttr axisAttrs = llvm::dyn_cast(op).axesAttr(); - SmallVector axes; - for (auto axisAttr : axisAttrs.getValue()) { - int axis = axisAttr.cast().getInt(); - axis = axis >= 0 ? axis : (outRank + axis); - axes.emplace_back(axis); - } - - // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertTensorToMemRef(tensorType); - Value alloc; - - // Compute size in bytes. - Value tensorSize = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), - getMemRefEltSizeInBytes(memRefType))); - - bool insertDealloc = checkInsertDealloc(op); - auto memRefShape = memRefType.getShape(); - if (hasAllConstantDimensions(memRefType)) { - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); - for (int i = 0; i < memRefShape.size(); ++i) { - Value dimVal = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), - memRefShape[i])); - tensorSize = rewriter.create(loc, tensorSize, dimVal); - } - } else { - // Unknown dimensions are always the operand's dimensions. - SmallVector allocOperands; - for (int outIdx = 0, inIdx = 0; outIdx < memRefShape.size(); ++outIdx) { - Value dimVal = nullptr; - if (memRefShape[outIdx] < 0) { - Value index = rewriter.create(loc, operands[0], inIdx); - dimVal = rewriter.create(loc, index, - rewriter.getIntegerType(64)); - allocOperands.emplace_back(index); - } else { - dimVal = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), - memRefShape[outIdx])); - } - tensorSize = rewriter.create(loc, tensorSize, dimVal); - if (std::find(axes.begin(), axes.end(), outIdx) == axes.end()) - inIdx++; - } - alloc = rewriter.create(loc, memRefType, allocOperands); - auto *parentBlock = alloc.getDefiningOp()->getBlock(); - if (insertDealloc) { - auto dealloc = rewriter.create(loc, alloc); - dealloc.getOperation()->moveBefore(&parentBlock->back()); - } - } - rewriter.create(loc, alloc, operands[0], tensorSize); - rewriter.replaceOp(op, alloc); - return matchSuccess(); - } -}; - -struct ONNXTransposeOpLowering : public ConversionPattern { - ONNXTransposeOpLowering(MLIRContext *ctx) - : ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - auto tensorType = (*op->result_type_begin()).cast(); - auto loc = op->getLoc(); - // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertTensorToMemRef(tensorType); - Value alloc; - bool insertDealloc = checkInsertDealloc(op); - - if (hasAllConstantDimensions(memRefType)) - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); - else - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, - {operands[0]}); - - // Number of loops - auto memRefShape = memRefType.getShape(); - int64_t rank = memRefShape.size(); - - // Define loops. - std::vector originalLoops; - std::vector optimizedLoops; - Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, - optimizedLoops, rank); - - KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); - // Iterate over the loop nest using the input shape. - for (int i = 0; i < rank; ++i) - addDimensionToPack(rewriter, loc, pack, operands[0], i); - - auto iterateOp = rewriter.create(loc, pack); - Block &iterationBlock = iterateOp.bodyRegion().front(); - - // Now perform the insertions into the body of the - // just generated instructions: - - // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. - rewriter.setInsertionPointToEnd(optimizationBlock); - // Return from KrnlOptimizeLoopsOp body. - // When no optimizations are present we just return the loops - // unchaged. - rewriter.create(loc, originalLoops); - - // 2. Insert instructions inside the KernelIterateOp body. - rewriter.setInsertionPointToStart(&iterationBlock); - - // Handle the operation. - - // Read perm attribute. - SmallVector perm; - auto permAttribute = llvm::dyn_cast(op).permAttr(); - if (permAttribute) { - for (auto permVal : permAttribute.getValue()) - perm.emplace_back(permVal.cast().getInt()); - } else { - // TODO: Remove when perm is guaranteed to be present (even for - // the default case). This means that perm was added by shape - // inference or another pass to contain the values corresponding - // to the default behavior of Transpose. - for (int i = iterationBlock.getArguments().size() - 1; i >= 0; i--) - perm.emplace_back(i); - } - - SmallVector inLoopIVs; - for (auto arg : iterationBlock.getArguments()) - inLoopIVs.emplace_back(arg); - - SmallVector outLoopIVs; - for (int i = 0; i < iterationBlock.getArguments().size(); ++i) - outLoopIVs.emplace_back(iterationBlock.getArguments()[perm[i]]); - - auto inVal = rewriter.create(loc, operands[0], inLoopIVs); - rewriter.create(loc, inVal, alloc, outLoopIVs); - - rewriter.replaceOp(op, alloc); - - return matchSuccess(); - } -}; - -struct ONNXIdentityOpLowering : public ConversionPattern { - ONNXIdentityOpLowering(MLIRContext *ctx) - : ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - rewriter.replaceOp(op, operands[0]); - return matchSuccess(); - } -}; - -struct ONNXConvNoBiasOpLowering : public ConversionPattern { - ONNXConvNoBiasOpLowering(MLIRContext *ctx) - : ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - auto tensorType = (*op->result_type_begin()).cast(); - auto loc = op->getLoc(); - // Insert an allocation and deallocation for the result of this operation. - auto memRefType = convertTensorToMemRef(tensorType); - Value alloc; - bool insertDealloc = checkInsertDealloc(op); - ONNXConvNoBiasOp convOp = llvm::dyn_cast(op); - - if (hasAllConstantDimensions(memRefType)) - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); - else - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, - {operands[0]}); - - auto resultShape = memRefType.getShape(); - auto inputShape = operands[0].getType().cast().getShape(); - auto kernelShape = operands[1].getType().cast().getShape(); - - // R = ConvNoBias(D, K) - // - // The input/output shapes will look like this: - // - // D (NxCxHxW) x K (MxC/groupxKHxKW) -> R (NxMxRHxRW) - // - // M is a multiple of the number of groups: - // M = group * kernelsPerGroup - // - // The loop nest will look as follows: - // - // strides = [s1, s2] - // - // kernelsPerGroup = M / group; - // for n = 0 .. N: - // for g = 0 .. group: - // for m = 0 .. kernelsPerGroup: - // kernel = g * kernelsPerGroup + m; - // for r1 = 0 .. RH: - // for r2 = 0 .. RW: - // R[n][kernel][r1][r2] = 0; - // for c = 0 .. C/group: - // for k1 = 0 .. KH: - // for k2 = 0 .. KW: - // R[n][kernel][r1][r2] = - // D[n][g * (C / group) + c][s1 * r1 + k1][s2 * r2 + k2] * - // K[kernel][c][k1][k2]; - // - // Naming: - // n, g, m: outer loop nest indices - // r1, r2: spatial loop nest indices - // c, k1, k2: inner loop nest indices - // - // TODO: handle padding. - // - // In the general case: - // - // D (NxCxD1xD2x...xDdim) x K (MxC/groupxK1xK2x...xKdim) - // -> R (NxMxR1xR2x...xRdim) - // - // The above loop nest can be adapted by increasing the number - // of r- and k-index loop i.e. r1 r2 and k1 k2 loops. - - // Set up outermost loops: n g m r1 r2 ... rdim - // Skip g if group is 1. - - // Before we start the iteration we need to compute the number of - // unsplit kernels and fetch the number of groups from the attribute - // list. Group is always a compilation constant. - int64_t group = convOp.group().getSExtValue(); - // Compute the number of unsplit kernels. The number of kernels - // must be a multiple of the number of groups. - int64_t kernelsPerGroup = floor(kernelShape[0] / group); - auto kernelsPerGroupValue = - rewriter.create(loc, kernelsPerGroup); - auto zero = rewriter.create( - loc, FloatAttr::get(memRefType.getElementType(), 0)); - Value subchannels; - if (kernelShape[1] < 0) { - subchannels = - rewriter.create(loc, operands[1], 1).getResult(); - } else { - subchannels = rewriter.create( - loc, kernelShape[1]); - } - - // 1. Define outer loops and emit empty optimization block: - int64_t nOuterLoops = (group > 1) ? 3 : 2; - std::vector outerLoops; - std::vector optimizedOuterLoops; - Block *optimizationBlock = defineLoops(rewriter, loc, outerLoops, - optimizedOuterLoops, nOuterLoops); - - // Prepare iteration arguments over outer loop nest. - KrnlIterateOperandPack pack( - rewriter, outerLoops, optimizedOuterLoops); - // for n = 0 .. N: - pack.pushConstantBound(0); - if (inputShape[0] < 0) - pack.pushOperandBound( - rewriter.create(loc, operands[0], 0).getResult()); - else - pack.pushConstantBound(inputShape[0]); - // for g = 0 .. N: - if (group > 1) { - pack.pushConstantBound(0); - pack.pushConstantBound(group); - } - // for m = 0 .. kernelsPerGroup: - pack.pushConstantBound(0); - pack.pushConstantBound(kernelsPerGroup); - // Outer loop iteration. - auto iterateOp = rewriter.create(loc, pack); - Block &outerIterationBlock = iterateOp.bodyRegion().front(); - // Emit optimizations for outer loops: - rewriter.setInsertionPointToEnd(optimizationBlock); - rewriter.create(loc, outerLoops); - rewriter.setInsertionPointToStart(&outerIterationBlock); - { - // 2. Emit the body of the outer loop nest. - - // 2.1 Compute kernel order number: kernel = g * kernelsPerGroup + m; - // If group is not set then the value of the kernel ID is - // identical to that of the loop over kernels. - Value kernel = outerIterationBlock.getArguments()[1]; - if (group > 1) { - // Middle loop is over groups and third loop is over the - // kernel identifiers in the current group. - auto kernelsOffset = rewriter.create(loc, - outerIterationBlock.getArguments()[1], - kernelsPerGroupValue); - kernel = rewriter.create(loc, kernelsOffset, - outerIterationBlock.getArguments()[2]); - } - - // 2.2 Define spatial loops - int64_t nSpatialLoops = resultShape.size() - 2; - std::vector spatialLoops; - std::vector optimizedSpatialLoops; - Block *optSpatialLoopBlock = defineLoops(rewriter, loc, spatialLoops, - optimizedSpatialLoops, nSpatialLoops); - - // 2.3 Prepare iteration arguments for spatial loop nest. - KrnlIterateOperandPack spatialPack( - rewriter, spatialLoops, optimizedSpatialLoops); - for (int i = 2; i < resultShape.size(); ++i) - addDimensionToPack(rewriter, loc, spatialPack, alloc, i); - - // 2.4 Emit loop nest over output spatial dimensions. - // for rX = 0 .. RX - auto spatialIterateOp = - rewriter.create(loc, spatialPack); - Block &spatialIterationBlock = spatialIterateOp.bodyRegion().front(); - // 2.5 Emit optimizations for outer loops: - rewriter.setInsertionPointToEnd(optSpatialLoopBlock); - rewriter.create(loc, spatialLoops); - rewriter.setInsertionPointToStart(&spatialIterationBlock); - { - // 3. Emit the body of the spatial loop nest. - // 3.1 Emit: R[n][kernel][r1][r2] = 0; - SmallVector resultIndices; - // n - resultIndices.emplace_back(outerIterationBlock.getArguments()[0]); - // kernel - resultIndices.emplace_back(kernel); - // rX - for (auto arg : spatialIterationBlock.getArguments()) - resultIndices.emplace_back(arg); - // Store initializer value into output location. - rewriter.create(loc, zero, alloc, resultIndices); - - // 3.2 Define inner loops. - int64_t nInnerLoops = 1 + (kernelShape.size() - 2); - std::vector innerLoops; - std::vector optimizedInnerLoops; - Block *optInnerLoopBlock = defineLoops(rewriter, loc, innerLoops, - optimizedInnerLoops, nInnerLoops); - - // 3.3 Prepare iteration arguments for inner loop nest. - KrnlIterateOperandPack innerPack( - rewriter, innerLoops, optimizedInnerLoops); - // for c = 0 .. C/group - innerPack.pushConstantBound(0); - innerPack.pushConstantBound(kernelShape[1]); - // for Kx = 0 .. KX - for (int i = 2; i < kernelShape.size(); ++i) - addDimensionToPack(rewriter, loc, innerPack, operands[1], i); - - // 3.4 Emit inner loop nest. - auto innerIterateOp = - rewriter.create(loc, innerPack); - Block &innerIterationBlock = innerIterateOp.bodyRegion().front(); - // 3.5 Emit optimizations for outer loops: - rewriter.setInsertionPointToEnd(optInnerLoopBlock); - rewriter.create(loc, innerLoops); - rewriter.setInsertionPointToStart(&innerIterationBlock); - { - // 4. Emit inner loop body - // R[n][kernel][r1][r2] = - // D[n][g * (C / group) + c][s1 * r1 + k1][s2 * r2 + k2] * - // K[kernel][c][k1][k2]; - - // 4.1 Prepare indices for accesing the data tensor. - SmallVector dataIndices; - // n - dataIndices.emplace_back(outerIterationBlock.getArguments()[0]); - // g * (C / group) + c - Value channelDepth = innerIterationBlock.getArguments()[0]; - if (group > 1) - channelDepth = rewriter.create(loc, channelDepth, - rewriter.create(loc, subchannels, - outerIterationBlock.getArguments()[1])); - dataIndices.emplace_back(channelDepth); - // sX * rX + kX - auto stridesAttribute = convOp.stridesAttr(); - // Read strides attribute - SmallVector strides; - if (stridesAttribute) - for (auto stride : stridesAttribute.getValue()) - strides.emplace_back(stride.cast().getInt()); - for (int i = 0; i < kernelShape.size() - 2; ++i) { - Value spatialIndex = spatialIterationBlock.getArguments()[i]; - // If strides are present then emit the correct access index. - if (stridesAttribute && strides[i] > 1) - spatialIndex = rewriter.create(loc, - rewriter.create(loc, strides[i]), - spatialIterationBlock.getArguments()[i]); - dataIndices.emplace_back( - rewriter.create(loc, spatialIndex, - innerIterationBlock.getArguments()[i+1])); - } - - // 4.2 Prepare indices for accessing the kernel tensor. - SmallVector kernelIndices; - // kernel - kernelIndices.emplace_back(kernel); - // c - kernelIndices.emplace_back(innerIterationBlock.getArguments()[0]); - // kX - for (int i = 0; i < kernelShape.size() - 2; ++i) - kernelIndices.emplace_back( - innerIterationBlock.getArguments()[i+1]); - - // 4.3 Compute convolution. - auto loadData = - rewriter.create(loc, operands[0], dataIndices); - auto loadKernel = - rewriter.create(loc, operands[1], kernelIndices); - auto loadPartialSum = - rewriter.create(loc, alloc, resultIndices); - Value result = rewriter.create(loc, loadPartialSum, - rewriter.create(loc, loadData, loadKernel)); - // 4.4 Store computed value into output location. - rewriter.create(loc, result, alloc, resultIndices); - } - } - } - rewriter.replaceOp(op, alloc); - - return matchSuccess(); - } -}; - -//===----------------------------------------------------------------------===// -// Reduction ops lowering to Krnl dialect. -//===----------------------------------------------------------------------===// -template -struct ONNXReductionOpLowering : public ConversionPattern { - ONNXReductionOpLowering(MLIRContext *ctx) - : ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - /* - * Condition: reduction function must be associative and commutative. - * - * Example 1 (here, reduction function is `+`): - * Induction variables: (i0, i1, i2) - * axes = [0, 2] - * keepdims = true - * krnl.iterate() with (i0, i1, i2) { - * Y(0, i1, 0) += X(i0, i1, i2) - * } - * - * Example 2 (here, reduction function is `+`): - * Induction variables: (i0, i1, i2) - * axes = [0, 2] - * keepdims = false - * krnl.iterate() with (i0, i1, i2) { - * Y(i1) += X(i0, i1, i2) - * } - * - */ - auto loc = op->getLoc(); - auto memRefInType = operands[0].getType().cast(); - auto memRefInShape = memRefInType.getShape(); - auto tensorOutType = (*op->result_type_begin()).cast(); - int64_t inRank = memRefInType.getRank(); - int64_t outRank = tensorOutType.getRank(); - - // Get attributes - ArrayAttr axisAttrs = llvm::dyn_cast(op).axesAttr(); - std::vector axes; - if (axisAttrs) { - for (auto axisAttr : axisAttrs.getValue()) { - int64_t axis = axisAttr.cast().getInt(); - axis = axis >= 0 ? axis : (inRank + axis); - assert(axis >= -inRank && axis <= inRank - 1); - if (std::find(axes.begin(), axes.end(), axis) == axes.end()) - axes.push_back(axis); - } - } else { - for (decltype(inRank) i = 0; i < inRank; ++i) { - axes.push_back(i); - } - } - // KeepDims - auto keepdims = - llvm::dyn_cast(op).keepdims(); - bool isKeepdims = (keepdims == 1) ? true : false; - - // Get type information - auto memRefOutType = convertTensorToMemRef(tensorOutType); - auto memRefOutShape = memRefOutType.getShape(); - auto elementOutType = memRefOutType.getElementType(); - std::map outInDimMap = - getReductionMapping(memRefInType, axes, isKeepdims); - - // Insert an allocation and deallocation for the result of this operation. - Value alloc; - bool insertDealloc = checkInsertDealloc(op); - if (hasAllConstantDimensions(memRefOutType)) { - alloc = insertAllocAndDealloc(memRefOutType, loc, rewriter, insertDealloc); - } else { - SmallVector allocOperands; - for (decltype(outRank) i = 0; i < outRank; ++i) { - if (memRefOutShape[i] < 0) { - auto dim = rewriter.create(loc, operands[0], outInDimMap[i]); - allocOperands.push_back(dim); - } - } - alloc = rewriter.create(loc, memRefOutType, allocOperands); - if (insertDealloc) { - auto *parentBlock = alloc.getDefiningOp()->getBlock(); - auto dealloc = rewriter.create(loc, alloc); - dealloc.getOperation()->moveBefore(&parentBlock->back()); - } - } - - // There are two Krnl loops: - // - One to initialize the result memref, and - // - One to do reduction - - // Define loops to initialize the result. - std::vector originalLoopsInit; - std::vector optimizedLoopsInit; - Block *optimizationBlockInit = defineLoops(rewriter, loc, originalLoopsInit, - optimizedLoopsInit, outRank); - - // Iteration information - KrnlIterateOperandPack packInit(rewriter, originalLoopsInit, - optimizedLoopsInit); - for (decltype(outRank) i = 0; i < outRank; ++i) { - addDimensionToPack(rewriter, loc, packInit, alloc, i); - } - auto iterateOpInit = rewriter.create(loc, packInit); - Block &iterationBlockInit = iterateOpInit.bodyRegion().front(); - - // Perform the insertions into the body of the initialization loop. - // No optimization - rewriter.setInsertionPointToEnd(optimizationBlockInit); - rewriter.create(loc, originalLoopsInit); - - // Insert instructions inside the KernelIterateOp body. - rewriter.setInsertionPointToStart(&iterationBlockInit); - - // Handle the operation: - SmallVector loopIVs; - for (auto arg : iterationBlockInit.getArguments()) { - loopIVs.push_back(arg); - } - - Value identity; - if (elementOutType.isa()) { - identity = rewriter.create( - loc, FloatAttr::get(elementOutType, - getIdentityValue())); - } else if (elementOutType.isa()) { - identity = rewriter.create( - loc, IntegerAttr::get(elementOutType, - getIdentityValue())); - } else { - emitError(loc, "unsupported element type"); - } - rewriter.create(loc, identity, alloc, loopIVs); - - // Define an Krnl loop to do reduction. - rewriter.setInsertionPointAfter(iterateOpInit); - std::vector originalLoops, optimizedLoops; - Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, - optimizedLoops, inRank); - // Iteration information - KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); - for (decltype(inRank) i = 0; i < inRank; ++i) { - addDimensionToPack(rewriter, loc, pack, operands[0], i); - } - auto iterateOp = rewriter.create(loc, pack); - Block &iterationBlock = iterateOp.bodyRegion().front(); - - // Perform the insertions into the body of the reduction loop. - // No optimization - rewriter.setInsertionPointToEnd(optimizationBlock); - rewriter.create(loc, originalLoops); - - // Insert instructions inside the KernelIterateOp body. - rewriter.setInsertionPointToStart(&iterationBlock); - - // Handle the operation: - SmallVector inLoopIVs, outLoopIVs; - auto args = iterationBlock.getArguments(); - for (int i = 0; i < args.size(); ++i) { - inLoopIVs.push_back(args[i]); - } - Value zeroIndex = nullptr; - for (decltype(inRank) i = 0; i < outRank; ++i) { - if (outInDimMap.find(i) != outInDimMap.end()) { - outLoopIVs.push_back(inLoopIVs[outInDimMap[i]]); - } else { - if (zeroIndex) { - outLoopIVs.push_back(zeroIndex); - } else { - zeroIndex = rewriter.create(loc, 0); - outLoopIVs.push_back(zeroIndex); - } - } - } - - Value next, accumulated; - next = rewriter.create(loc, operands[0], inLoopIVs); - accumulated = rewriter.create(loc, alloc, outLoopIVs); - accumulated = mapToLowerScalarOp( - op, memRefOutType.getElementType(), {accumulated, next}, rewriter); - rewriter.create(loc, accumulated, alloc, outLoopIVs); - - rewriter.replaceOp(op, alloc); - return matchSuccess(); - } -}; - -//===----------------------------------------------------------------------===// -// EntryPoint Op lowering to Krnl Entry Point. -//===----------------------------------------------------------------------===// - -class ONNXEntryPointLowering : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - PatternMatchResult matchAndRewrite(ONNXEntryPointOp op, - PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, - op.getAttrOfType( - ONNXEntryPointOp::getEntryPointFuncAttrName()), - op.getAttrOfType(ONNXEntryPointOp::getNumInputsAttrName()), - op.getAttrOfType( - ONNXEntryPointOp::getNumOutputsAttrName())); - return matchSuccess(); - } -}; - -//===----------------------------------------------------------------------===// -// Conversion from Tensor type to the Standard dialect MemRef type. -//===----------------------------------------------------------------------===// - -struct TensorTypeConverter : public TypeConverter { - using TypeConverter::TypeConverter; - - LogicalResult convertType(Type t, SmallVectorImpl &results) override { - if (auto tensor_type = t.dyn_cast()) { - results.push_back(convertTensorToMemRef(tensor_type)); - return success(); - } - - results.push_back(t); - return success(); - } - - /// Return true if the inputs and outputs of the given function type are - /// legal. [Taken from MLIR and adapted to only check the legality of the - /// inputs. Once unranked results can be handled gracefully this - /// override needs to be removed in favour of the original MLIR one.] - bool isSignatureLegal(FunctionType funcType) { - return llvm::all_of(funcType.getInputs(), - [this](Type type) { return isLegal(type); }); - } -}; - -} // end anonymous namespace. - -//===----------------------------------------------------------------------===// -// Frontend to Krnl Dialect lowering pass -//===----------------------------------------------------------------------===// - -/// This is a partial lowering to Krnl loops of the ONNX operations. -namespace { -struct FrontendToKrnlLoweringPass - : public ModulePass { - void runOnModule() final; -}; -} // end anonymous namespace. - -void FrontendToKrnlLoweringPass::runOnModule() { - auto module = getModule(); - - // The first thing to define is the conversion target. This will define the - // final target for this lowering. - ConversionTarget target(getContext()); - - // We define the specific operations, or dialects, that are legal targets for - // this lowering. - target - .addLegalDialect(); - - // TODO: enable this once more ops are supported. - // We also define the ONNX dialect as Illegal so that the conversion will fail - // if any of these operations are *not* converted. - // target.addIllegalDialect(); - - // TODO: add any other ops which are considered legal. - // Some operations can be marked as being still legal. - // Example: target.addLegalOp(); - - // Now that the conversion target has been defined, we just need to provide - // the set of patterns that will lower the frontend operations. - OwningRewritePatternList patterns; - - // Convert TensorType to MemRef - TensorTypeConverter tensor_to_memref_converter; - target.addDynamicallyLegalOp([&](FuncOp op) { - // FuncOp is legal only if types have been converted to Std types. - return tensor_to_memref_converter.isSignatureLegal(op.getType()); - }); - - // Type conversion for function signatures. - // Call MLIR FuncOp signature conversion when result type is - // a ranked tensor. - populateFuncOpTypeConversionPattern(patterns, &getContext(), - tensor_to_memref_converter); - - // Frontent operation lowering. - patterns.insert, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseUnaryOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXElementwiseVariadicOpLowering, - ONNXReshapeOpLowering, ONNXEntryPointLowering, - ONNXReductionOpLowering, - ONNXReductionOpLowering, - ONNXReductionOpLowering, - ONNXReductionOpLowering, - ONNXSoftmaxOpLowering, ONNXGemmOpLowering, - ONNXUnsqueezeOpLowering, ONNXTransposeOpLowering, - ONNXIdentityOpLowering, ONNXConvNoBiasOpLowering, - ONNXMatMulOpLowering>(&getContext()); - - // With the target and rewrite patterns defined, we can now attempt the - // conversion. The conversion will signal failure if any of our `illegal` - // operations were not converted successfully. - if (failed(applyPartialConversion(module, target, patterns))) - signalPassFailure(); -} - -std::unique_ptr mlir::createLowerToKrnlPass() { - return std::make_unique(); -} - -static PassRegistration - pass("lower-frontend", "Lower frontend ops to Krnl dialect.");