diff --git a/.circleci/config.yml b/.circleci/config.yml index 48fda88..3863f72 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -38,7 +38,7 @@ jobs: - run: name: Run End-To-End Tests command: | - sudo pip install -q onnx + sudo pip install -q -e ./ONNF/third_party/onnx cd ONNF/build cmake --build . --target run-onnx-backend-test - run: diff --git a/doc/Dialects/onnx.md b/doc/Dialects/onnx.md index ff0e8c5..e4ca150 100644 --- a/doc/Dialects/onnx.md +++ b/doc/Dialects/onnx.md @@ -1558,33 +1558,6 @@ ONNX Gather operation 1. `output`: memref of any type values or tensor of any type values -### onnx.GemmNoBias (ONNXGemmNoBiasOp) -ONNX general matrix multiply operation without bias. - -#### Description: - - -The "onnx.Gemm" generic matrix multiplication without bias. - - -#### Operands: - -1. `A`: memref of any type values or tensor of any type values -1. `B`: memref of any type values or tensor of any type values - -#### Attributes: - -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `alpha` | `FloatAttr` | 32-bit float attribute attribute | -| `beta` | `FloatAttr` | 32-bit float attribute attribute | -| `transA` | `IntegerAttr` | 64-bit integer attribute attribute | -| `transB` | `IntegerAttr` | 64-bit integer attribute attribute | - -#### Results: - -1. `o_Y`: memref of any type values or tensor of any type values - ### onnx.Gemm (ONNXGemmOp) ONNX Gemm operation diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d895be5..b210275 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -62,7 +62,21 @@ target_include_directories(onnf_shape_inference target_link_libraries(onnf_shape_inference ${MLIRLibs}) add_dependencies(onnf_shape_inference gen_krnl_ops) -add_library(onnf_lower_frontend conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp) +add_library(onnf_lower_frontend + conversion/onnx_to_krnl/onnx_to_krnl_common.cpp + conversion/onnx_to_krnl/onnx_to_krnl_common.hpp + conversion/onnx_to_krnl/math/elementwise.cpp + conversion/onnx_to_krnl/math/gemm.cpp + conversion/onnx_to_krnl/math/matmul.cpp + conversion/onnx_to_krnl/math/reduction.cpp + conversion/onnx_to_krnl/math/softmax.cpp + conversion/onnx_to_krnl/nn/conv.cpp + conversion/onnx_to_krnl/nn/normalization.cpp + conversion/onnx_to_krnl/tensor/identity.cpp + conversion/onnx_to_krnl/tensor/reshape.cpp + conversion/onnx_to_krnl/tensor/transpose.cpp + conversion/onnx_to_krnl/tensor/unsqueeze.cpp + conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp) target_include_directories(onnf_lower_frontend PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT} ${ONNF_SRC_ROOT}) diff --git a/src/builder/frontend_dialect_transformer.cpp b/src/builder/frontend_dialect_transformer.cpp index cd23e8c..0efca22 100644 --- a/src/builder/frontend_dialect_transformer.cpp +++ b/src/builder/frontend_dialect_transformer.cpp @@ -189,8 +189,9 @@ private: } } - mlir::Type elementType = - convertONNXTypeToMLIRType(input.type().tensor_type().elem_type()); + auto elementOnnxType = + (onnx::TensorProto_DataType)input.type().tensor_type().elem_type(); + mlir::Type elementType = convertONNXTypeToMLIRType(elementOnnxType); llvm::ArrayRef tensor_dims(dims.data(), dims.size()); arg_types.emplace_back( mlir::RankedTensorType::get(tensor_dims, elementType)); diff --git a/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp b/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp index 84d4be8..ffc7219 100644 --- a/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp +++ b/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp @@ -8,404 +8,11 @@ // Krnl IR and standard operations. // //===----------------------------------------------------------------------===// -#include -#include "mlir/Dialect/AffineOps/AffineOps.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/Sequence.h" - -#include "src/dialect/krnl/krnl_helper.hpp" -#include "src/dialect/krnl/krnl_ops.hpp" -#include "src/dialect/onnx/onnx_ops.hpp" -#include "src/pass/passes.hpp" +#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp" using namespace mlir; -//===----------------------------------------------------------------------===// -// FrontendToAffine RewritePatterns -//===----------------------------------------------------------------------===// - -/// Check is all dimensions are known at compile time. -static bool hasAllConstantDimensions(MemRefType type) { - auto memRefShape = type.getShape(); - for (int i = 0; i < memRefShape.size(); ++i) - if (memRefShape[i] < 0) - return false; - return true; -} - -/// Get the corresponding MemRefType of a given TensorType/MemRefType. -static MemRefType convertToMemRefType(Type type) { - MemRefType memRefType; - auto tensorType = type.dyn_cast(); - if (tensorType) { - assert(tensorType.hasRank() && "expected only ranked shapes"); - memRefType = - MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - } else { - memRefType = type.dyn_cast(); - } - return memRefType; -} - -/// Insert an allocation and deallocation for the given MemRefType. -static Value insertAllocAndDealloc(MemRefType type, Location loc, - PatternRewriter &rewriter, - bool insertDealloc, - ArrayRef operands = {}) { - // Put together alloc operands for any dynamic dimensions of the memref. - AllocOp alloc; - if (!operands.empty()) { - auto memRefShape = type.getShape(); - auto rank = memRefShape.size(); - - std::map fromOperands; - for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { - int memRefDimIdx = rank - 1 - reversedIdx; - if (memRefShape[memRefDimIdx] < 0) { // unknown dimension - Value maxDim = nullptr; - for (int i = 0; i < operands.size(); i++) { - auto operandShape = - operands[i].getType().cast().getShape(); - int operandDimIdx = operandShape.size() - 1 - reversedIdx; - - if (operandDimIdx < 0) - continue; - - // In case of operations with broadcasting, the dimension of the - // alloc result is the maximum size along each dimension of the - // operands. - auto operandDim = - rewriter.create(loc, operands[i], operandDimIdx); - if (maxDim) { - auto maxCondition = rewriter.create(loc, CmpIPredicate::sgt, - operandDim, maxDim); - maxDim = rewriter.create(loc, maxCondition, operandDim, - maxDim); - } else { - maxDim = operandDim; - } - } - fromOperands.insert(std::make_pair(memRefDimIdx, maxDim)); - } - } - - SmallVector allocOperands; - for (int i = 0; i < rank; ++i) - if (memRefShape[i] < 0) - allocOperands.push_back(fromOperands[i]); - alloc = rewriter.create(loc, type, allocOperands); - } else { - alloc = rewriter.create(loc, type); - } - - // Make sure to allocate at the beginning of the block if - // all dimensions are known. - auto *parentBlock = alloc.getOperation()->getBlock(); - if (hasAllConstantDimensions(type)) - alloc.getOperation()->moveBefore(&parentBlock->front()); - - if (insertDealloc) { - auto dealloc = rewriter.create(loc, alloc); - dealloc.getOperation()->moveBefore(&parentBlock->back()); - } - - return alloc; -} - -// Determine if current function returns the result value of the -// current op being lowered. If it does then dealloc should not be -// inserted. -static bool checkInsertDealloc(Operation *currentOp) { - auto parentBlock = currentOp->getBlock(); - - bool insertDealloc = true; - parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) { - assert(currentOp->getNumResults() < 2 && - "No more than one result supported (for now)."); - // If there is at least one result to investigate. - if (currentOp->getNumResults() > 0) { - auto result = currentOp->getResult(0); - for (const auto &operand : op.getOperands()) - if (operand == result) - insertDealloc = false; - } - }); - - return insertDealloc; -} - -// Create a mapping from result type's dimensions to input type's dimensions, -// given that the result type is the result of a reduction op over the input -// type. -std::map -getReductionMapping(MemRefType inputTy, ArrayRef axes, bool keepdims) { - std::map OutInDimMap; - int64_t rank = inputTy.getRank(); - - // Mark reduction axes. - std::vector isReductionAxis; - for (decltype(rank) i = 0; i < rank; ++i) { - if (std::find(axes.begin(), axes.end(), i) != axes.end()) - isReductionAxis.push_back(true); - else - isReductionAxis.push_back(false); - } - - for (decltype(rank) inIndex = 0, outIndex = 0; inIndex < rank; ++inIndex) { - // If it is a reduction axis, there is no relationship among dimensions. - if (isReductionAxis[inIndex]) { - if (keepdims) - outIndex++; - } else { - OutInDimMap.insert(std::make_pair(outIndex, inIndex)); - outIndex++; - } - } - - return OutInDimMap; -} - -// Add bounds associated with the op operand to the KRNL iteration pack. -// Dynamic dimenions are supported. -static void addDimensionToPack(ConversionPatternRewriter &rewriter, - Location loc, KrnlIterateOperandPack &pack, - Value operand, int index) { - auto shape = operand.getType().cast().getShape(); - if (shape[index] < 0) { - pack.pushConstantBound(0); - pack.pushOperandBound( - rewriter.create(loc, operand, index).getResult()); - } else { - pack.pushConstantBound(0); - pack.pushConstantBound(shape[index]); - } -} - -// Function that defines the KRNL dialect loops and their respective -// optimized version. -static KrnlOptimizeLoopsOp -emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc, - std::vector &loops, - std::vector &optimizedLoops, int64_t numLoops) { - // Define loops. - auto loopsOp = rewriter.create(loc, numLoops); - loops.reserve(numLoops); - for (auto result : loopsOp.getResults()) - loops.push_back(result); - - // Define optimized version of the loops. - auto optimizedLoopsOp = rewriter.create(loc, numLoops); - optimizedLoops.reserve(numLoops); - for (auto result : optimizedLoopsOp.getResults()) - optimizedLoops.push_back(result); - - return optimizedLoopsOp; -} - -// Function that emits the loops and their optimized version. -// The function returns a reference to the inner optimization block. -static Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc, - std::vector &loops, - std::vector &optimizedLoops, - int64_t numLoops) { - KrnlOptimizeLoopsOp optimizedLoopsOp = - emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops); - return &optimizedLoopsOp.region().front(); -} - -// Function which emits a basic set of loops and optimized loops -// for a given operation argument. A reference to the loop optimization -// block is returned in the last argument of the function. -static void emitKrnlLoopsAndIterationForOperand( - ConversionPatternRewriter &rewriter, Location loc, Value operand, - std::vector &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp, - KrnlIterateOp &iterateOp) { - // Operand shape. - auto shape = operand.getType().cast().getShape(); - - // Number of loops. - int64_t rank = shape.size(); - - // Define loops and optimized loops. - std::vector optimizedLoops; - optimizedLoopsOp = - emitOptimizedLoops(rewriter, loc, originalLoops, optimizedLoops, rank); - - KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); - // Iterate over the loop nest. - for (int i = 0; i < rank; ++i) - addDimensionToPack(rewriter, loc, pack, operand, i); - - iterateOp = rewriter.create(loc, pack); -} - -unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { - auto elementType = memRefType.getElementType(); - - unsigned sizeInBits; - if (elementType.isIntOrFloat()) { - sizeInBits = elementType.getIntOrFloatBitWidth(); - } else { - auto vectorType = elementType.cast(); - sizeInBits = - vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); - } - return llvm::divideCeil(sizeInBits, 8); -} - -// Get run-time dimension information for unknown dimensions used for -// broadcasting. -std::map> -getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, - MemRefType memRefType, ArrayRef operands) { - auto memRefShape = memRefType.getShape(); - int64_t rank = memRefShape.size(); - // For unknown dimensions, we need to get dimension values at runtime in - // order to do broadcasting. - std::map> DimInfo; - // For each result dimension, compute the number of sharing operands. - // Sharing operands are operands sharing the same index (counting from the - // rightmost to the leftmost) for a given dimension. - std::map sharedDimCount; - for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { - int dimIdx = rank - 1 - reversedIdx; - sharedDimCount[dimIdx] = 0; - for (int i = 0; i < operands.size(); ++i) { - auto shape = operands[i].getType().cast().getShape(); - if (reversedIdx <= shape.size() - 1) - sharedDimCount[dimIdx]++; - } - } - // An unknown dimension can have a value of 1 or N (N > 1). - // If its value is 1, it is broadcasted dimension. - // Otherwise, non-broadcasted dimension. - // We only care about unknown dimensions whose number of sharing operands is - // more than one, since they are potentially broadcasted dimensions. - for (int i = 0; i < operands.size(); ++i) { - std::map broadcastedDims; - auto shape = operands[i].getType().cast().getShape(); - int size = shape.size(); - for (int j = 0; j < shape.size(); ++j) { - if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) { - auto dim = rewriter.create(loc, operands[i], j).getResult(); - auto one = rewriter.create(loc, 1); - auto isBroadcasted = - rewriter.create(loc, CmpIPredicate::eq, dim, one); - broadcastedDims.insert(std::make_pair(j, isBroadcasted)); - } - } - DimInfo.insert(std::make_pair(i, broadcastedDims)); - } - return DimInfo; -} - -// Extract induction variables that are used for broadcasting values of a -// given operand. -std::vector -getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, - ArrayRef loopIVs, Value operand, - std::map broadcastedDims) { - // `operand` must has a ranked type. This should have been checked by the - // shape inference pass. - auto operandShape = operand.getType().cast().getShape(); - auto rank = operandShape.size(); - auto loopCount = loopIVs.size(); - - std::vector newLoopIVs; - for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { - auto dimIdx = rank - 1 - reversedIdx; - auto loopIdx = loopCount - 1 - reversedIdx; - if (operandShape[dimIdx] == 1) { - // Broadcasted dimension - auto zero = rewriter.create(loc, 0); - newLoopIVs.insert(newLoopIVs.begin(), zero); - } else if ((operandShape[dimIdx] == -1) && - (broadcastedDims.find(dimIdx) != broadcastedDims.end())) { - // Unknown dimension, it can have a value of 1 or N (N > 1). - // If its value is 1, it is broadcasted dimension. - // Otherwise, non-broadcasted dimension. - auto zero = rewriter.create(loc, 0); - auto idx = rewriter.create(loc, broadcastedDims[dimIdx], zero, - loopIVs[loopIdx]); - newLoopIVs.insert(newLoopIVs.begin(), idx); - } else { - // Non-broadcasted dimension - newLoopIVs.insert(newLoopIVs.begin(), loopIVs[loopIdx]); - } - } - return newLoopIVs; -} - -namespace { - -// This is to get a scalar operation of a given type for a specific operation. -template -struct ScalarOp { - using FOp = void; - using IOp = void; -}; - -template -using ScalarFOp = typename ScalarOp::FOp; -template -using ScalarIOp = typename ScalarOp::IOp; - -// Get the identity element of a operation. -// Return NULL if the function does not have identity. -template -DataType getIdentityValue() { - return NULL; -} - -//===----------------------------------------------------------------------===// -// This is used in the innermost loop of a KrnlIterateOp to insert computation -// composed of one or many scalar ops. -// Use template specialization for each of different ONNX operations. -//===----------------------------------------------------------------------===// -template -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - auto loc = op->getLoc(); - Type element_type = operands.front().getType(); - if (element_type.isa()) { - return rewriter.create>(loc, result_types, operands, - mlir::None); - } else if (element_type.isa()) { - return rewriter.create>(loc, result_types, operands, - mlir::None); - } else { - emitError(loc, "unsupported element type"); - return nullptr; - } -} - -// We divide the operator lowering into different categories. -// These categories are mostly similar to the operator categories in ONNX: -// https://github.com/onnx/onnx/tree/master/onnx/defs. -// Besides, it is better to put operators with the same computation pattern into -// the same category, e.g. element-wise operators will belong to the elementwise -// category. - -// Math -#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/elementwise.inc" -#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc" -#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/reduction.inc" -#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/softmax.inc" -#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/matmul.inc" -// Tensor -#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/identity.inc" -#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/reshape.inc" -#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/transpose.inc" -#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc" -// Neural network -#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc" -#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/normalization.inc" - //===----------------------------------------------------------------------===// // EntryPoint Op lowering to Krnl Entry Point. //===----------------------------------------------------------------------===// @@ -427,39 +34,6 @@ public: } }; -//===----------------------------------------------------------------------===// -// Conversion from Tensor type to the Standard dialect MemRef type. -//===----------------------------------------------------------------------===// - -struct TensorTypeConverter : public TypeConverter { - using TypeConverter::TypeConverter; - - TensorTypeConverter() { - addConversion(convertType); - } - - static LogicalResult convertType(Type t, SmallVectorImpl &results) { - if (auto type = convertToMemRefType(t)) { - results.push_back(type); - return success(); - } - - results.push_back(t); - return success(); - } - - /// Return true if the inputs and outputs of the given function type are - /// legal. [Taken from MLIR and adapted to only check the legality of the - /// inputs. Once unranked results can be handled gracefully this - /// override needs to be removed in favour of the original MLIR one.] - bool isSignatureLegal(FunctionType funcType) { - return llvm::all_of(funcType.getInputs(), - [this](Type type) { return isLegal(type); }); - } -}; - -} // end anonymous namespace. - //===----------------------------------------------------------------------===// // Frontend to Krnl Dialect lowering pass //===----------------------------------------------------------------------===// diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/math/elementwise.inc b/src/conversion/onnx_to_krnl/math/elementwise.cpp similarity index 99% rename from src/conversion/onnx_to_krnl/rewrite_patterns/math/elementwise.inc rename to src/conversion/onnx_to_krnl/math/elementwise.cpp index 945d4da..b397281 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/math/elementwise.inc +++ b/src/conversion/onnx_to_krnl/math/elementwise.cpp @@ -1,4 +1,4 @@ -//===----- elementwise.inc - Elementwise Ops ------------------------------===// +//===----- elementwise.cpp - Elementwise Ops ------------------------------===// // // Copyright 2019 The IBM Research Authors. // @@ -8,6 +8,10 @@ // //===----------------------------------------------------------------------===// +#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp" + +using namespace mlir; + template <> struct ScalarOp { using FOp = AddFOp; diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc b/src/conversion/onnx_to_krnl/math/gemm.cpp similarity index 96% rename from src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc rename to src/conversion/onnx_to_krnl/math/gemm.cpp index 8a9bf8e..0eed272 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc +++ b/src/conversion/onnx_to_krnl/math/gemm.cpp @@ -1,4 +1,4 @@ -//===----- gemm.inc - Lowering Gemm Op ------------------------------------===// +//===----- gemm.cpp - Lowering Gemm Op ------------------------------------===// // // Copyright 2019 The IBM Research Authors. // @@ -8,6 +8,10 @@ // //===----------------------------------------------------------------------===// +#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp" + +using namespace mlir; + template struct ONNXGemmOpLowering : public ConversionPattern { ONNXGemmOpLowering(MLIRContext *ctx) @@ -17,9 +21,7 @@ struct ONNXGemmOpLowering : public ConversionPattern { matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - // The first predicate is unnecessary when we remove ONXGemmNoBiasOp. - bool hasBias = (operands.size() == 3) && - (!op->getOperand(2).getType().isa()); + bool hasBias = !op->getOperand(2).getType().isa(); Value A, B, C; A = operands[0]; @@ -215,5 +217,4 @@ struct ONNXGemmOpLowering : public ConversionPattern { void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns, MLIRContext *ctx) { patterns.insert>(ctx); - patterns.insert>(ctx); } diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/math/matmul.inc b/src/conversion/onnx_to_krnl/math/matmul.cpp similarity index 98% rename from src/conversion/onnx_to_krnl/rewrite_patterns/math/matmul.inc rename to src/conversion/onnx_to_krnl/math/matmul.cpp index 1af1f1b..a3cb26a 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/math/matmul.inc +++ b/src/conversion/onnx_to_krnl/math/matmul.cpp @@ -1,4 +1,4 @@ -//===----- matmul.inc - Lowering Matmul Op --------------------------------===// +//===----- matmul.cpp - Lowering Matmul Op --------------------------------===// // // Copyright 2019 The IBM Research Authors. // @@ -8,6 +8,10 @@ // //===----------------------------------------------------------------------===// +#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp" + +using namespace mlir; + struct ONNXMatMulOpLowering : public ConversionPattern { ONNXMatMulOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {} diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/math/reduction.inc b/src/conversion/onnx_to_krnl/math/reduction.cpp similarity index 98% rename from src/conversion/onnx_to_krnl/rewrite_patterns/math/reduction.inc rename to src/conversion/onnx_to_krnl/math/reduction.cpp index 9b94861..42b074a 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/math/reduction.inc +++ b/src/conversion/onnx_to_krnl/math/reduction.cpp @@ -1,4 +1,4 @@ -//===----- reduction.inc - Lowering Reduction Ops -------------------------===// +//===----- reduction.cpp - Lowering Reduction Ops -------------------------===// // // Copyright 2019 The IBM Research Authors. // @@ -8,6 +8,10 @@ // //===----------------------------------------------------------------------===// +#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp" + +using namespace mlir; + // Identity values template <> float getIdentityValue(){ diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/math/softmax.inc b/src/conversion/onnx_to_krnl/math/softmax.cpp similarity index 98% rename from src/conversion/onnx_to_krnl/rewrite_patterns/math/softmax.inc rename to src/conversion/onnx_to_krnl/math/softmax.cpp index 3f24a6e..3277635 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/math/softmax.inc +++ b/src/conversion/onnx_to_krnl/math/softmax.cpp @@ -1,4 +1,4 @@ -//===----- softmax.inc - Softmax Op ---------------------------------------===// +//===----- softmax.cpp - Softmax Op ---------------------------------------===// // // Copyright 2019 The IBM Research Authors. // @@ -8,6 +8,10 @@ // //===----------------------------------------------------------------------===// +#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp" + +using namespace mlir; + struct ONNXSoftmaxOpLowering : public ConversionPattern { ONNXSoftmaxOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {} diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc b/src/conversion/onnx_to_krnl/nn/conv.cpp similarity index 98% rename from src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc rename to src/conversion/onnx_to_krnl/nn/conv.cpp index 6e3afe1..851668a 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc +++ b/src/conversion/onnx_to_krnl/nn/conv.cpp @@ -1,4 +1,4 @@ -//===----- conv.inc - Lowering Convolution Op -----------------------------===// +//===----- conv.cpp - Lowering Convolution Op -----------------------------===// // // Copyright 2019 The IBM Research Authors. // @@ -8,6 +8,10 @@ // //===----------------------------------------------------------------------===// +#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp" + +using namespace mlir; + struct ONNXConvNoBiasOpLowering : public ConversionPattern { ONNXConvNoBiasOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {} diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/nn/normalization.inc b/src/conversion/onnx_to_krnl/nn/normalization.cpp similarity index 97% rename from src/conversion/onnx_to_krnl/rewrite_patterns/nn/normalization.inc rename to src/conversion/onnx_to_krnl/nn/normalization.cpp index cb98b13..d151f0a 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/nn/normalization.inc +++ b/src/conversion/onnx_to_krnl/nn/normalization.cpp @@ -1,4 +1,4 @@ -//===----- normalization.inc - Lowering Normalization Ops -----------------===// +//===----- normalization.cpp - Lowering Normalization Ops -----------------===// // // Copyright 2019 The IBM Research Authors. // @@ -8,6 +8,10 @@ // //===----------------------------------------------------------------------===// +#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp" + +using namespace mlir; + struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern { ONNXBatchNormalizationTestModeOpLowering(MLIRContext *ctx) : ConversionPattern( diff --git a/src/conversion/onnx_to_krnl/onnx_to_krnl_common.cpp b/src/conversion/onnx_to_krnl/onnx_to_krnl_common.cpp new file mode 100644 index 0000000..16bc499 --- /dev/null +++ b/src/conversion/onnx_to_krnl/onnx_to_krnl_common.cpp @@ -0,0 +1,324 @@ +//====-- onnx_to_krnl_common.cpp - ONNX dialects to Krnl lowering ---------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file contains common code shared by the functions performing the +// lowering to the KRNL dialect. +// +//===----------------------------------------------------------------------===// + +#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp" + +/// Check is all dimensions are known at compile time. +bool hasAllConstantDimensions(MemRefType type) { + auto memRefShape = type.getShape(); + for (int i = 0; i < memRefShape.size(); ++i) + if (memRefShape[i] < 0) + return false; + return true; +} + +/// Get the corresponding MemRefType of a given TensorType/MemRefType. +MemRefType convertToMemRefType(Type type) { + MemRefType memRefType; + auto tensorType = type.dyn_cast(); + if (tensorType) { + assert(tensorType.hasRank() && "expected only ranked shapes"); + memRefType = + MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + } else { + memRefType = type.dyn_cast(); + } + return memRefType; +} + +/// Insert an allocation and deallocation for the given MemRefType. +Value insertAllocAndDealloc(MemRefType type, Location loc, + PatternRewriter &rewriter, + bool insertDealloc, + ArrayRef operands) { + // Put together alloc operands for any dynamic dimensions of the memref. + AllocOp alloc; + if (!operands.empty()) { + auto memRefShape = type.getShape(); + auto rank = memRefShape.size(); + + std::map fromOperands; + for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { + int memRefDimIdx = rank - 1 - reversedIdx; + if (memRefShape[memRefDimIdx] < 0) { // unknown dimension + Value maxDim = nullptr; + for (int i = 0; i < operands.size(); i++) { + auto operandShape = + operands[i].getType().cast().getShape(); + int operandDimIdx = operandShape.size() - 1 - reversedIdx; + + if (operandDimIdx < 0) + continue; + + // In case of operations with broadcasting, the dimension of the + // alloc result is the maximum size along each dimension of the + // operands. + auto operandDim = + rewriter.create(loc, operands[i], operandDimIdx); + if (maxDim) { + auto maxCondition = rewriter.create(loc, CmpIPredicate::sgt, + operandDim, maxDim); + maxDim = rewriter.create(loc, maxCondition, operandDim, + maxDim); + } else { + maxDim = operandDim; + } + } + fromOperands.insert(std::make_pair(memRefDimIdx, maxDim)); + } + } + + SmallVector allocOperands; + for (int i = 0; i < rank; ++i) + if (memRefShape[i] < 0) + allocOperands.push_back(fromOperands[i]); + alloc = rewriter.create(loc, type, allocOperands); + } else { + alloc = rewriter.create(loc, type); + } + + // Make sure to allocate at the beginning of the block if + // all dimensions are known. + auto *parentBlock = alloc.getOperation()->getBlock(); + if (hasAllConstantDimensions(type)) + alloc.getOperation()->moveBefore(&parentBlock->front()); + + if (insertDealloc) { + auto dealloc = rewriter.create(loc, alloc); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + } + + return alloc; +} + +// Determine if current function returns the result value of the +// current op being lowered. If it does then dealloc should not be +// inserted. +bool checkInsertDealloc(Operation *currentOp) { + auto parentBlock = currentOp->getBlock(); + + bool insertDealloc = true; + parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) { + assert(currentOp->getNumResults() < 2 && + "No more than one result supported (for now)."); + // If there is at least one result to investigate. + if (currentOp->getNumResults() > 0) { + auto result = currentOp->getResult(0); + for (const auto &operand : op.getOperands()) + if (operand == result) + insertDealloc = false; + } + }); + + return insertDealloc; +} + +// Create a mapping from result type's dimensions to input type's dimensions, +// given that the result type is the result of a reduction op over the input +// type. +std::map +getReductionMapping(MemRefType inputTy, ArrayRef axes, bool keepdims) { + std::map OutInDimMap; + int64_t rank = inputTy.getRank(); + + // Mark reduction axes. + std::vector isReductionAxis; + for (decltype(rank) i = 0; i < rank; ++i) { + if (std::find(axes.begin(), axes.end(), i) != axes.end()) + isReductionAxis.push_back(true); + else + isReductionAxis.push_back(false); + } + + for (decltype(rank) inIndex = 0, outIndex = 0; inIndex < rank; ++inIndex) { + // If it is a reduction axis, there is no relationship among dimensions. + if (isReductionAxis[inIndex]) { + if (keepdims) + outIndex++; + } else { + OutInDimMap.insert(std::make_pair(outIndex, inIndex)); + outIndex++; + } + } + + return OutInDimMap; +} + +// Add bounds associated with the op operand to the KRNL iteration pack. +// Dynamic dimenions are supported. +void addDimensionToPack(ConversionPatternRewriter &rewriter, + Location loc, KrnlIterateOperandPack &pack, + Value operand, int index) { + auto shape = operand.getType().cast().getShape(); + if (shape[index] < 0) { + pack.pushConstantBound(0); + pack.pushOperandBound( + rewriter.create(loc, operand, index).getResult()); + } else { + pack.pushConstantBound(0); + pack.pushConstantBound(shape[index]); + } +} + +// Function that defines the KRNL dialect loops and their respective +// optimized version. +KrnlOptimizeLoopsOp +emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc, + std::vector &loops, + std::vector &optimizedLoops, int64_t numLoops) { + // Define loops. + auto loopsOp = rewriter.create(loc, numLoops); + loops.reserve(numLoops); + for (auto result : loopsOp.getResults()) + loops.push_back(result); + + // Define optimized version of the loops. + auto optimizedLoopsOp = rewriter.create(loc, numLoops); + optimizedLoops.reserve(numLoops); + for (auto result : optimizedLoopsOp.getResults()) + optimizedLoops.push_back(result); + + return optimizedLoopsOp; +} + +// Function that emits the loops and their optimized version. +// The function returns a reference to the inner optimization block. +Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc, + std::vector &loops, + std::vector &optimizedLoops, + int64_t numLoops) { + KrnlOptimizeLoopsOp optimizedLoopsOp = + emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops); + return &optimizedLoopsOp.region().front(); +} + +// Function which emits a basic set of loops and optimized loops +// for a given operation argument. A reference to the loop optimization +// block is returned in the last argument of the function. +void emitKrnlLoopsAndIterationForOperand( + ConversionPatternRewriter &rewriter, Location loc, Value operand, + std::vector &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp, + KrnlIterateOp &iterateOp) { + // Operand shape. + auto shape = operand.getType().cast().getShape(); + + // Number of loops. + int64_t rank = shape.size(); + + // Define loops and optimized loops. + std::vector optimizedLoops; + optimizedLoopsOp = + emitOptimizedLoops(rewriter, loc, originalLoops, optimizedLoops, rank); + + KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); + // Iterate over the loop nest. + for (int i = 0; i < rank; ++i) + addDimensionToPack(rewriter, loc, pack, operand, i); + + iterateOp = rewriter.create(loc, pack); +} + +unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { + auto elementType = memRefType.getElementType(); + + unsigned sizeInBits; + if (elementType.isIntOrFloat()) { + sizeInBits = elementType.getIntOrFloatBitWidth(); + } else { + auto vectorType = elementType.cast(); + sizeInBits = + vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); + } + return llvm::divideCeil(sizeInBits, 8); +} + +// Get run-time dimension information for unknown dimensions used for +// broadcasting. +std::map> +getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, + MemRefType memRefType, ArrayRef operands) { + auto memRefShape = memRefType.getShape(); + int64_t rank = memRefShape.size(); + // For unknown dimensions, we need to get dimension values at runtime in + // order to do broadcasting. + std::map> DimInfo; + // For each result dimension, compute the number of sharing operands. + // Sharing operands are operands sharing the same index (counting from the + // rightmost to the leftmost) for a given dimension. + std::map sharedDimCount; + for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { + int dimIdx = rank - 1 - reversedIdx; + sharedDimCount[dimIdx] = 0; + for (int i = 0; i < operands.size(); ++i) { + auto shape = operands[i].getType().cast().getShape(); + if (reversedIdx <= shape.size() - 1) + sharedDimCount[dimIdx]++; + } + } + // An unknown dimension can have a value of 1 or N (N > 1). + // If its value is 1, it is broadcasted dimension. + // Otherwise, non-broadcasted dimension. + // We only care about unknown dimensions whose number of sharing operands is + // more than one, since they are potentially broadcasted dimensions. + for (int i = 0; i < operands.size(); ++i) { + std::map broadcastedDims; + auto shape = operands[i].getType().cast().getShape(); + int size = shape.size(); + for (int j = 0; j < shape.size(); ++j) { + if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) { + auto dim = rewriter.create(loc, operands[i], j).getResult(); + auto one = rewriter.create(loc, 1); + auto isBroadcasted = + rewriter.create(loc, CmpIPredicate::eq, dim, one); + broadcastedDims.insert(std::make_pair(j, isBroadcasted)); + } + } + DimInfo.insert(std::make_pair(i, broadcastedDims)); + } + return DimInfo; +} + +// Extract induction variables that are used for broadcasting values of a +// given operand. +std::vector +getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, + ArrayRef loopIVs, Value operand, + std::map broadcastedDims) { + // `operand` must has a ranked type. This should have been checked by the + // shape inference pass. + auto operandShape = operand.getType().cast().getShape(); + auto rank = operandShape.size(); + auto loopCount = loopIVs.size(); + + std::vector newLoopIVs; + for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) { + auto dimIdx = rank - 1 - reversedIdx; + auto loopIdx = loopCount - 1 - reversedIdx; + if (operandShape[dimIdx] == 1) { + // Broadcasted dimension + auto zero = rewriter.create(loc, 0); + newLoopIVs.insert(newLoopIVs.begin(), zero); + } else if ((operandShape[dimIdx] == -1) && + (broadcastedDims.find(dimIdx) != broadcastedDims.end())) { + // Unknown dimension, it can have a value of 1 or N (N > 1). + // If its value is 1, it is broadcasted dimension. + // Otherwise, non-broadcasted dimension. + auto zero = rewriter.create(loc, 0); + auto idx = rewriter.create(loc, broadcastedDims[dimIdx], zero, + loopIVs[loopIdx]); + newLoopIVs.insert(newLoopIVs.begin(), idx); + } else { + // Non-broadcasted dimension + newLoopIVs.insert(newLoopIVs.begin(), loopIVs[loopIdx]); + } + } + return newLoopIVs; +} diff --git a/src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp b/src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp new file mode 100644 index 0000000..bd22d95 --- /dev/null +++ b/src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp @@ -0,0 +1,217 @@ +//====-- onnx_to_krnl_common.hpp - ONNX dialects to Krnl lowering ---------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file contains common code shared by the functions performing the +// lowering to the KRNL dialect. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Sequence.h" +#include "mlir/IR/PatternMatch.h" + +#include "src/dialect/krnl/krnl_helper.hpp" +#include "src/dialect/krnl/krnl_ops.hpp" +#include "src/dialect/onnx/onnx_ops.hpp" +#include "src/pass/passes.hpp" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Common functions used when lowering the ONNX frontend dialect to KRNL. +//===----------------------------------------------------------------------===// + +/// Check is all dimensions are known at compile time. +bool hasAllConstantDimensions(MemRefType type); + +/// Get the corresponding MemRefType of a given TensorType/MemRefType. +MemRefType convertToMemRefType(Type type); + +/// Insert an allocation and deallocation for the given MemRefType. +Value insertAllocAndDealloc(MemRefType type, Location loc, + PatternRewriter &rewriter, + bool insertDealloc, + ArrayRef operands = {}); + +// Determine if current function returns the result value of the +// current op being lowered. If it does then dealloc should not be +// inserted. +bool checkInsertDealloc(Operation *currentOp); + +// Create a mapping from result type's dimensions to input type's dimensions, +// given that the result type is the result of a reduction op over the input +// type. +std::map +getReductionMapping(MemRefType inputTy, ArrayRef axes, bool keepdims); + +// Add bounds associated with the op operand to the KRNL iteration pack. +// Dynamic dimenions are supported. +void addDimensionToPack(ConversionPatternRewriter &rewriter, + Location loc, KrnlIterateOperandPack &pack, + Value operand, int index); + +// Function that defines the KRNL dialect loops and their respective +// optimized version. +KrnlOptimizeLoopsOp +emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc, + std::vector &loops, + std::vector &optimizedLoops, int64_t numLoops); + +// Function that emits the loops and their optimized version. +// The function returns a reference to the inner optimization block. +Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc, + std::vector &loops, + std::vector &optimizedLoops, + int64_t numLoops); + +// Function which emits a basic set of loops and optimized loops +// for a given operation argument. A reference to the loop optimization +// block is returned in the last argument of the function. +void emitKrnlLoopsAndIterationForOperand( + ConversionPatternRewriter &rewriter, Location loc, Value operand, + std::vector &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp, + KrnlIterateOp &iterateOp); + +unsigned getMemRefEltSizeInBytes(MemRefType memRefType); + +// Get run-time dimension information for unknown dimensions used for +// broadcasting. +std::map> +getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter, + MemRefType memRefType, ArrayRef operands); + +// Extract induction variables that are used for broadcasting values of a +// given operand. +std::vector +getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter, + ArrayRef loopIVs, Value operand, + std::map broadcastedDims); + +//===----------------------------------------------------------------------===// +// This is to get a scalar operation of a given type for a specific operation. +//===----------------------------------------------------------------------===// +template +struct ScalarOp { + using FOp = void; + using IOp = void; +}; + +template +using ScalarFOp = typename ScalarOp::FOp; +template +using ScalarIOp = typename ScalarOp::IOp; + +// Get the identity element of a operation. +// Return NULL if the function does not have identity. +template +DataType getIdentityValue() { + return NULL; +} + +//===----------------------------------------------------------------------===// +// This is used in the innermost loop of a KrnlIterateOp to insert computation +// composed of one or many scalar ops. +// Use template specialization for each of different ONNX operations. +//===----------------------------------------------------------------------===// +template +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + auto loc = op->getLoc(); + Type element_type = operands.front().getType(); + if (element_type.isa()) { + return rewriter.create>(loc, result_types, operands, + mlir::None); + } else if (element_type.isa()) { + return rewriter.create>(loc, result_types, operands, + mlir::None); + } else { + emitError(loc, "unsupported element type"); + return nullptr; + } +} + +//===----------------------------------------------------------------------===// +// Conversion from Tensor type to the Standard dialect MemRef type. +//===----------------------------------------------------------------------===// + +struct TensorTypeConverter : public TypeConverter { + using TypeConverter::TypeConverter; + + TensorTypeConverter() { + addConversion(convertType); + } + + static LogicalResult convertType(Type t, SmallVectorImpl &results) { + if (auto type = convertToMemRefType(t)) { + results.push_back(type); + return success(); + } + + results.push_back(t); + return success(); + } + + /// Return true if the inputs and outputs of the given function type are + /// legal. [Taken from MLIR and adapted to only check the legality of the + /// inputs. Once unranked results can be handled gracefully this + /// override needs to be removed in favour of the original MLIR one.] + bool isSignatureLegal(FunctionType funcType) { + return llvm::all_of(funcType.getInputs(), + [this](Type type) { return isLegal(type); }); + } +}; + +//===----------------------------------------------------------------------===// +// Functions to add lowering patterns for frontend operations. +//===----------------------------------------------------------------------===// + +// `math` directory methods: + +void populateLoweringONNXElementwiseOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx); + +void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns, + MLIRContext *ctx); + +void populateLoweringONNXMatMulOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx); + +void populateLoweringONNXReductionOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx); + +void populateLoweringONNXSoftmaxOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx); + +// `nn` directory methods: + +void populateLoweringONNXConvOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx); + +void populateLoweringONNXNormalizationOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx); + +// `tensor` directory methods: + +void populateLoweringONNXUnsqueezeOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx); + +void populateLoweringONNXTransposeOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx); + +void populateLoweringONNXReshapeOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx); + +void populateLoweringONNXIdentityOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx); diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/identity.inc b/src/conversion/onnx_to_krnl/tensor/identity.cpp similarity index 85% rename from src/conversion/onnx_to_krnl/rewrite_patterns/tensor/identity.inc rename to src/conversion/onnx_to_krnl/tensor/identity.cpp index 2ff1633..45985af 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/identity.inc +++ b/src/conversion/onnx_to_krnl/tensor/identity.cpp @@ -1,4 +1,4 @@ -//===----- identity.inc - Lowering Identity Op ----------------------------===// +//===----- identity.cpp - Lowering Identity Op ----------------------------===// // // Copyright 2019 The IBM Research Authors. // @@ -8,6 +8,10 @@ // //===----------------------------------------------------------------------===// +#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp" + +using namespace mlir; + struct ONNXIdentityOpLowering : public ConversionPattern { ONNXIdentityOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {} diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/reshape.inc b/src/conversion/onnx_to_krnl/tensor/reshape.cpp similarity index 97% rename from src/conversion/onnx_to_krnl/rewrite_patterns/tensor/reshape.inc rename to src/conversion/onnx_to_krnl/tensor/reshape.cpp index b64494f..6489a71 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/reshape.inc +++ b/src/conversion/onnx_to_krnl/tensor/reshape.cpp @@ -1,4 +1,4 @@ -//===----- reshape.inc - Lowering Reshape Op ------------------------------===// +//===----- reshape.cpp - Lowering Reshape Op ------------------------------===// // // Copyright 2019 The IBM Research Authors. // @@ -8,6 +8,10 @@ // //===----------------------------------------------------------------------===// +#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp" + +using namespace mlir; + struct ONNXReshapeOpLowering : public ConversionPattern { ONNXReshapeOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {} diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/transpose.inc b/src/conversion/onnx_to_krnl/tensor/transpose.cpp similarity index 96% rename from src/conversion/onnx_to_krnl/rewrite_patterns/tensor/transpose.inc rename to src/conversion/onnx_to_krnl/tensor/transpose.cpp index 3bb897a..0a6c8f4 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/transpose.inc +++ b/src/conversion/onnx_to_krnl/tensor/transpose.cpp @@ -1,4 +1,4 @@ -//===----- transpose.inc - Lowering Transpose Op --------------------------===// +//===----- transpose.cpp - Lowering Transpose Op --------------------------===// // // Copyright 2019 The IBM Research Authors. // @@ -8,6 +8,10 @@ // //===----------------------------------------------------------------------===// +#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp" + +using namespace mlir; + struct ONNXTransposeOpLowering : public ConversionPattern { ONNXTransposeOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {} diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc b/src/conversion/onnx_to_krnl/tensor/unsqueeze.cpp similarity index 95% rename from src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc rename to src/conversion/onnx_to_krnl/tensor/unsqueeze.cpp index 6d5289d..070a91c 100644 --- a/src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc +++ b/src/conversion/onnx_to_krnl/tensor/unsqueeze.cpp @@ -1,4 +1,4 @@ -//===----- unsqueeze.inc - Lowering Unsqueeze Op --------------------------===// +//===----- unsqueeze.cpp - Lowering Unsqueeze Op --------------------------===// // // Copyright 2019 The IBM Research Authors. // @@ -8,6 +8,10 @@ // //===----------------------------------------------------------------------===// +#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp" + +using namespace mlir; + struct ONNXUnsqueezeOpLowering : public ConversionPattern { ONNXUnsqueezeOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {} diff --git a/src/dialect/krnl/krnl_helper.cpp b/src/dialect/krnl/krnl_helper.cpp index 4f75a43..91e9825 100644 --- a/src/dialect/krnl/krnl_helper.cpp +++ b/src/dialect/krnl/krnl_helper.cpp @@ -131,7 +131,7 @@ void KrnlIterateOperandPack::pushConstantBound(int64_t bound) { boundMaps.emplace_back(AffineMapAttr::get(map)); } -void KrnlIterateOperandPack::pushOperandBound(mlir::Value operand) { +void KrnlIterateOperandPack::pushOperandBound(Value operand) { if (boundMaps.size() % 2 == 0) _operands.emplace_back(inputLoops[boundMaps.size() / 2]); AffineMap map = builder.getSymbolIdentityMap(); @@ -145,7 +145,7 @@ BuildKrnlLoop::BuildKrnlLoop( pushCount(0), createdDefineOp(false), createdOptimizeOp(false), createdIterateOp(false) { if (originalLoopNum <= 0) - emitError(loc, "expected positive number of original loops"); + emitError(loc, "Expected positive number of original loops."); } BuildKrnlLoop::BuildKrnlLoop( @@ -154,25 +154,24 @@ BuildKrnlLoop::BuildKrnlLoop( memRefOperand.getType().cast().getShape().size()) {} BuildKrnlLoop::~BuildKrnlLoop() { - if (!createdDefineOp) - emitError(loc, "expected to create define op"); - if (!createdIterateOp) - emitError(loc, "expected to create iteration op"); if (pack) free(pack); } void BuildKrnlLoop::createDefineAndOptimizeOp(bool withEmptyOptimization) { - // insert define loop op + // Insert define loop operation. auto loopsOp = rewriter.create(loc, originalLoopNum); originalLoops.reserve(originalLoopNum); for (auto result : loopsOp.getResults()) originalLoops.push_back(result); - // inserte optimize loop op. + createdDefineOp = true; + + // Insert optimize loop operation. auto optimizedLoopsOp = rewriter.create(loc, originalLoopNum); optLoops.reserve(originalLoopNum); - // Emit empty optimizations + + // Emit empty optimizations if flag is set. if (withEmptyOptimization) { for (auto result : optimizedLoopsOp.getResults()) optLoops.push_back(result); @@ -182,12 +181,12 @@ void BuildKrnlLoop::createDefineAndOptimizeOp(bool withEmptyOptimization) { rewriter.create(loc, originalLoops); rewriter.restoreInsertionPoint(ip); } + createdOptimizeOp = true; + // prepare data structure to push bounds pack = new KrnlIterateOperandPack(rewriter, originalLoops, optLoops); - createdOptimizeOp = true; } -// push bounds (lower and upper) and return index for loop info int BuildKrnlLoop::pushBounds(int64_t lowerBound, int64_t upperBound) { pack->pushConstantBound(lowerBound); pack->pushConstantBound(upperBound); @@ -203,17 +202,20 @@ int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBound) { int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand, int upperBoundMemRefIndex, bool upperBoundMustBeConstant) { pack->pushConstantBound(lowerBound); - // process upperBound as a dimension of mem ref, possibly non-constant + + // Process upperBound as a dimension of the MemRef. Non-constant dimensions + // are supported. auto shape = upperBoundMemRefOperand.getType().cast().getShape(); if (shape[upperBoundMemRefIndex] < 0) { if (upperBoundMustBeConstant) - emitError(loc, "bound expected to be constant"); + emitError(loc, "Bound expected to be constant."); pack->pushOperandBound( rewriter .create(loc, upperBoundMemRefOperand, upperBoundMemRefIndex) .getResult()); } else pack->pushConstantBound(shape[upperBoundMemRefIndex]); + return pushCount++; } @@ -223,19 +225,20 @@ int BuildKrnlLoop::pushBounds(Value lowerBound, Value upperBound) { return pushCount++; } -// create iter void BuildKrnlLoop::createIterateOp() { + // Loop definition operation is mandatory. if (!createdDefineOp) - emitError(loc, "must create define op before iterate op"); - // Tight now, optimize (possibly empty) is mandatory. This may change + emitError(loc, "Must create define op before iterate op."); + + // Loop optimization operation is mandatory (for now). if (!createdOptimizeOp) - emitError(loc, "must create optimize op before iterate op"); - // have to have defined all bounds - if (pushCount != originalLoopNum) { - printf(" push count %d, original loop %d\n", pushCount, originalLoopNum); - emitError(loc, "must push bounds for all original loops"); - } - // create iterate op + emitError(loc, "Must create optimize op before iterate op."); + + // Check if all bounds have been defined. + if (pushCount != originalLoopNum) + emitError(loc, "Must push bounds for all original loops."); + + // Emit iteration operation. auto iterateOp = rewriter.create(loc, *pack); iterBlock = &iterateOp.bodyRegion().front(); createdIterateOp = true; @@ -243,19 +246,27 @@ void BuildKrnlLoop::createIterateOp() { void BuildKrnlLoop::createDefineOptimizeAndIterateOp( Value memRefOperand, bool withEmptyOptimization) { + // Rank of the MemRef operand. We will emit a loop for each dimension. int loopNum = memRefOperand.getType().cast().getShape().size(); if (originalLoopNum != loopNum) - emitError(loc, "mismatch in loop numbers from constructor and define"); + emitError(loc, "Mismatch in loop numbers from constructor and define."); + + // Emit the definition and the optimization operations for the loop nest. createDefineAndOptimizeOp(withEmptyOptimization); + + // Push a lower-upper bound pair for each dimension of the MemRef operand. + // The lower bound in this case is always zero. for (int i = 0; i < originalLoopNum; ++i) pushBounds(0, memRefOperand, i); + + // Emit the iteration operation over the current loop nest. createIterateOp(); } -// get induction variable to be use within iter BlockArgument &BuildKrnlLoop::getInductionVar(int originalLoopIndex) { + // Check if loop iteration variable is within bounds. if (originalLoopIndex < 0 || originalLoopIndex >= originalLoopNum) - emitError(loc, "original loop index is out of bound"); + emitError(loc, "Original loop index is out of bounds."); return iterBlock->getArguments()[originalLoopIndex]; } diff --git a/src/dialect/krnl/krnl_helper.hpp b/src/dialect/krnl/krnl_helper.hpp index cfe1787..aebbe0b 100644 --- a/src/dialect/krnl/krnl_helper.hpp +++ b/src/dialect/krnl/krnl_helper.hpp @@ -106,19 +106,21 @@ private: // // The sequence is as follow: // -// 1) Create a object giving the rewriter, location, and number of loop in the -// original (non optimized) loop. +// 1) Create an object giving the rewriter, location, and number of loop in +// the original (non optimized) loop. // // 2) Create define & optimize ops (currently paired). Optimizations can then -// be added to the inner block of the optimize operation. Make sure to set the -// insertion point to that block for optimizations to go in the right place. +// be added to the inner block of the optimize operation. Make sure to set +// the insertion point to that block for optimizations to go in the right +// place. // // 3) Push the bounds for each of the original loops. Bounds are pushed in -// pairs (lower & upper bounds). THere are a few methods to do it depending on -// the type of the bounds. When pushing bounds, the method returns a number -// that represent the index associated with that iteration (induction variable -// and bounds). That index can be used later to extract the induction variable -// for reference in computation and/or index calculations of mem refs. +// pairs (lower & upper bounds). There are a few methods to do it depending +// on the type of the bounds. When pushing bounds, the method returns a +// number that represent the index associated with that iteration (induction +// variable and bounds). That index can be used later to extract the +// induction variable for reference in computation and/or index calculations +// of mem refs. // // 4) Once all the bounds are pushed, create the iterate operation. Once this // is done, we can add operations within the iterate blocks by setting the @@ -127,67 +129,90 @@ private: class BuildKrnlLoop { public: - // Create a build kernel loop for the given location and loop number. + // Create kernel loop builder for a loop nest of depth loopNum. BuildKrnlLoop(ConversionPatternRewriter &rewriter, Location loc, int loopNum); - // Do the same, but where the loop number corresponds to the dimensionality of - // the mem ref operand. + + // Create kernel loop builder for a loop nest of depth equal to the + // dimensionality of the operand. An operand of MemRef type is requied. BuildKrnlLoop( ConversionPatternRewriter &rewriter, Location loc, Value memRefOperand); ~BuildKrnlLoop(); // Create define and optimize loop with loopNum original loops. If - // withEmptyOptimization, the optimization is simply the identity function (no - // optimizations). + // withEmptyOptimization is true, the optimization is simply the identity + // function (no optimizations). void createDefineAndOptimizeOp(bool withEmptyOptimization = true); - // Push bounds (lower and upper) for each of the loops, in order. It returns - // the index associated with the loop iteration. This index is in the range - // from zero to original loop number -1, and is monotonally increasing from - // call to call. This index is later used in the getInductionVar call. + // Push bounds (lower and upper) for each of the loops (order matters). + // The function returns the order number associated with the loop iteration. + // This index is used by the getInductionVar call. Non-constant operands + // must be of MemRef type. int pushBounds(int64_t lowerBound, int64_t upperBound); int pushBounds(int64_t lowerBound, Value upperBound); int pushBounds(Value lowerBound, Value upperBound); - // same, where the lower bound is an integer, and the uppoer bound is given by - // the size of the mem ref operand along the upperBoundMemRefIndex dimension. int pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand, int upperBoundMemRefIndex, bool upperBoundMustBeConstant = false); - // Create an iterate op. + // Create the KrnlIterateOp assiciated with this loop nest. The loops + // iteration will be created if the definition and the optimization + // operations associated with this loop nest have been emitted already. void createIterateOp(); - // Create an define, optimize and iterate op, with the same loop nummber as - // the rank of the memRefOperand. The lower bound of each loops is zero, and - // the upper bound of each loops is the dimension given by the mem refs + + // Create the loop nest definition, optimization and iteration operations + // for a given operand of MemRef type. The loop nest has a depth equal to the + // rank of the MemRef operand. The lower bound of each loop is zero. The + // upper bound of each loop is given by the corresponding dimension of the + // MemRef operand. void createDefineOptimizeAndIterateOp( Value memRefOperand, bool withEmptyOptimization = true); - // Get the (original loop) induction variable associated with the given index. - // Use the index returned when pushing the bounds. + // Get the (original loop) induction variable associated with the given + // index. Use the index returned when pushing the bounds. BlockArgument &getInductionVar(int originalLoopIndex); - // Get blocks. This allow us to set the insertion point to the inner block of - // the optimize and the iterate Operation + // Get a reference to the code region of the optimization operation. + // This allows us to set the insertion point to the inner block of the + // loop nest optimization operation. Block *getOptimizationBlock() { return optBlock; } + + // Get a reference to the code region of the iteration operation. + // This allows us to set the insertion point to the inner block of the + // loop nest iteration operation. Block *getIterateBlock() { return iterBlock; } - // get original or optimized loops + // Get original loop nest. std::vector &getOriginalLoops() { return originalLoops; } + + // Get optimized loop nest. std::vector &getOptimizedLoops() { return optLoops; } private: - // inputs + // Required for emitting operations. ConversionPatternRewriter &rewriter; Location loc; int originalLoopNum; - // track loops and bounds + + // List of original, un-optimized loops. std::vector originalLoops; + + // List of optimized loops. std::vector optLoops; + + // List of lower-upper bound pairs needed by the KrnlIterateOp. KrnlIterateOperandPack *pack; + + // Number of lower-upper bound pairs pushed. int pushCount; + + // Flags that keep track of emitted operations. bool createdDefineOp; bool createdOptimizeOp; bool createdIterateOp; - // insertion points (opt block, iterate) + + // Saved insertion point in the code region of the KrnlOptimizeLoopsOp. Block *optBlock; + + // Saved insertion point in the code region of the KrnlIterateOp. Block *iterBlock; }; diff --git a/src/dialect/onnx/onnx.td b/src/dialect/onnx/onnx.td index ef4e62a..1dde6cc 100644 --- a/src/dialect/onnx/onnx.td +++ b/src/dialect/onnx/onnx.td @@ -90,25 +90,6 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> { // or outputs. This decision affects only ONNX operations with optional // arguments not ONNX operations with variadic operands. -def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias", - [NoSideEffect, DeclareOpInterfaceMethods]> { - let summary = "ONNX general matrix multiply operation without bias."; - let description = [{ - - The "onnx.Gemm" generic matrix multiplication without bias. - - }]; - - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$B, - DefaultValuedAttr:$alpha, - DefaultValuedAttr:$beta, - DefaultValuedAttr:$transA, - DefaultValuedAttr:$transB); - - let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y); -} - def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias", [NoSideEffect, DeclareOpInterfaceMethods]> { let hasCanonicalizer = 1; diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 8757f71..9668e68 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -565,32 +565,6 @@ void ONNXGemmOp::inferShapes() { getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); } -// GemmNoBias - -void ONNXGemmNoBiasOp::inferShapes() { - // Cannot infer shape if no shape exists. - if (!getOperand(0).getType().isa() || - !getOperand(1).getType().isa()) - return; - auto lhsTy = getOperand(0).getType().cast(); - auto rhsTy = getOperand(1).getType().cast(); - - int64_t M, N, K_A, K_B; - M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1]; - K_A = (transA() == 0) ? lhsTy.getShape()[1] : lhsTy.getShape()[0]; - N = (transB() == 0) ? rhsTy.getShape()[1] : rhsTy.getShape()[0]; - K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1]; - - if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) { - emitError("Tensor shapes mismatched."); - } - - SmallVector dims; - dims.emplace_back(M); - dims.emplace_back(N); - getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); -} - /// BatchNormalizationTestMode void ONNXBatchNormalizationTestModeOp::inferShapes() { // Cannot infer shape if no shape exists. diff --git a/src/pass/shape_inference_pass.cpp b/src/pass/shape_inference_pass.cpp index 2f80ea7..47826af 100644 --- a/src/pass/shape_inference_pass.cpp +++ b/src/pass/shape_inference_pass.cpp @@ -118,7 +118,6 @@ public: op->getName().getStringRef() != "onnx.Identity" && op->getName().getStringRef() != "onnx.MatMul" && op->getName().getStringRef() != "onnx.Gemm" && - op->getName().getStringRef() != "onnx.GemmNoBias" && op->getName().getStringRef() != "onnx.Reshape" && op->getName().getStringRef() != "onnx.Transpose" && op->getName().getStringRef() != "onnx.ReduceMax" && diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 9da12ac..c35536d 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -806,35 +806,6 @@ func @test_gemm(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>, %arg2: tenso // CHECK: } } -func @test_gemm_no_bias(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>) -> tensor<*xf32> { - %0 ="onnx.GemmNoBias"(%arg0, %arg1) {alpha = 1.0 : f32, beta = 5.0 : f32, transA = 1, transB = 0} : (tensor<5x10xf32>, tensor<5x10xf32>) -> tensor<*xf32> - "std.return"(%0) : (tensor<*xf32>) -> () - - // CHECK-LABEL: test_gemm_no_bias - // CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32> - // CHECK: [[ALPHA:%.+]] = constant 1.000000e+00 : f32 - // CHECK: [[BETA:%.+]] = constant 5.000000e+00 : f32 - // CHECK: [[DEF_LOOPS:%.+]]:3 = krnl.define_loops 3 - // CHECK: [[OPT_LOOPS:%.+]]:3 = krnl.optimize_loops { - // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1, [[DEF_LOOPS]]#2 - // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { - // CHECK: krnl.iterate([[OPT_LOOPS]]#2) with ([[DEF_LOOPS]]#2 -> %arg4 = 0 to 5) { - // CHECK: [[A:%.+]] = load %arg0[%arg4, %arg2] : memref<5x10xf32> - // CHECK: [[B:%.+]] = load %arg1[%arg4, %arg3] : memref<5x10xf32> - // CHECK: [[Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32> - // CHECK: [[AB:%.+]] = mulf [[A]], [[B]] : f32 - // CHECK: [[SUM:%.+]] = addf [[Y]], [[AB]] : f32 - // CHECK: store [[SUM]], [[RES]][%arg2, %arg3] : memref<10x10xf32> - // CHECK: } - // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32> - // CHECK: [[ALPHA_AB:%.+]] = mulf [[ALPHA]], [[LOAD_Y]] : f32 - // CHECK: store [[ALPHA_AB]], [[RES]][%arg2, %arg3] : memref<10x10xf32> - // CHECK: } - // CHECK: return [[RES]] : memref<10x10xf32> - // CHECK: } -} - func @test_sqrt(%arg0 : tensor) -> tensor<*xf32> { %0 = "onnx.Sqrt"(%arg0) : (tensor) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () diff --git a/third_party/onnx b/third_party/onnx index 1439eab..553df22 160000 --- a/third_party/onnx +++ b/third_party/onnx @@ -1 +1 @@ -Subproject commit 1439eab5542c625bb3da49860f0cd68c3eafdc18 +Subproject commit 553df22c67bee5f0fe6599cff60f1afc6748c635