//===-- BundleMemoryPools.cpp - Bundle Memory Pools for internal MemRefs -===// // // Copyright 2019-2020 The IBM Research Authors. // // ============================================================================= // // For certain cases the number of individual memory allocations required for // all internal tensors is large and needs to be mitigated. This pass bundles // all the internal MemRef memory pools emitted by the EnableMemoryPool pass // int a single memory pool. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" #include "src/Dialect/Krnl/KrnlOps.hpp" #include "src/Pass/Passes.hpp" using namespace mlir; namespace { KrnlGetRefOp getUnbundledGetRef(AllocOp *memPool) { auto parentBlock = memPool->getOperation()->getBlock(); KrnlGetRefOp unbundledGetRef = nullptr; parentBlock->walk([&unbundledGetRef, memPool](KrnlGetRefOp op) { auto result = memPool->getResult(); if (op.getOperands()[0] != result) unbundledGetRef = op; }); return unbundledGetRef; } /*! * RewritePattern that replaces: * %mem1 = alloc() : memref<x> * %mem2 = alloc() : memref<x> * %1 = krnl.getref %mem2 0 : memref<x> * => * %mem1 = alloc() : memref<x> * %1 = krnl.getref %mem1 : memref<x> * * * ASSUMPTION: All krnl.getref operations in the program have been emitted * by the EnableMemoryPool pass i.e. there are no krnl.getref * operations which are not related to the memory pool. * krnl.getref is an operation specific to memory management * for other use cases use MLIR Standard dialect operations. * This assumption simplifies the code and avoids additional * checks to ensure that all the participating krnl.getref * operations are part of memory pooling. */ class KrnlBundleMemoryPools : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite( AllocOp allocOp, PatternRewriter &rewriter) const override { auto loc = allocOp.getLoc(); auto memRefType = convertToMemRefType(allocOp.getResult().getType()); auto memRefShape = memRefType.getShape(); // If alloca result is not used by getref then it cannot be part of // the memory pool. if (!checkOpResultIsUsedByGetRef(&allocOp)) return failure(); // Alloc memory type must be byte. if (getMemRefEltSizeInBytes(memRefType) != 1) return failure(); // Rank of the allocated MemRef must be 1. if (memRefShape.size() != 1) return failure(); // TODO: Change this when dyanmic shapes are supported. // TODO: Add support for dynamic shapes. int64_t currentMemPoolSize = memRefShape[0]; // Get a KrnlGetRefOp which does not use the current alloc. if (KrnlGetRefOp unbundledGetRef = getUnbundledGetRef(&allocOp)) { // Current memory pool size is the offset for the newly bundled // internal MemRef. Emit the offset as a constant. auto offset = rewriter.create( loc, rewriter.getIntegerAttr( rewriter.getIntegerType(64), currentMemPoolSize)); // Size in bytes of the output of the krnl.getref operation. int64_t unbundledTotalSize = getMemRefSizeInBytes(unbundledGetRef.getResult()); // Compute new size. int64_t bundleTotalSize = unbundledTotalSize + currentMemPoolSize; // We need to emit a new alloc which contains the additional MemRef. SmallVector newMemPoolShape; newMemPoolShape.emplace_back(bundleTotalSize); auto bundledMemPoolMemRefType = MemRefType::get(newMemPoolShape, rewriter.getIntegerType(8)); auto bundledAlloc = rewriter.create(loc, bundledMemPoolMemRefType); // The newly bundled MemRef expressed as a KrnlGetRefOp. auto bundledMemRef = rewriter.create( loc, unbundledGetRef.getResult().getType(), bundledAlloc, offset); rewriter.replaceOp(unbundledGetRef, bundledMemRef.getResult()); // Replace old memory pool with new one. rewriter.replaceOp(allocOp, bundledAlloc.getResult()); return success(); } return failure(); } }; /*! * Function pass that enables memory pooling for MemRefs. */ class KrnlBundleMemoryPoolsPass : public PassWrapper { public: void runOnFunction() override { auto function = getFunction(); ConversionTarget target(getContext()); OwningRewritePatternList patterns; patterns.insert(&getContext()); applyPatternsAndFoldGreedily(function, patterns); } }; } // namespace std::unique_ptr mlir::createKrnlBundleMemoryPoolsPass() { return std::make_unique(); }