diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2725545..da88be1 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -22,7 +22,8 @@ set(OMLibs OMElideConstants OMElideKrnlGlobalConstants OMPackKrnlGlobalConstants - OMEnableMemoryPool) + OMEnableMemoryPool + OMBundleMemoryPools) set(OMLibs ${OMLibs} PARENT_SCOPE) message(SATUS "OMLibs" ${OMLibs}) diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp index 8882622..c57d0a4 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp @@ -461,3 +461,28 @@ Value emitNegativeInfinityConstantOp( int64_t ArrayAttrIntVal(ArrayAttr a, int i) { return (a.getValue()[i]).cast().getInt(); } + +bool checkOpResultIsUsedByGetRef(AllocOp *allocOp) { + auto parentBlock = allocOp->getOperation()->getBlock(); + + bool opIsUsedInGetRef = false; + parentBlock->walk([&opIsUsedInGetRef, allocOp](KrnlGetRefOp op) { + auto result = allocOp->getResult(); + for (const auto &operand : op.getOperands()) + if (operand == result) + opIsUsedInGetRef = true; + }); + + return opIsUsedInGetRef; +} + +// TODO: support dynamic sizes. +int64_t getMemRefSizeInBytes(Value val) { + auto memRefType = convertToMemRefType(val.getType()); + auto memRefShape = memRefType.getShape(); + int64_t size = 1; + for (int i = 0; i < memRefShape.size(); i++) + size *= memRefShape[i]; + size *= getMemRefEltSizeInBytes(memRefType); + return size; +} diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index f1c6b14..6b6660c 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -242,3 +242,7 @@ void populateLoweringONNXSqueezeOpPattern( void populateLoweringONNXSplitOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx); + +bool checkOpResultIsUsedByGetRef(AllocOp *allocOp); + +int64_t getMemRefSizeInBytes(Value val); diff --git a/src/InitOMPasses.hpp b/src/InitOMPasses.hpp index 88318d2..feb6fcc 100644 --- a/src/InitOMPasses.hpp +++ b/src/InitOMPasses.hpp @@ -42,6 +42,12 @@ void initOMPasses() { return mlir::createKrnlEnableMemoryPoolPass(); }); + mlir::registerPass("bundle-memory-pools", + "Bundle memory pools of internal MemRefs into a single memory pool.", + []() -> std::unique_ptr { + return mlir::createKrnlBundleMemoryPoolsPass(); + }); + mlir::registerPass( "lower-krnl", "Lower Krnl dialect.", []() -> std::unique_ptr { return mlir::createLowerKrnlPass(); diff --git a/src/MainUtils.cpp b/src/MainUtils.cpp index 50f2a7f..b916bb1 100644 --- a/src/MainUtils.cpp +++ b/src/MainUtils.cpp @@ -269,6 +269,8 @@ void addONNXToKrnlPasses(mlir::PassManager &pm) { // TODO: make this pass optional: pm.addPass(mlir::createKrnlEnableMemoryPoolPass()); + pm.addPass(mlir::createKrnlBundleMemoryPoolsPass()); + pm.addPass(mlir::createCanonicalizerPass()); } void addKrnlToAffinePasses(mlir::PassManager &pm) { diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index 9988537..2a3e362 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -31,6 +31,9 @@ std::unique_ptr createElideConstantValuePass(); /// Pass for enabling a memory pool for MemRefs. std::unique_ptr createKrnlEnableMemoryPoolPass(); +/// Pass for enabling a memory pool for MemRefs. +std::unique_ptr createKrnlBundleMemoryPoolsPass(); + /// Add pass for lowering to Krnl IR. std::unique_ptr createLowerToKrnlPass(); diff --git a/src/Transform/BundleMemoryPools.cpp b/src/Transform/BundleMemoryPools.cpp new file mode 100644 index 0000000..f8e1ddc --- /dev/null +++ b/src/Transform/BundleMemoryPools.cpp @@ -0,0 +1,148 @@ +//===-- BundleMemoryPools.cpp - Bundle Memory Pools for internal MemRefs -===// +// +// Copyright 2019-2020 The IBM Research Authors. +// +// ============================================================================= +// +// For certain cases the number of individual memory allocations required for +// all internal tensors is large and needs to be mitigated. This pass bundles +// all the internal MemRef memory pools emitted by the EnableMemoryPool pass +// int a single memory pool. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "src/Pass/Passes.hpp" + +using namespace mlir; + +namespace { + +KrnlGetRefOp getUnbundledGetRef(AllocOp *memPool) { + auto parentBlock = memPool->getOperation()->getBlock(); + + KrnlGetRefOp unbundledGetRef = nullptr; + parentBlock->walk([&unbundledGetRef, memPool](KrnlGetRefOp op) { + auto result = memPool->getResult(); + if (op.getOperands()[0] != result) + unbundledGetRef = op; + }); + + return unbundledGetRef; +} + +/*! + * RewritePattern that replaces: + * %mem1 = alloc() : memref<x> + * %mem2 = alloc() : memref<x> + * %1 = krnl.getref %mem2 0 : memref<x> + * => + * %mem1 = alloc() : memref<x> + * %1 = krnl.getref %mem1 : memref<x> + * + * + * ASSUMPTION: All krnl.getref operations in the program have been emitted + * by the EnableMemoryPool pass i.e. there are no krnl.getref + * operations which are not related to the memory pool. + * krnl.getref is an operation specific to memory management + * for other use cases use MLIR Standard dialect operations. + * This assumption simplifies the code and avoids additional + * checks to ensure that all the participating krnl.getref + * operations are part of memory pooling. + */ + +class KrnlBundleMemoryPools : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + AllocOp allocOp, PatternRewriter &rewriter) const override { + auto loc = allocOp.getLoc(); + + auto memRefType = convertToMemRefType(allocOp.getResult().getType()); + auto memRefShape = memRefType.getShape(); + + // If alloca result is not used by getref then it cannot be part of + // the memory pool. + if (!checkOpResultIsUsedByGetRef(&allocOp)) + return failure(); + + // Alloc memory type must be byte. + if (getMemRefEltSizeInBytes(memRefType) != 1) + return failure(); + + // Rank of the allocated MemRef must be 1. + if (memRefShape.size() != 1) + return failure(); + + // TODO: Change this when dyanmic shapes are supported. + // TODO: Add support for dynamic shapes. + int64_t currentMemPoolSize = memRefShape[0]; + + // Get a KrnlGetRefOp which does not use the current alloc. + if (KrnlGetRefOp unbundledGetRef = getUnbundledGetRef(&allocOp)) { + unbundledGetRef.dump(); + + // Current memory pool size is the offset for the newly bundled + // internal MemRef. Emit the offset as a constant. + auto offset = rewriter.create( + loc, rewriter.getIntegerAttr( + rewriter.getIntegerType(64), currentMemPoolSize)); + + // Size in bytes of the output of the krnl.getref operation. + int64_t unbundledTotalSize = + getMemRefSizeInBytes(unbundledGetRef.getResult()); + + // Compute new size. + int64_t bundleTotalSize = unbundledTotalSize + currentMemPoolSize; + + // We need to emit a new alloc which contains the additional MemRef. + SmallVector newMemPoolShape; + newMemPoolShape.emplace_back(bundleTotalSize); + auto bundledMemPoolMemRefType = + MemRefType::get(newMemPoolShape, rewriter.getIntegerType(8)); + auto bundledAlloc = + rewriter.create(loc, bundledMemPoolMemRefType); + + // The newly bundled MemRef expressed as a KrnlGetRefOp. + auto bundledMemRef = rewriter.create( + loc, unbundledGetRef.getResult().getType(), bundledAlloc, offset); + rewriter.replaceOp(unbundledGetRef, bundledMemRef.getResult()); + + // Replace old memory pool with new one. + rewriter.replaceOp(allocOp, bundledAlloc.getResult()); + + return success(); + } + + return failure(); + } +}; + +/*! + * Function pass that enables memory pooling for MemRefs. + */ +class KrnlBundleMemoryPoolsPass + : public PassWrapper { +public: + void runOnFunction() override { + auto function = getFunction(); + + ConversionTarget target(getContext()); + OwningRewritePatternList patterns; + patterns.insert(&getContext()); + + applyPatternsAndFoldGreedily(function, patterns); + } +}; +} // namespace + +std::unique_ptr mlir::createKrnlBundleMemoryPoolsPass() { + return std::make_unique(); +} diff --git a/src/Transform/CMakeLists.txt b/src/Transform/CMakeLists.txt index af2bc9a..0626abf 100644 --- a/src/Transform/CMakeLists.txt +++ b/src/Transform/CMakeLists.txt @@ -46,6 +46,8 @@ target_include_directories(OMPackKrnlGlobalConstants ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} ${ONNX_MLIR_SRC_ROOT}) +add_dependencies(OMPackKrnlGlobalConstants + OMKrnlOps) add_library(OMEnableMemoryPool EnableMemoryPool.cpp) @@ -54,13 +56,23 @@ target_include_directories(OMEnableMemoryPool ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} ${ONNX_MLIR_SRC_ROOT}) -add_dependencies(OMPackKrnlGlobalConstants - OMKrnlOps) - target_link_libraries(OMEnableMemoryPool onnx) add_dependencies(OMEnableMemoryPool OMKrnlOps OMONNXOps) +add_library(OMBundleMemoryPools + BundleMemoryPools.cpp) +target_include_directories(OMBundleMemoryPools + PRIVATE + ${ONNX_MLIR_SRC_ROOT} + ${ONNX_MLIR_BIN_ROOT} + ${ONNX_MLIR_SRC_ROOT}) +target_link_libraries(OMBundleMemoryPools + onnx) +add_dependencies(OMBundleMemoryPools + OMKrnlOps + OMONNXOps) + add_subdirectory(ONNX) diff --git a/src/Transform/EnableMemoryPool.cpp b/src/Transform/EnableMemoryPool.cpp index 94a5446..c41293f 100644 --- a/src/Transform/EnableMemoryPool.cpp +++ b/src/Transform/EnableMemoryPool.cpp @@ -37,20 +37,6 @@ bool checkOpResultIsReturned(AllocOp *allocOp) { return opIsReturned; } -bool checkOpResultIsUsedByGetRef(AllocOp *allocOp) { - auto parentBlock = allocOp->getOperation()->getBlock(); - - bool opIsUsedInGetRef = false; - parentBlock->walk([&opIsUsedInGetRef, allocOp](KrnlGetRefOp op) { - auto result = allocOp->getResult(); - for (const auto &operand : op.getOperands()) - if (operand == result) - opIsUsedInGetRef = true; - }); - - return opIsUsedInGetRef; -} - /*! * RewritePattern that replaces: * %0 = alloc() : memref<x> @@ -84,11 +70,7 @@ public: return failure(); // Compute total size. - auto memRefShape = memRefType.getShape(); - int64_t totalSize = 1; - for (int i = 0; i < memRefShape.size(); i++) - totalSize *= memRefShape[i]; - totalSize *= getMemRefEltSizeInBytes(memRefType); + int64_t totalSize = getMemRefSizeInBytes(allocOp.getResult()); // Emit new alloc. SmallVector memPoolShape; diff --git a/test/mlir/krnl/krnl_bundle_memory_pool.mlir b/test/mlir/krnl/krnl_bundle_memory_pool.mlir new file mode 100644 index 0000000..750439a --- /dev/null +++ b/test/mlir/krnl/krnl_bundle_memory_pool.mlir @@ -0,0 +1,53 @@ +// RUN: onnx-mlir-opt --bundle-memory-pools --canonicalize %s -split-input-file | FileCheck %s + +func @test_pool_bundling(%arg0: memref<10x10xf32>, %arg1: memref<10x20xf32>) -> memref<10x20xf32> { + %c0_i64 = constant 0 : i64 + %ind = constant 0 : index + %cst = constant 0.000000e+00 : f32 + %0 = alloc() : memref<10x20xf32> + %1 = alloc() : memref<800xi8> + %2 = "krnl.getref"(%1, %c0_i64) : (memref<800xi8>, i64) -> memref<10x20xf32> + %3 = alloc() : memref<400xi8> + %4 = "krnl.getref"(%3, %c0_i64) : (memref<400xi8>, i64) -> memref<10x10xf32> + %5 = alloc() : memref<800xi8> + %6 = "krnl.getref"(%5, %c0_i64) : (memref<800xi8>, i64) -> memref<10x20xf32> + %7 = alloc() : memref<800xi8> + %8 = "krnl.getref"(%7, %c0_i64) : (memref<800xi8>, i64) -> memref<10x20xf32> + %9 = alloc() : memref<400xi8> + %10 = "krnl.getref"(%9, %c0_i64) : (memref<400xi8>, i64) -> memref<10x10xf32> + affine.store %cst, %10[%ind, %ind] : memref<10x10xf32> + affine.store %cst, %8[%ind, %ind] : memref<10x20xf32> + affine.store %cst, %6[%ind, %ind] : memref<10x20xf32> + affine.store %cst, %4[%ind, %ind] : memref<10x10xf32> + affine.store %cst, %2[%ind, %ind] : memref<10x20xf32> + affine.store %cst, %0[%ind, %ind] : memref<10x20xf32> + dealloc %9 : memref<400xi8> + dealloc %7 : memref<800xi8> + dealloc %5 : memref<800xi8> + dealloc %3 : memref<400xi8> + dealloc %1 : memref<800xi8> + return %0 : memref<10x20xf32> + + // CHECK-LABEL: test_pool_bundling + // CHECK: [[CONST_0:%.+]] = constant 0 : i64 + // CHECK: [[CONST_CST:%.+]] = constant 0.000000e+00 : f32 + // CHECK: [[CONST_400:%.+]] = constant 400 : i64 + // CHECK: [[CONST_1200:%.+]] = constant 1200 : i64 + // CHECK: [[CONST_2000:%.+]] = constant 2000 : i64 + // CHECK: [[CONST_2400:%.+]] = constant 2400 : i64 + // CHECK: [[RES:%.+]] = alloc() : memref<10x20xf32> + // CHECK: [[MEMPOOL:%.+]] = alloc() : memref<3200xi8> + // CHECK: [[MEMREF1:%.+]] = "krnl.getref"([[MEMPOOL]], [[CONST_2400]]) : (memref<3200xi8>, i64) -> memref<10x20xf32> + // CHECK: [[MEMREF2:%.+]] = "krnl.getref"([[MEMPOOL]], [[CONST_2000]]) : (memref<3200xi8>, i64) -> memref<10x10xf32> + // CHECK: [[MEMREF3:%.+]] = "krnl.getref"([[MEMPOOL]], [[CONST_1200]]) : (memref<3200xi8>, i64) -> memref<10x20xf32> + // CHECK: [[MEMREF4:%.+]] = "krnl.getref"([[MEMPOOL]], [[CONST_400]]) : (memref<3200xi8>, i64) -> memref<10x20xf32> + // CHECK: [[MEMREF5:%.+]] = "krnl.getref"([[MEMPOOL]], [[CONST_0]]) : (memref<3200xi8>, i64) -> memref<10x10xf32> + // CHECK: affine.store %cst, [[MEMREF5]][0, 0] : memref<10x10xf32> + // CHECK: affine.store %cst, [[MEMREF4]][0, 0] : memref<10x20xf32> + // CHECK: affine.store %cst, [[MEMREF3]][0, 0] : memref<10x20xf32> + // CHECK: affine.store %cst, [[MEMREF2]][0, 0] : memref<10x10xf32> + // CHECK: affine.store %cst, [[MEMREF1]][0, 0] : memref<10x20xf32> + // CHECK: affine.store %cst, [[RES]][0, 0] : memref<10x20xf32> + // CHECK: dealloc [[MEMPOOL]] : memref<3200xi8> + // CHECK: return [[RES]] : memref<10x20xf32> +} diff --git a/test/mlir/onnx/onnx_bundle_memory_pool.mlir b/test/mlir/onnx/onnx_bundle_memory_pool.mlir new file mode 100644 index 0000000..f6288a8 --- /dev/null +++ b/test/mlir/onnx/onnx_bundle_memory_pool.mlir @@ -0,0 +1,28 @@ +// RUN: onnx-mlir-opt --shape-inference --lower-frontend --enable-memory-pool --bundle-memory-pools --canonicalize %s -split-input-file | FileCheck %s + +func @test_bundle_memory_pool(%arg0: tensor<10x10xf32>, %arg1: tensor<10x20xf32>) -> tensor<10x20xf32> { + %0 = "onnx.Add"(%arg0, %arg0) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + %1 = "onnx.MatMul"(%0, %arg1) : (tensor<10x10xf32>, tensor<10x20xf32>) -> tensor<10x20xf32> + %2 = "onnx.Add"(%1, %arg1) : (tensor<10x20xf32>, tensor<10x20xf32>) -> tensor<10x20xf32> + %3 = "onnx.Add"(%0, %arg0) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + %4 = "onnx.MatMul"(%3, %arg1) : (tensor<10x10xf32>, tensor<10x20xf32>) -> tensor<10x20xf32> + %5 = "onnx.Add"(%4, %arg1) : (tensor<10x20xf32>, tensor<10x20xf32>) -> tensor<10x20xf32> + return %5 : tensor<10x20xf32> + + // CHECK-LABEL: test_bundle_memory_pool + // CHECK: [[CONST0:%.+]] = constant 0 : i64 + // CHECK: [[CONST00:%.+]] = constant 0.000000e+00 : f32 + // CHECK: [[CONST400:%.+]] = constant 400 : i64 + // CHECK: [[CONST1200:%.+]] = constant 1200 : i64 + // CHECK: [[CONST2000:%.+]] = constant 2000 : i64 + // CHECK: [[CONST2400:%.+]] = constant 2400 : i64 + // CHECK: [[RES:%.+]] = alloc() : memref<10x20xf32> + // CHECK: [[MEMPOOL:%.+]] = alloc() : memref<3200xi8> + // CHECK: "krnl.getref"([[MEMPOOL]], [[CONST2400]]) : (memref<3200xi8>, i64) -> memref<10x20xf32> + // CHECK: "krnl.getref"([[MEMPOOL]], [[CONST2000]]) : (memref<3200xi8>, i64) -> memref<10x10xf32> + // CHECK: "krnl.getref"([[MEMPOOL]], [[CONST1200]]) : (memref<3200xi8>, i64) -> memref<10x20xf32> + // CHECK: "krnl.getref"([[MEMPOOL]], [[CONST400]]) : (memref<3200xi8>, i64) -> memref<10x20xf32> + // CHECK: "krnl.getref"([[MEMPOOL]], [[CONST0]]) : (memref<3200xi8>, i64) -> memref<10x10xf32> + // CHECK: dealloc [[MEMPOOL]] : memref<3200xi8> + // CHECK: return [[RES]] : memref<10x20xf32> +} \ No newline at end of file