From d3dcee73665d5b7aead7dec67b2bf6532fc88efe Mon Sep 17 00:00:00 2001 From: Gheorghe-Teodor Bercea Date: Fri, 14 Aug 2020 12:59:33 -0400 Subject: [PATCH] Add krnl.dim operation and lowering pass (#261) * Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Add krnl.dim op. * Add test with alloc with map. * Code clean-up. * Code clean-up. * Add comment for function. * Update DisconnectKrnlDimFromAlloc.cpp --- src/CMakeLists.txt | 3 +- src/Conversion/KrnlToAffine/KrnlToAffine.cpp | 2 + .../ONNXToKrnl/ONNXToKrnlCommon.cpp | 17 +++ .../ONNXToKrnl/ONNXToKrnlCommon.hpp | 11 ++ src/Dialect/Krnl/KrnlOps.td | 21 ++++ src/InitOMPasses.hpp | 5 + src/Pass/Passes.hpp | 3 + src/Transform/CMakeLists.txt | 13 +++ src/Transform/DisconnectKrnlDimFromAlloc.cpp | 103 ++++++++++++++++++ .../krnl/krnl_disconnect_dim_from_alloc.mlir | 43 ++++++++ 10 files changed, 220 insertions(+), 1 deletion(-) create mode 100644 src/Transform/DisconnectKrnlDimFromAlloc.cpp create mode 100644 test/mlir/krnl/krnl_disconnect_dim_from_alloc.mlir diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 08f9c1a..d834ac1 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -23,7 +23,8 @@ set(OMLibs OMElideKrnlGlobalConstants OMPackKrnlGlobalConstants OMEnableMemoryPool - OMBundleMemoryPools) + OMBundleMemoryPools + OMDisconnectKrnlDimFromAlloc) set(OMLibs ${OMLibs} PARENT_SCOPE) add_subdirectory(Tool) diff --git a/src/Conversion/KrnlToAffine/KrnlToAffine.cpp b/src/Conversion/KrnlToAffine/KrnlToAffine.cpp index 5c9626a..73b8fba 100644 --- a/src/Conversion/KrnlToAffine/KrnlToAffine.cpp +++ b/src/Conversion/KrnlToAffine/KrnlToAffine.cpp @@ -233,6 +233,8 @@ void ConvertKrnlToAffinePass::runOnFunction() { ConversionTarget target(getContext()); target.addIllegalOp(); + // krnl.dim operations must be lowered prior to this pass. + target.addIllegalOp(); target.addLegalOp(); OwningRewritePatternList patterns; patterns.insert(&getContext()); diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp index b854698..76bddc8 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp @@ -525,3 +525,20 @@ Value getDynamicMemRefSizeInBytes( return result; } + +int64_t getAllocArgIndex(AllocOp allocOp, int64_t index) { + auto memRefShape = + convertToMemRefType(allocOp.getResult().getType()).getShape(); + auto rank = memRefShape.size(); + + int dynDimIdx = 0; + for (int idx = 0; idx < rank; ++idx) { + if (memRefShape[idx] < 0) { + if (idx == index) + return dynDimIdx; + dynDimIdx++; + } + } + + return -1; +} diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index 68e481f..fe1c74c 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -252,3 +252,14 @@ int64_t getMemRefSizeInBytes(Value val); Value getDynamicMemRefSizeInBytes( MemRefType type, Location loc, PatternRewriter &rewriter, AllocOp allocOp); + +/// This function returns the index in the list of alloc arguments of the +/// dynamic dimension corresponding to `index` in the MemRef shape. +/// As an example: +/// +/// alloc(%d0, %d1, %d2) : memref<10x?x?x20x?x30xf32> +/// +/// In the above alloc the list of alloc arguments is being represented by +/// %d0, %d1 and %d2. Their indices 0, 1, 2 correspond to `index` values +/// 1, 2 and 4 in the MemRef shape respectively +int64_t getAllocArgIndex(AllocOp allocOp, int64_t index); diff --git a/src/Dialect/Krnl/KrnlOps.td b/src/Dialect/Krnl/KrnlOps.td index d074d25..0f44171 100644 --- a/src/Dialect/Krnl/KrnlOps.td +++ b/src/Dialect/Krnl/KrnlOps.td @@ -334,3 +334,24 @@ def KrnlUnrollOp : Op { $loop attr-dict `:` type($loop) }]; } + +def KrnlDimOp : Op { + let summary = "Krnl dimensions operation."; + let description = [{ + Emits the dimension of a MemRef independent of the MemRef alloc: + + "krnl.dim"(%memref, %index) + + The index identifies the dimension within the shape which is going to be emitted. + Initially the krnl.dim operation depends on the alloc of the MemRef. + Unlike the std.dim operation which maintains a dependency on the alloc of the MemRef, the dimension emitted by krnl.dim will not depend on the alloc operation of the MemRef once the krnl.dim operation is lowered. + + Any changes to the original MemRef size after the krnl.dim has been lowered will not be picked up by the emitted dimension. This allows the original MemRef to be safely modified via code transformations or affine map normalization without the risk of changing the value already emitted via krnl.dim. + }]; + + let arguments = (ins AnyTypeOf<[AnyMemRef]>:$alloc, Index:$index); + let results = (outs Index:$dimension); + + let parser = ?; + let printer = ?; +} diff --git a/src/InitOMPasses.hpp b/src/InitOMPasses.hpp index 4dd09bd..4521ad8 100644 --- a/src/InitOMPasses.hpp +++ b/src/InitOMPasses.hpp @@ -76,5 +76,10 @@ void initOMPasses() { []() -> std::unique_ptr { return mlir::createPackKrnlGlobalConstantsPass(); }); + + mlir::registerPass("disconnect-dims", "Disconnect dims from allocs.", + []() -> std::unique_ptr { + return mlir::createDisconnectKrnlDimFromAllocPass(); + }); } } // namespace onnx_mlir \ No newline at end of file diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index 5197553..72c9857 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -40,6 +40,9 @@ std::unique_ptr createLowerToKrnlPass(); /// Pass for lowering frontend dialects to Krnl IR dialect. std::unique_ptr createConvertKrnlToAffinePass(); +/// Pass for lowering krnl.dim operations to standard dialect. +std::unique_ptr createDisconnectKrnlDimFromAllocPass(); + /// Pass for eliding the values of global Krnl operations. std::unique_ptr createElideConstGlobalValuePass(); diff --git a/src/Transform/CMakeLists.txt b/src/Transform/CMakeLists.txt index 8961998..dda900d 100644 --- a/src/Transform/CMakeLists.txt +++ b/src/Transform/CMakeLists.txt @@ -47,4 +47,17 @@ add_dependencies(OMBundleMemoryPools OMKrnlOps OMONNXOps) +add_library(OMDisconnectKrnlDimFromAlloc + DisconnectKrnlDimFromAlloc.cpp) +target_include_directories(OMDisconnectKrnlDimFromAlloc + PRIVATE + ${ONNX_MLIR_SRC_ROOT} + ${ONNX_MLIR_BIN_ROOT} + ${ONNX_MLIR_SRC_ROOT}) +target_link_libraries(OMDisconnectKrnlDimFromAlloc + onnx) +add_dependencies(OMDisconnectKrnlDimFromAlloc + OMKrnlOps + OMONNXOps) + add_subdirectory(ONNX) diff --git a/src/Transform/DisconnectKrnlDimFromAlloc.cpp b/src/Transform/DisconnectKrnlDimFromAlloc.cpp new file mode 100644 index 0000000..58693ba --- /dev/null +++ b/src/Transform/DisconnectKrnlDimFromAlloc.cpp @@ -0,0 +1,103 @@ +//===-------- DisconnectKrnlDimFromAlloc.cpp ------------------------------===// +// +// Copyright 2019-2020 The IBM Research Authors. +// +// ============================================================================= +// +// This pass enables the lowering of the krnl.dim operation to a series of +// instruction which do not depend on the alloc of the MemRef whose dim is +// being taken. The krnl.dim operation works in the presence of MemRefs +// which contain affine maps by ignoring the map if present. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "src/Pass/Passes.hpp" + +using namespace mlir; + +namespace { + +/*! + * RewritePattern that replaces: + * %0 = alloc(%d) : memref, #map> + * %1 = krnl.dim(%0, 0) : memref> + * %2 = krnl.dim(%0, 1) : memref> + * %3 = add %1, %2 + * with: + * %0 = alloc(%d) : memref, #map> + * %2 = constant 10 : index + * %3 = add %d, %2 + */ + +class DisconnectKrnlDimFromAlloc : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + KrnlDimOp krnlDimOp, PatternRewriter &rewriter) const override { + auto loc = krnlDimOp.getLoc(); + + // If index is not constant, return failure. + ConstantOp indexOp = + dyn_cast(krnlDimOp.index().getDefiningOp()); + if (!indexOp) + return failure(); + + // Get the integer value of the index. + int64_t index = indexOp.getAttrOfType("value").getInt(); + + // Get defining operation for the MemRef argument. + AllocOp allocOp = dyn_cast(krnlDimOp.alloc().getDefiningOp()); + auto memRefShape = + convertToMemRefType(allocOp.getResult().getType()).getShape(); + auto rank = memRefShape.size(); + assert(index >= 0 && index < rank && "Index must be in bounds"); + + Value result; + if (memRefShape[index] > -1) { + // If dimension is static, then we can just emit the constant value. + result = rewriter.create(loc, + rewriter.getIntegerAttr(rewriter.getIndexType(), memRefShape[index])); + } else { + // If dimension is dynamic we need to return the input alloc Value which + // corresponds to it. + int64_t dynDimIdx = getAllocArgIndex(allocOp, index); + assert(dynDimIdx >= 0 && dynDimIdx < allocOp.getOperands().size() && + "Dynamic index outside range of alloc argument list."); + result = allocOp.getOperands()[dynDimIdx]; + } + + rewriter.replaceOp(krnlDimOp, result); + + return success(); + } +}; + +/*! + * Function pass that disconnects krnl.dim emission from its MemRef alloc. + */ +class DisconnectKrnlDimFromAllocPass + : public PassWrapper { +public: + void runOnFunction() override { + auto function = getFunction(); + + ConversionTarget target(getContext()); + OwningRewritePatternList patterns; + patterns.insert(&getContext()); + + applyPatternsAndFoldGreedily(function, patterns); + } +}; +} // namespace + +std::unique_ptr mlir::createDisconnectKrnlDimFromAllocPass() { + return std::make_unique(); +} diff --git a/test/mlir/krnl/krnl_disconnect_dim_from_alloc.mlir b/test/mlir/krnl/krnl_disconnect_dim_from_alloc.mlir new file mode 100644 index 0000000..1d95db3 --- /dev/null +++ b/test/mlir/krnl/krnl_disconnect_dim_from_alloc.mlir @@ -0,0 +1,43 @@ +// RUN: onnx-mlir-opt --disconnect-dims %s -split-input-file | FileCheck %s + +func @test_krnl_dim_lowering(%arg0: memref) -> index { + %c1 = constant 1 : index + %c0 = constant 0 : index + %0 = dim %arg0, %c0 : memref + %1 = alloc(%0) : memref + %d0 = "krnl.dim"(%1, %c0) : (memref, index) -> index + %d1 = "krnl.dim"(%1, %c1) : (memref, index) -> index + %e = addi %d0, %d1 : index + return %e : index + + // CHECK-LABEL: test_krnl_dim_lowering + // CHECK: [[CONST0:%.+]] = constant 0 : index + // CHECK: [[CONST10:%.+]] = constant 10 : index + // CHECK: [[DIM:%.+]] = dim %arg0, [[CONST0]] : memref + // CHECK: [[ALLOC:%.+]] = alloc([[DIM]]) : memref + // CHECK: [[SUM:%.+]] = addi [[DIM]], [[CONST10]] : index + // CHECK: return [[SUM]] : index +} + +// ----- + +#map = affine_map<(d0, d1) -> (d1, d0)> +func @test_krnl_dim_lowering_with_map(%arg0: memref) -> index { + %c1 = constant 1 : index + %c0 = constant 0 : index + %0 = dim %arg0, %c0 : memref + %1 = alloc(%0) : memref + %d0 = "krnl.dim"(%1, %c0) : (memref, index) -> index + %d1 = "krnl.dim"(%1, %c1) : (memref, index) -> index + %e = addi %d0, %d1 : index + return %e : index + + // CHECK: [[MAP:#.+]] = affine_map<(d0, d1) -> (d1, d0)> + // CHECK-LABEL: test_krnl_dim_lowering_with_map + // CHECK: [[CONST0:%.+]] = constant 0 : index + // CHECK: [[CONST10:%.+]] = constant 10 : index + // CHECK: [[DIM:%.+]] = dim %arg0, [[CONST0]] : memref + // CHECK: [[ALLOC:%.+]] = alloc([[DIM]]) : memref + // CHECK: [[SUM:%.+]] = addi [[DIM]], [[CONST10]] : index + // CHECK: return [[SUM]] : index +}