Bundle individual memory pools into a single memory pool (#183)
* Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Add memory pooling for constant sized arrays. * Clean code. * Clean code. * Clean code. * Add simple bundling test. Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
		
							parent
							
								
									0a936edf79
								
							
						
					
					
						commit
						100bfc81b4
					
				| 
						 | 
					@ -22,7 +22,8 @@ set(OMLibs
 | 
				
			||||||
        OMElideConstants
 | 
					        OMElideConstants
 | 
				
			||||||
        OMElideKrnlGlobalConstants
 | 
					        OMElideKrnlGlobalConstants
 | 
				
			||||||
        OMPackKrnlGlobalConstants
 | 
					        OMPackKrnlGlobalConstants
 | 
				
			||||||
        OMEnableMemoryPool)
 | 
					        OMEnableMemoryPool
 | 
				
			||||||
 | 
					        OMBundleMemoryPools)
 | 
				
			||||||
set(OMLibs ${OMLibs} PARENT_SCOPE)
 | 
					set(OMLibs ${OMLibs} PARENT_SCOPE)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
message(SATUS "OMLibs" ${OMLibs})
 | 
					message(SATUS "OMLibs" ${OMLibs})
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -461,3 +461,28 @@ Value emitNegativeInfinityConstantOp(
 | 
				
			||||||
int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
 | 
					int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
 | 
				
			||||||
  return (a.getValue()[i]).cast<IntegerAttr>().getInt();
 | 
					  return (a.getValue()[i]).cast<IntegerAttr>().getInt();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					bool checkOpResultIsUsedByGetRef(AllocOp *allocOp) {
 | 
				
			||||||
 | 
					  auto parentBlock = allocOp->getOperation()->getBlock();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  bool opIsUsedInGetRef = false;
 | 
				
			||||||
 | 
					  parentBlock->walk([&opIsUsedInGetRef, allocOp](KrnlGetRefOp op) {
 | 
				
			||||||
 | 
					    auto result = allocOp->getResult();
 | 
				
			||||||
 | 
					    for (const auto &operand : op.getOperands())
 | 
				
			||||||
 | 
					      if (operand == result)
 | 
				
			||||||
 | 
					        opIsUsedInGetRef = true;
 | 
				
			||||||
 | 
					  });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return opIsUsedInGetRef;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// TODO: support dynamic sizes.
 | 
				
			||||||
 | 
					int64_t getMemRefSizeInBytes(Value val) {
 | 
				
			||||||
 | 
					  auto memRefType = convertToMemRefType(val.getType());
 | 
				
			||||||
 | 
					  auto memRefShape = memRefType.getShape();
 | 
				
			||||||
 | 
					  int64_t size = 1;
 | 
				
			||||||
 | 
					  for (int i = 0; i < memRefShape.size(); i++)
 | 
				
			||||||
 | 
					    size *= memRefShape[i];
 | 
				
			||||||
 | 
					  size *= getMemRefEltSizeInBytes(memRefType);
 | 
				
			||||||
 | 
					  return size;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -242,3 +242,7 @@ void populateLoweringONNXSqueezeOpPattern(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void populateLoweringONNXSplitOpPattern(
 | 
					void populateLoweringONNXSplitOpPattern(
 | 
				
			||||||
    OwningRewritePatternList &patterns, MLIRContext *ctx);
 | 
					    OwningRewritePatternList &patterns, MLIRContext *ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					bool checkOpResultIsUsedByGetRef(AllocOp *allocOp);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					int64_t getMemRefSizeInBytes(Value val);
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -42,6 +42,12 @@ void initOMPasses() {
 | 
				
			||||||
        return mlir::createKrnlEnableMemoryPoolPass();
 | 
					        return mlir::createKrnlEnableMemoryPoolPass();
 | 
				
			||||||
      });
 | 
					      });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  mlir::registerPass("bundle-memory-pools",
 | 
				
			||||||
 | 
					      "Bundle memory pools of internal MemRefs into a single memory pool.",
 | 
				
			||||||
 | 
					      []() -> std::unique_ptr<mlir::Pass> {
 | 
				
			||||||
 | 
					        return mlir::createKrnlBundleMemoryPoolsPass();
 | 
				
			||||||
 | 
					      });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  mlir::registerPass(
 | 
					  mlir::registerPass(
 | 
				
			||||||
      "lower-krnl", "Lower Krnl dialect.", []() -> std::unique_ptr<mlir::Pass> {
 | 
					      "lower-krnl", "Lower Krnl dialect.", []() -> std::unique_ptr<mlir::Pass> {
 | 
				
			||||||
        return mlir::createLowerKrnlPass();
 | 
					        return mlir::createLowerKrnlPass();
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -269,6 +269,8 @@ void addONNXToKrnlPasses(mlir::PassManager &pm) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // TODO: make this pass optional:
 | 
					  // TODO: make this pass optional:
 | 
				
			||||||
  pm.addPass(mlir::createKrnlEnableMemoryPoolPass());
 | 
					  pm.addPass(mlir::createKrnlEnableMemoryPoolPass());
 | 
				
			||||||
 | 
					  pm.addPass(mlir::createKrnlBundleMemoryPoolsPass());
 | 
				
			||||||
 | 
					  pm.addPass(mlir::createCanonicalizerPass());
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void addKrnlToAffinePasses(mlir::PassManager &pm) {
 | 
					void addKrnlToAffinePasses(mlir::PassManager &pm) {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -31,6 +31,9 @@ std::unique_ptr<Pass> createElideConstantValuePass();
 | 
				
			||||||
/// Pass for enabling a memory pool for MemRefs.
 | 
					/// Pass for enabling a memory pool for MemRefs.
 | 
				
			||||||
std::unique_ptr<Pass> createKrnlEnableMemoryPoolPass();
 | 
					std::unique_ptr<Pass> createKrnlEnableMemoryPoolPass();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/// Pass for enabling a memory pool for MemRefs.
 | 
				
			||||||
 | 
					std::unique_ptr<Pass> createKrnlBundleMemoryPoolsPass();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/// Add pass for lowering to Krnl IR.
 | 
					/// Add pass for lowering to Krnl IR.
 | 
				
			||||||
std::unique_ptr<Pass> createLowerToKrnlPass();
 | 
					std::unique_ptr<Pass> createLowerToKrnlPass();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,148 @@
 | 
				
			||||||
 | 
					//===-- BundleMemoryPools.cpp - Bundle Memory Pools for  internal MemRefs -===//
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Copyright 2019-2020 The IBM Research Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// =============================================================================
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// For certain cases the number of individual memory allocations required for
 | 
				
			||||||
 | 
					// all internal tensors is large and needs to be mitigated. This pass bundles
 | 
				
			||||||
 | 
					// all the internal MemRef memory pools emitted by the EnableMemoryPool pass
 | 
				
			||||||
 | 
					// int a single memory pool.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "mlir/Dialect/Affine/IR/AffineOps.h"
 | 
				
			||||||
 | 
					#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
				
			||||||
 | 
					#include "mlir/Pass/Pass.h"
 | 
				
			||||||
 | 
					#include "mlir/Transforms/DialectConversion.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
 | 
				
			||||||
 | 
					#include "src/Dialect/Krnl/KrnlOps.hpp"
 | 
				
			||||||
 | 
					#include "src/Pass/Passes.hpp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using namespace mlir;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					KrnlGetRefOp getUnbundledGetRef(AllocOp *memPool) {
 | 
				
			||||||
 | 
					  auto parentBlock = memPool->getOperation()->getBlock();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  KrnlGetRefOp unbundledGetRef = nullptr;
 | 
				
			||||||
 | 
					  parentBlock->walk([&unbundledGetRef, memPool](KrnlGetRefOp op) {
 | 
				
			||||||
 | 
					    auto result = memPool->getResult();
 | 
				
			||||||
 | 
					    if (op.getOperands()[0] != result)
 | 
				
			||||||
 | 
					      unbundledGetRef = op;
 | 
				
			||||||
 | 
					  });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return unbundledGetRef;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/*!
 | 
				
			||||||
 | 
					 *  RewritePattern that replaces:
 | 
				
			||||||
 | 
					 *    %mem1 = alloc() : memref<<dims1>x<type>>
 | 
				
			||||||
 | 
					 *    %mem2 = alloc() : memref<<dims2>x<type>>
 | 
				
			||||||
 | 
					 *    %1 = krnl.getref %mem2 0 : memref<<dims2>x<type>>
 | 
				
			||||||
 | 
					 *  =>
 | 
				
			||||||
 | 
					 *    %mem1 = alloc() : memref<<dims1 + dims2>x<type>>
 | 
				
			||||||
 | 
					 *    %1 = krnl.getref %mem1 <dims1> : memref<<dims2>x<type>>
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *  ASSUMPTION: All krnl.getref operations in the program have been emitted
 | 
				
			||||||
 | 
					 *              by the EnableMemoryPool pass i.e. there are no krnl.getref
 | 
				
			||||||
 | 
					 *              operations which are not related to the memory pool.
 | 
				
			||||||
 | 
					 *              krnl.getref is an operation specific to memory management
 | 
				
			||||||
 | 
					 *              for other use cases use MLIR Standard dialect operations.
 | 
				
			||||||
 | 
					 *              This assumption simplifies the code and avoids additional
 | 
				
			||||||
 | 
					 *              checks to ensure that all the participating krnl.getref
 | 
				
			||||||
 | 
					 *              operations are part of memory pooling.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class KrnlBundleMemoryPools : public OpRewritePattern<AllocOp> {
 | 
				
			||||||
 | 
					public:
 | 
				
			||||||
 | 
					  using OpRewritePattern<AllocOp>::OpRewritePattern;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  LogicalResult matchAndRewrite(
 | 
				
			||||||
 | 
					      AllocOp allocOp, PatternRewriter &rewriter) const override {
 | 
				
			||||||
 | 
					    auto loc = allocOp.getLoc();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto memRefType = convertToMemRefType(allocOp.getResult().getType());
 | 
				
			||||||
 | 
					    auto memRefShape = memRefType.getShape();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // If alloca result is not used by getref then it cannot be part of
 | 
				
			||||||
 | 
					    // the memory pool.
 | 
				
			||||||
 | 
					    if (!checkOpResultIsUsedByGetRef(&allocOp))
 | 
				
			||||||
 | 
					      return failure();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Alloc memory type must be byte.
 | 
				
			||||||
 | 
					    if (getMemRefEltSizeInBytes(memRefType) != 1)
 | 
				
			||||||
 | 
					      return failure();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Rank of the allocated MemRef must be 1.
 | 
				
			||||||
 | 
					    if (memRefShape.size() != 1)
 | 
				
			||||||
 | 
					      return failure();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // TODO: Change this when dyanmic shapes are supported.
 | 
				
			||||||
 | 
					    // TODO: Add support for dynamic shapes.
 | 
				
			||||||
 | 
					    int64_t currentMemPoolSize = memRefShape[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Get a KrnlGetRefOp which does not use the current alloc.
 | 
				
			||||||
 | 
					    if (KrnlGetRefOp unbundledGetRef = getUnbundledGetRef(&allocOp)) {
 | 
				
			||||||
 | 
					      unbundledGetRef.dump();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Current memory pool size is the offset for the newly bundled
 | 
				
			||||||
 | 
					      // internal MemRef. Emit the offset as a constant.
 | 
				
			||||||
 | 
					      auto offset = rewriter.create<ConstantOp>(
 | 
				
			||||||
 | 
					          loc, rewriter.getIntegerAttr(
 | 
				
			||||||
 | 
					                   rewriter.getIntegerType(64), currentMemPoolSize));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Size in bytes of the output of the krnl.getref operation.
 | 
				
			||||||
 | 
					      int64_t unbundledTotalSize =
 | 
				
			||||||
 | 
					          getMemRefSizeInBytes(unbundledGetRef.getResult());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Compute new size.
 | 
				
			||||||
 | 
					      int64_t bundleTotalSize = unbundledTotalSize + currentMemPoolSize;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // We need to emit a new alloc which contains the additional MemRef.
 | 
				
			||||||
 | 
					      SmallVector<int64_t, 1> newMemPoolShape;
 | 
				
			||||||
 | 
					      newMemPoolShape.emplace_back(bundleTotalSize);
 | 
				
			||||||
 | 
					      auto bundledMemPoolMemRefType =
 | 
				
			||||||
 | 
					          MemRefType::get(newMemPoolShape, rewriter.getIntegerType(8));
 | 
				
			||||||
 | 
					      auto bundledAlloc =
 | 
				
			||||||
 | 
					          rewriter.create<AllocOp>(loc, bundledMemPoolMemRefType);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // The newly bundled MemRef expressed as a KrnlGetRefOp.
 | 
				
			||||||
 | 
					      auto bundledMemRef = rewriter.create<KrnlGetRefOp>(
 | 
				
			||||||
 | 
					          loc, unbundledGetRef.getResult().getType(), bundledAlloc, offset);
 | 
				
			||||||
 | 
					      rewriter.replaceOp(unbundledGetRef, bundledMemRef.getResult());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Replace old memory pool with new one.
 | 
				
			||||||
 | 
					      rewriter.replaceOp(allocOp, bundledAlloc.getResult());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      return success();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return failure();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/*!
 | 
				
			||||||
 | 
					 *  Function pass that enables memory pooling for MemRefs.
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					class KrnlBundleMemoryPoolsPass
 | 
				
			||||||
 | 
					    : public PassWrapper<KrnlBundleMemoryPoolsPass, FunctionPass> {
 | 
				
			||||||
 | 
					public:
 | 
				
			||||||
 | 
					  void runOnFunction() override {
 | 
				
			||||||
 | 
					    auto function = getFunction();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ConversionTarget target(getContext());
 | 
				
			||||||
 | 
					    OwningRewritePatternList patterns;
 | 
				
			||||||
 | 
					    patterns.insert<KrnlBundleMemoryPools>(&getContext());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    applyPatternsAndFoldGreedily(function, patterns);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					} // namespace
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::unique_ptr<Pass> mlir::createKrnlBundleMemoryPoolsPass() {
 | 
				
			||||||
 | 
					  return std::make_unique<KrnlBundleMemoryPoolsPass>();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -46,6 +46,8 @@ target_include_directories(OMPackKrnlGlobalConstants
 | 
				
			||||||
        ${ONNX_MLIR_SRC_ROOT}
 | 
					        ${ONNX_MLIR_SRC_ROOT}
 | 
				
			||||||
        ${ONNX_MLIR_BIN_ROOT}
 | 
					        ${ONNX_MLIR_BIN_ROOT}
 | 
				
			||||||
        ${ONNX_MLIR_SRC_ROOT})
 | 
					        ${ONNX_MLIR_SRC_ROOT})
 | 
				
			||||||
 | 
					add_dependencies(OMPackKrnlGlobalConstants
 | 
				
			||||||
 | 
					        OMKrnlOps)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
add_library(OMEnableMemoryPool
 | 
					add_library(OMEnableMemoryPool
 | 
				
			||||||
        EnableMemoryPool.cpp)
 | 
					        EnableMemoryPool.cpp)
 | 
				
			||||||
| 
						 | 
					@ -54,13 +56,23 @@ target_include_directories(OMEnableMemoryPool
 | 
				
			||||||
        ${ONNX_MLIR_SRC_ROOT}
 | 
					        ${ONNX_MLIR_SRC_ROOT}
 | 
				
			||||||
        ${ONNX_MLIR_BIN_ROOT}
 | 
					        ${ONNX_MLIR_BIN_ROOT}
 | 
				
			||||||
        ${ONNX_MLIR_SRC_ROOT})
 | 
					        ${ONNX_MLIR_SRC_ROOT})
 | 
				
			||||||
add_dependencies(OMPackKrnlGlobalConstants
 | 
					 | 
				
			||||||
        OMKrnlOps)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
target_link_libraries(OMEnableMemoryPool
 | 
					target_link_libraries(OMEnableMemoryPool
 | 
				
			||||||
        onnx)
 | 
					        onnx)
 | 
				
			||||||
add_dependencies(OMEnableMemoryPool
 | 
					add_dependencies(OMEnableMemoryPool
 | 
				
			||||||
        OMKrnlOps
 | 
					        OMKrnlOps
 | 
				
			||||||
        OMONNXOps)
 | 
					        OMONNXOps)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					add_library(OMBundleMemoryPools
 | 
				
			||||||
 | 
					        BundleMemoryPools.cpp)
 | 
				
			||||||
 | 
					target_include_directories(OMBundleMemoryPools
 | 
				
			||||||
 | 
					        PRIVATE
 | 
				
			||||||
 | 
					        ${ONNX_MLIR_SRC_ROOT}
 | 
				
			||||||
 | 
					        ${ONNX_MLIR_BIN_ROOT}
 | 
				
			||||||
 | 
					        ${ONNX_MLIR_SRC_ROOT})
 | 
				
			||||||
 | 
					target_link_libraries(OMBundleMemoryPools
 | 
				
			||||||
 | 
					        onnx)
 | 
				
			||||||
 | 
					add_dependencies(OMBundleMemoryPools
 | 
				
			||||||
 | 
					        OMKrnlOps
 | 
				
			||||||
 | 
					        OMONNXOps)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
add_subdirectory(ONNX)
 | 
					add_subdirectory(ONNX)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -37,20 +37,6 @@ bool checkOpResultIsReturned(AllocOp *allocOp) {
 | 
				
			||||||
  return opIsReturned;
 | 
					  return opIsReturned;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
bool checkOpResultIsUsedByGetRef(AllocOp *allocOp) {
 | 
					 | 
				
			||||||
  auto parentBlock = allocOp->getOperation()->getBlock();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  bool opIsUsedInGetRef = false;
 | 
					 | 
				
			||||||
  parentBlock->walk([&opIsUsedInGetRef, allocOp](KrnlGetRefOp op) {
 | 
					 | 
				
			||||||
    auto result = allocOp->getResult();
 | 
					 | 
				
			||||||
    for (const auto &operand : op.getOperands())
 | 
					 | 
				
			||||||
      if (operand == result)
 | 
					 | 
				
			||||||
        opIsUsedInGetRef = true;
 | 
					 | 
				
			||||||
  });
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  return opIsUsedInGetRef;
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
/*!
 | 
					/*!
 | 
				
			||||||
 *  RewritePattern that replaces:
 | 
					 *  RewritePattern that replaces:
 | 
				
			||||||
 *    %0 = alloc() : memref<<dims>x<type>>
 | 
					 *    %0 = alloc() : memref<<dims>x<type>>
 | 
				
			||||||
| 
						 | 
					@ -84,11 +70,7 @@ public:
 | 
				
			||||||
      return failure();
 | 
					      return failure();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Compute total size.
 | 
					    // Compute total size.
 | 
				
			||||||
    auto memRefShape = memRefType.getShape();
 | 
					    int64_t totalSize = getMemRefSizeInBytes(allocOp.getResult());
 | 
				
			||||||
    int64_t totalSize = 1;
 | 
					 | 
				
			||||||
    for (int i = 0; i < memRefShape.size(); i++)
 | 
					 | 
				
			||||||
      totalSize *= memRefShape[i];
 | 
					 | 
				
			||||||
    totalSize *= getMemRefEltSizeInBytes(memRefType);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Emit new alloc.
 | 
					    // Emit new alloc.
 | 
				
			||||||
    SmallVector<int64_t, 1> memPoolShape;
 | 
					    SmallVector<int64_t, 1> memPoolShape;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,53 @@
 | 
				
			||||||
 | 
					// RUN: onnx-mlir-opt --bundle-memory-pools --canonicalize %s -split-input-file | FileCheck %s
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_pool_bundling(%arg0: memref<10x10xf32>, %arg1: memref<10x20xf32>) -> memref<10x20xf32> {
 | 
				
			||||||
 | 
					  %c0_i64 = constant 0 : i64
 | 
				
			||||||
 | 
					  %ind = constant 0 : index
 | 
				
			||||||
 | 
					  %cst = constant 0.000000e+00 : f32
 | 
				
			||||||
 | 
					  %0 = alloc() : memref<10x20xf32>
 | 
				
			||||||
 | 
					  %1 = alloc() : memref<800xi8>
 | 
				
			||||||
 | 
					  %2 = "krnl.getref"(%1, %c0_i64) : (memref<800xi8>, i64) -> memref<10x20xf32>
 | 
				
			||||||
 | 
					  %3 = alloc() : memref<400xi8>
 | 
				
			||||||
 | 
					  %4 = "krnl.getref"(%3, %c0_i64) : (memref<400xi8>, i64) -> memref<10x10xf32>
 | 
				
			||||||
 | 
					  %5 = alloc() : memref<800xi8>
 | 
				
			||||||
 | 
					  %6 = "krnl.getref"(%5, %c0_i64) : (memref<800xi8>, i64) -> memref<10x20xf32>
 | 
				
			||||||
 | 
					  %7 = alloc() : memref<800xi8>
 | 
				
			||||||
 | 
					  %8 = "krnl.getref"(%7, %c0_i64) : (memref<800xi8>, i64) -> memref<10x20xf32>
 | 
				
			||||||
 | 
					  %9 = alloc() : memref<400xi8>
 | 
				
			||||||
 | 
					  %10 = "krnl.getref"(%9, %c0_i64) : (memref<400xi8>, i64) -> memref<10x10xf32>
 | 
				
			||||||
 | 
					  affine.store %cst, %10[%ind, %ind] : memref<10x10xf32>
 | 
				
			||||||
 | 
					  affine.store %cst, %8[%ind, %ind] : memref<10x20xf32>
 | 
				
			||||||
 | 
					  affine.store %cst, %6[%ind, %ind] : memref<10x20xf32>
 | 
				
			||||||
 | 
					  affine.store %cst, %4[%ind, %ind] : memref<10x10xf32>
 | 
				
			||||||
 | 
					  affine.store %cst, %2[%ind, %ind] : memref<10x20xf32>
 | 
				
			||||||
 | 
					  affine.store %cst, %0[%ind, %ind] : memref<10x20xf32>
 | 
				
			||||||
 | 
					  dealloc %9 : memref<400xi8>
 | 
				
			||||||
 | 
					  dealloc %7 : memref<800xi8>
 | 
				
			||||||
 | 
					  dealloc %5 : memref<800xi8>
 | 
				
			||||||
 | 
					  dealloc %3 : memref<400xi8>
 | 
				
			||||||
 | 
					  dealloc %1 : memref<800xi8>
 | 
				
			||||||
 | 
					  return %0 : memref<10x20xf32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-LABEL: test_pool_bundling
 | 
				
			||||||
 | 
					  // CHECK: [[CONST_0:%.+]] = constant 0 : i64
 | 
				
			||||||
 | 
					  // CHECK: [[CONST_CST:%.+]] = constant 0.000000e+00 : f32
 | 
				
			||||||
 | 
					  // CHECK: [[CONST_400:%.+]] = constant 400 : i64
 | 
				
			||||||
 | 
					  // CHECK: [[CONST_1200:%.+]] = constant 1200 : i64
 | 
				
			||||||
 | 
					  // CHECK: [[CONST_2000:%.+]] = constant 2000 : i64
 | 
				
			||||||
 | 
					  // CHECK: [[CONST_2400:%.+]] = constant 2400 : i64
 | 
				
			||||||
 | 
					  // CHECK: [[RES:%.+]] = alloc() : memref<10x20xf32>
 | 
				
			||||||
 | 
					  // CHECK: [[MEMPOOL:%.+]] = alloc() : memref<3200xi8>
 | 
				
			||||||
 | 
					  // CHECK: [[MEMREF1:%.+]] = "krnl.getref"([[MEMPOOL]], [[CONST_2400]]) : (memref<3200xi8>, i64) -> memref<10x20xf32>
 | 
				
			||||||
 | 
					  // CHECK: [[MEMREF2:%.+]] = "krnl.getref"([[MEMPOOL]], [[CONST_2000]]) : (memref<3200xi8>, i64) -> memref<10x10xf32>
 | 
				
			||||||
 | 
					  // CHECK: [[MEMREF3:%.+]] = "krnl.getref"([[MEMPOOL]], [[CONST_1200]]) : (memref<3200xi8>, i64) -> memref<10x20xf32>
 | 
				
			||||||
 | 
					  // CHECK: [[MEMREF4:%.+]] = "krnl.getref"([[MEMPOOL]], [[CONST_400]]) : (memref<3200xi8>, i64) -> memref<10x20xf32>
 | 
				
			||||||
 | 
					  // CHECK: [[MEMREF5:%.+]] = "krnl.getref"([[MEMPOOL]], [[CONST_0]]) : (memref<3200xi8>, i64) -> memref<10x10xf32>
 | 
				
			||||||
 | 
					  // CHECK: affine.store %cst, [[MEMREF5]][0, 0] : memref<10x10xf32>
 | 
				
			||||||
 | 
					  // CHECK: affine.store %cst, [[MEMREF4]][0, 0] : memref<10x20xf32>
 | 
				
			||||||
 | 
					  // CHECK: affine.store %cst, [[MEMREF3]][0, 0] : memref<10x20xf32>
 | 
				
			||||||
 | 
					  // CHECK: affine.store %cst, [[MEMREF2]][0, 0] : memref<10x10xf32>
 | 
				
			||||||
 | 
					  // CHECK: affine.store %cst, [[MEMREF1]][0, 0] : memref<10x20xf32>
 | 
				
			||||||
 | 
					  // CHECK: affine.store %cst, [[RES]][0, 0] : memref<10x20xf32>
 | 
				
			||||||
 | 
					  // CHECK: dealloc [[MEMPOOL]] : memref<3200xi8>
 | 
				
			||||||
 | 
					  // CHECK: return [[RES]] : memref<10x20xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,28 @@
 | 
				
			||||||
 | 
					// RUN: onnx-mlir-opt --shape-inference --lower-frontend --enable-memory-pool --bundle-memory-pools --canonicalize %s -split-input-file | FileCheck %s
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func @test_bundle_memory_pool(%arg0: tensor<10x10xf32>, %arg1: tensor<10x20xf32>) -> tensor<10x20xf32> {
 | 
				
			||||||
 | 
					  %0 = "onnx.Add"(%arg0, %arg0) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
 | 
				
			||||||
 | 
					  %1 = "onnx.MatMul"(%0, %arg1) : (tensor<10x10xf32>, tensor<10x20xf32>) -> tensor<10x20xf32>
 | 
				
			||||||
 | 
					  %2 = "onnx.Add"(%1, %arg1) : (tensor<10x20xf32>, tensor<10x20xf32>) -> tensor<10x20xf32>
 | 
				
			||||||
 | 
					  %3 = "onnx.Add"(%0, %arg0) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
 | 
				
			||||||
 | 
					  %4 = "onnx.MatMul"(%3, %arg1) : (tensor<10x10xf32>, tensor<10x20xf32>) -> tensor<10x20xf32>
 | 
				
			||||||
 | 
					  %5 = "onnx.Add"(%4, %arg1) : (tensor<10x20xf32>, tensor<10x20xf32>) -> tensor<10x20xf32>
 | 
				
			||||||
 | 
					  return %5 : tensor<10x20xf32>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-LABEL: test_bundle_memory_pool
 | 
				
			||||||
 | 
					  // CHECK: [[CONST0:%.+]] = constant 0 : i64
 | 
				
			||||||
 | 
					  // CHECK: [[CONST00:%.+]] = constant 0.000000e+00 : f32
 | 
				
			||||||
 | 
					  // CHECK: [[CONST400:%.+]] = constant 400 : i64
 | 
				
			||||||
 | 
					  // CHECK: [[CONST1200:%.+]] = constant 1200 : i64
 | 
				
			||||||
 | 
					  // CHECK: [[CONST2000:%.+]] = constant 2000 : i64
 | 
				
			||||||
 | 
					  // CHECK: [[CONST2400:%.+]] = constant 2400 : i64
 | 
				
			||||||
 | 
					  // CHECK: [[RES:%.+]] = alloc() : memref<10x20xf32>
 | 
				
			||||||
 | 
					  // CHECK: [[MEMPOOL:%.+]] = alloc() : memref<3200xi8>
 | 
				
			||||||
 | 
					  // CHECK: "krnl.getref"([[MEMPOOL]], [[CONST2400]]) : (memref<3200xi8>, i64) -> memref<10x20xf32>
 | 
				
			||||||
 | 
					  // CHECK: "krnl.getref"([[MEMPOOL]], [[CONST2000]]) : (memref<3200xi8>, i64) -> memref<10x10xf32>
 | 
				
			||||||
 | 
					  // CHECK: "krnl.getref"([[MEMPOOL]], [[CONST1200]]) : (memref<3200xi8>, i64) -> memref<10x20xf32>
 | 
				
			||||||
 | 
					  // CHECK: "krnl.getref"([[MEMPOOL]], [[CONST400]]) : (memref<3200xi8>, i64) -> memref<10x20xf32>
 | 
				
			||||||
 | 
					  // CHECK: "krnl.getref"([[MEMPOOL]], [[CONST0]]) : (memref<3200xi8>, i64) -> memref<10x10xf32>
 | 
				
			||||||
 | 
					  // CHECK: dealloc [[MEMPOOL]] : memref<3200xi8>
 | 
				
			||||||
 | 
					  // CHECK: return [[RES]] : memref<10x20xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
		Loading…
	
		Reference in New Issue