diff --git a/src/Transform/BundleMemoryPools.cpp b/src/Transform/BundleMemoryPools.cpp index b213fcf..199a664 100644 --- a/src/Transform/BundleMemoryPools.cpp +++ b/src/Transform/BundleMemoryPools.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/SetVector.h" #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" #include "src/Dialect/Krnl/KrnlOps.hpp" @@ -24,6 +25,102 @@ using namespace mlir; namespace { +//===----------------------------------------------------------------------===// +// Insertion point for initialization instructions and the blocks used for +// inserting the initialization and main code. These blocks will disappear +// when the first canonicalization is performed because the init block +// unconditionally branches into the second block. These blocks exist only for +// the purpose of this optimization. +// The information is recorded on a per function basis. +//===----------------------------------------------------------------------===// + +typedef struct ONNXOperandsInitState { + Block *initBlock; + Block *mainBlock; + BranchOp branchInit; + AllocOp dynamicMemoryPool; +} ONNXOperandsInitState; + +// Helper data structure for the bundling of dynamic AllocOps. +std::map> initMap; + +//===----------------------------------------------------------------------===// +// Helper functions. +//===----------------------------------------------------------------------===// + +/// Retrieve function which contains the current operation. +FuncOp getContainingFunction(AllocOp op) { + Operation *parentFuncOp = op.getParentOp(); + + // While parent is not a FuncOp and its cast to a FuncOp is null. + while (!llvm::dyn_cast_or_null(parentFuncOp)) + parentFuncOp = parentFuncOp->getParentOp(); + + return cast(parentFuncOp); +} + +bool hasInitBlock(FuncOp function) { + std::unique_ptr &initState = initMap.at(function); + return initState->initBlock != nullptr; +} + +bool addInitBlock(PatternRewriter &rewriter, Location loc, AllocOp allocOp) { + // If this is the first time we encounter an operation in this + // function, we create an entry inside the initMap and split the + // function body into an init block and a main block. + // + // function func_name() { + // ... init block ... + // br ^bb1 + // ^bb1: // pred: ^bb0 + // ... main block ... + // return + // } + // + // Note: the block ^bb0 being the first block has its label omitted. + // + FuncOp function = getContainingFunction(allocOp); + // If the function does not contain an init block, create one. + if (!hasInitBlock(function)) { + std::unique_ptr &initState = initMap.at(function); + initState = std::make_unique(); + + // All input arguments are considered as part of the initialization block + // so add them to the operandsInInitBlock set. + Block *functionBlock = &function.front(); + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(functionBlock); + + initState->initBlock = rewriter.getInsertionBlock(); + auto currentPoint = rewriter.getInsertionPoint(); + initState->mainBlock = + rewriter.splitBlock(initState->initBlock, currentPoint); + + rewriter.setInsertionPointToEnd(initState->initBlock); + + // Insert a branch operation from initBlock to mainBlock. This + // ensures the final code contains legal blocks. + initState->branchInit = + rewriter.create(loc, initState->mainBlock); + + rewriter.setInsertionPointToStart(initState->mainBlock); + + // Save a reference to the current dynamic memory pool value. + initState->dynamicMemoryPool = allocOp; + + return true; + } + + return false; +} + +bool isBlockArgument(Block *block, Value operand) { + for (auto arg : block->getArguments()) + if (operand == arg) + return true; + return false; +} + KrnlGetRefOp getUnbundledGetRef(AllocOp *memPool) { auto parentBlock = memPool->getOperation()->getBlock(); @@ -37,6 +134,23 @@ KrnlGetRefOp getUnbundledGetRef(AllocOp *memPool) { return unbundledGetRef; } +KrnlGetRefOp getCurrentAllocGetRef(AllocOp *allocOp) { + auto parentBlock = allocOp->getOperation()->getBlock(); + + KrnlGetRefOp currentAllocGetRef = nullptr; + parentBlock->walk([¤tAllocGetRef, allocOp](KrnlGetRefOp op) { + auto result = allocOp->getResult(); + if (op.getOperands()[0] == result) + currentAllocGetRef = op; + }); + + return currentAllocGetRef; +} + +//===----------------------------------------------------------------------===// +// Rewrite patterns. +//===----------------------------------------------------------------------===// + /*! * RewritePattern that replaces: * %mem1 = alloc() : memref<x> @@ -73,7 +187,7 @@ public: if (!checkOpResultIsUsedByGetRef(&allocOp)) return failure(); - // TODO: remove once we support the bundling of dynamic memory pools. + // Only handle constant AllocOps. if (!hasAllConstantDimensions(memRefType)) return failure(); @@ -85,12 +199,16 @@ public: if (memRefShape.size() != 1) return failure(); - // TODO: Change this when dyanmic shapes are supported. - // TODO: Add support for dynamic shapes. int64_t currentMemPoolSize = memRefShape[0]; // Get a KrnlGetRefOp which does not use the current alloc. if (KrnlGetRefOp unbundledGetRef = getUnbundledGetRef(&allocOp)) { + // Make sure that this get ref uses a static alloc. + auto unbundledGetRefType = + convertToMemRefType(unbundledGetRef.getResult().getType()); + if (!hasAllConstantDimensions(unbundledGetRefType)) + return failure(); + // Current memory pool size is the offset for the newly bundled // internal MemRef. Emit the offset as a constant. auto offset = rewriter.create( @@ -127,6 +245,201 @@ public: } }; +/*! + * RewritePattern that merges a new dynamic AllocOp with the existing dynamic + * memory pool. + * %dyn_mempool = alloc(%a) : memref + * %new_alloc = alloc(%b) : memref + * %new_ref = krnl.getref %new_alloc 0 : memref + * => + * %c = addi %a, %b + * %dyn_mempool = alloc(%c) : memref + * %new_ref = krnl.getref %dyn_mempool %a : memref + */ + +class KrnlBundleDynamicMemoryPools : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + AllocOp allocOp, PatternRewriter &rewriter) const override { + auto loc = allocOp.getLoc(); + + auto memRefType = convertToMemRefType(allocOp.getResult().getType()); + auto memRefShape = memRefType.getShape(); + + // If alloca result is not used by getref then it cannot be part of + // the memory pool. + if (!checkOpResultIsUsedByGetRef(&allocOp)) + return failure(); + + // Only handle dynamic allocs here. + if (hasAllConstantDimensions(memRefType)) + return failure(); + + // Alloc memory type must be byte. + if (getMemRefEltSizeInBytes(memRefType) != 1) + return failure(); + + // Rank of the allocated MemRef must be 1. + if (memRefShape.size() != 1) + return failure(); + + // Visit dependendent operations in the current parent block and assemble + // a trace of operations which participate in the computation of the size + // of the AllocOp. + auto parentBlock = allocOp.getOperation()->getBlock(); + + // Get function. + FuncOp function = getContainingFunction(allocOp); + Block *firstBlock = &function.getBody().front(); + + // If this is the alloc representing the memory pool and the function + // already has an init block, pattern matching must fail to avoid + // processing the dynamic memory pool a second time. + if (hasInitBlock(function)) { + std::unique_ptr &initState = initMap.at(function); + if (allocOp == initState->dynamicMemoryPool) + return failure(); + } + + // Initialize work queue data structure. + Operation *op = allocOp.getOperation(); + std::vector operandList; + for (const auto &operand : allocOp.getOperands()) { + operandList.emplace_back(operand); + } + + // Check if list of operations depends on dynamic local AllocOp. + bool dependsOnLocalDynamicAlloc = false; + + // Construct the list of Values on which the current AllocOp depends on. + llvm::SetVector dependentOps; + while (operandList.size() > 0) { + Value currentElement = operandList[0]; + Operation *definingOperation = currentElement.getDefiningOp(); + + // If this value has not been seen before, process it. + if (dependentOps.count(definingOperation) == 0) { + // Add value to dependent values list. + dependentOps.insert(definingOperation); + + // Add operands to work queue. + // printf("Processing the args of the following op:\n"); + for (const auto &operand : definingOperation->getOperands()) { + // Check operand is not a block argument. If it is skip it, we + // consider block arguments to be leafs. + if (!isBlockArgument(firstBlock, operand)) { + operandList.emplace_back(operand); + + // Check if the current operation is a dim or a load and the + // argument list involves a local AllocOp with dynamic sizes. + // If that's the case then it means that the whole set of + // instructions cannot be moved. + // Check if the current operation is a DimOp or a LoadOp. + if (llvm::dyn_cast(definingOperation) || + llvm::dyn_cast(definingOperation)) { + Operation *operandOp = operand.getDefiningOp(); + if (operandOp) { + auto localAlloc = llvm::dyn_cast(operandOp); + if (localAlloc) { + auto memRefType = + convertToMemRefType(localAlloc.getResult().getType()); + if (!hasAllConstantDimensions(memRefType)) + dependsOnLocalDynamicAlloc = true; + } + + // If operand is a getref then this alloc cannot be bundled. + auto memPool = llvm::dyn_cast(operandOp); + if (memPool) + dependsOnLocalDynamicAlloc = true; + } + } + } + } + } + + // Erase first element from work queue. + operandList.erase(operandList.begin()); + } + + if (dependsOnLocalDynamicAlloc) + return failure(); + + // Order the dependent values in the same order they appear in the code. + // One cannot iterate over and make changes to the order of the operations + // of a block. A temporary ordered list of dependent instructions is + // necessary. + llvm::SmallVector orderedDependentOps; + for (auto &op : + llvm::make_range(parentBlock->begin(), std::prev(parentBlock->end()))) + if (dependentOps.count(&op) > 0) + orderedDependentOps.emplace_back(&op); + + // If no dynamic alloc is in the trace of the dependent operations, + // emit the size calculation in the init block, if one exists already, + // if not, create the init block. + bool addedInitBlock = addInitBlock(rewriter, loc, allocOp); + + // Move the ordered dependent size calculation to the init block. + std::unique_ptr &initState = initMap.at(function); + for (auto &op : orderedDependentOps) + op->moveBefore(initState->branchInit); + + // Bundle MemRef type: + SmallVector memPoolShape; + memPoolShape.emplace_back(-1); + auto bundledMemPoolMemRefType = + MemRefType::get(memPoolShape, rewriter.getIntegerType(8)); + + // Get the getref of the current allocOp. There is exactly one such getref. + KrnlGetRefOp currentAllocGetRef = getCurrentAllocGetRef(&allocOp); + if (!currentAllocGetRef) + return failure(); + + // Add the current alloc size to the current MemPool size. + Value dynamicMemoryPoolSize = initState->dynamicMemoryPool.getOperand(0); + if (addedInitBlock) { + Value zero = emitConstantOp(rewriter, loc, rewriter.getIndexType(), 0); + zero.getDefiningOp()->moveBefore(initState->branchInit); + dynamicMemoryPoolSize = zero; + } + + AddIOp bundledAllocOperand = rewriter.create( + loc, dynamicMemoryPoolSize, allocOp.getOperand(0)); + bundledAllocOperand.getOperation()->moveBefore(initState->branchInit); + + // The newly bundled MemRef expressed as a KrnlGetRefOp. + // Current memory pool size is the offset for the newly bundled + // internal MemRef. + Value integerDynamicMemoryPoolSize = rewriter.create( + loc, dynamicMemoryPoolSize, rewriter.getIntegerType(64)); + integerDynamicMemoryPoolSize.getDefiningOp()->moveBefore( + initState->branchInit); + + // We need to emit a new alloc which contains the additional MemRef. + AllocOp bundledAlloc = rewriter.create( + loc, bundledMemPoolMemRefType, bundledAllocOperand.getResult()); + bundledAlloc.getOperation()->moveBefore(&initState->mainBlock->front()); + + KrnlGetRefOp bundledMemRef = rewriter.create(loc, + currentAllocGetRef.getResult().getType(), bundledAlloc, + integerDynamicMemoryPoolSize); + bundledMemRef.getOperation()->moveAfter(bundledAlloc); + + // Replace old memory pool with new one. + rewriter.replaceOp(initState->dynamicMemoryPool, bundledAlloc.getResult()); + + // Replace old getref with new getref from new memory pool. + rewriter.replaceOp(currentAllocGetRef, bundledMemRef.getResult()); + + // Update MemPool size. + initState->dynamicMemoryPool = bundledAlloc; + + return success(); + } +}; + /*! * Function pass that enables memory pooling for MemRefs. */ @@ -136,11 +449,22 @@ public: void runOnFunction() override { auto function = getFunction(); + // ModuleOp module = cast(function.getParentOp()); + initMap.insert(std::pair>( + function, std::make_unique())); + + // Initialize state for this function. + std::unique_ptr &initState = initMap.at(function); + initState->initBlock = nullptr; + ConversionTarget target(getContext()); OwningRewritePatternList patterns; - patterns.insert(&getContext()); + patterns.insert( + &getContext()); applyPatternsAndFoldGreedily(function, patterns); + + initMap.erase(function); } }; } // namespace diff --git a/test/mlir/krnl/krnl_bundle_memory_pool.mlir b/test/mlir/krnl/krnl_bundle_memory_pool.mlir index 750439a..31b224b 100644 --- a/test/mlir/krnl/krnl_bundle_memory_pool.mlir +++ b/test/mlir/krnl/krnl_bundle_memory_pool.mlir @@ -51,3 +51,135 @@ func @test_pool_bundling(%arg0: memref<10x10xf32>, %arg1: memref<10x20xf32>) -> // CHECK: dealloc [[MEMPOOL]] : memref<3200xi8> // CHECK: return [[RES]] : memref<10x20xf32> } + +func @test_dynamic_pool_bundling(%arg0: memref) -> memref { + %c1 = constant 1 : index + %c0 = constant 0 : index + %cst = constant 0.000000e+00 : f32 + %ind = constant 0 : index + %c4 = constant 4 : index + %c10 = constant 10 : index + %c0_i64 = constant 0 : i64 + %0 = dim %arg0, %c0 : memref + %1 = muli %0, %c4 : index + %2 = muli %1, %c10 : index + %3 = alloc(%2) : memref + %4 = "krnl.getref"(%3, %c0_i64) : (memref, i64) -> memref + %6 = cmpi "sgt", %0, %0 : index + %7 = select %6, %0, %0 : index + %8 = muli %7, %c4 : index + %9 = muli %8, %c10 : index + %10 = alloc(%9) : memref + %11 = "krnl.getref"(%10, %c0_i64) : (memref, i64) -> memref + %12 = cmpi "eq", %0, %c1 : index + %13 = cmpi "eq", %0, %c1 : index + %15 = alloc(%0) : memref + affine.store %cst, %4[%ind, %ind] : memref + affine.store %cst, %11[%ind, %ind] : memref + affine.store %cst, %15[%ind, %ind] : memref + dealloc %10 : memref + dealloc %3 : memref + return %15 : memref + + // CHECK-LABEL: test_dynamic_pool_bundling + // CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32 + // CHECK: [[C0:%.+]] = constant 0 : index + // CHECK: [[C4:%.+]] = constant 4 : index + // CHECK: [[C10:%.+]] = constant 10 : index + // CHECK: [[C0_I64:%.+]] = constant 0 : i64 + // CHECK: [[DIM:%.+]] = dim %arg0, [[C0]] : memref + // CHECK: [[SGT:%.+]] = cmpi "sgt", [[DIM]], [[DIM]] : index + // CHECK: [[SELECT:%.+]] = select [[SGT]], [[DIM]], [[DIM]] : index + // CHECK: [[MUL1:%.+]] = muli [[SELECT]], [[C4]] : index + // CHECK: [[OFFSET1:%.+]] = muli [[MUL1]], [[C10]] : index + // CHECK: [[MUL2:%.+]] = muli [[DIM]], [[C4]] : index + // CHECK: [[OFFSET2:%.+]] = muli [[MUL2]], [[C10]] : index + // CHECK: [[MEMPOOL_SIZE:%.+]] = addi [[OFFSET1]], [[OFFSET2]] : index + // CHECK: [[OFFSET1_I64:%.+]] = index_cast [[OFFSET1]] : index to i64 + // CHECK: [[DYN_MEMPOOL:%.+]] = alloc([[MEMPOOL_SIZE]]) : memref + // CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref, i64) -> memref + // CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[C0_I64]]) : (memref, i64) -> memref + // CHECK: [[RES:%.+]] = alloc([[DIM]]) : memref + // CHECK: affine.store [[CST]], [[DATA1]][0, 0] : memref + // CHECK: affine.store [[CST]], [[DATA2]][0, 0] : memref + // CHECK: affine.store [[CST]], [[RES]][0, 0] : memref + // CHECK: dealloc [[DYN_MEMPOOL]] : memref + // CHECK: return [[RES]] : memref +} + +func @test_dynamic_and_static_pool_bundling(%arg0: memref, %arg1: memref<10x10xf32>) -> memref { + %c1 = constant 1 : index + %c0 = constant 0 : index + %cst = constant 0.000000e+00 : f32 + %ind = constant 0 : index + %c4 = constant 4 : index + %c10 = constant 10 : index + %c0_i64 = constant 0 : i64 + %0 = dim %arg0, %c0 : memref + %1 = muli %0, %c4 : index + %2 = muli %1, %c10 : index + %3 = alloc(%2) : memref + %4 = "krnl.getref"(%3, %c0_i64) : (memref, i64) -> memref + %const_alloc1 = alloc() : memref<800xi8> + %const_ref1 = "krnl.getref"(%const_alloc1, %c0_i64) : (memref<800xi8>, i64) -> memref<10x20xf32> + %const_alloc2 = alloc() : memref<400xi8> + %const_ref2 = "krnl.getref"(%const_alloc2, %c0_i64) : (memref<400xi8>, i64) -> memref<10x10xf32> + %6 = cmpi "sgt", %0, %0 : index + %7 = select %6, %0, %0 : index + %8 = muli %7, %c4 : index + %9 = muli %8, %c10 : index + %10 = alloc(%9) : memref + %11 = "krnl.getref"(%10, %c0_i64) : (memref, i64) -> memref + %12 = cmpi "eq", %0, %c1 : index + %13 = cmpi "eq", %0, %c1 : index + %15 = alloc(%0) : memref + %const_alloc3 = alloc() : memref<1600xi8> + %const_ref3 = "krnl.getref"(%const_alloc3, %c0_i64) : (memref<1600xi8>, i64) -> memref<10x40xf32> + affine.store %cst, %4[%ind, %ind] : memref + affine.store %cst, %11[%ind, %ind] : memref + affine.store %cst, %15[%ind, %ind] : memref + affine.store %cst, %const_ref2[%ind, %ind] : memref<10x10xf32> + affine.store %cst, %const_ref1[%ind, %ind] : memref<10x20xf32> + affine.store %cst, %const_ref3[%ind, %ind] : memref<10x40xf32> + dealloc %10 : memref + dealloc %3 : memref + dealloc %const_alloc1 : memref<800xi8> + dealloc %const_alloc2 : memref<400xi8> + dealloc %const_alloc3 : memref<1600xi8> + return %15 : memref + + // CHECK-LABEL: test_dynamic_and_static_pool_bundling + // CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32 + // CHECK: [[C0:%.+]] = constant 0 : index + // CHECK: [[C4:%.+]] = constant 4 : index + // CHECK: [[C10:%.+]] = constant 10 : index + // CHECK: [[C400_I64:%.+]] = constant 400 : i64 + // CHECK: [[C2000_I64:%.+]] = constant 2000 : i64 + // CHECK: [[C0_I64:%.+]] = constant 0 : i64 + // CHECK: [[DIM:%.+]] = dim %arg0, [[C0]] : memref + // CHECK: [[SGT:%.+]] = cmpi "sgt", [[DIM]], [[DIM]] : index + // CHECK: [[SELECT:%.+]] = select [[SGT]], [[DIM]], [[DIM]] : index + // CHECK: [[MUL1:%.+]] = muli [[SELECT]], [[C4]] : index + // CHECK: [[OFFSET1:%.+]] = muli [[MUL1]], [[C10]] : index + // CHECK: [[MUL2:%.+]] = muli [[DIM]], [[C4]] : index + // CHECK: [[OFFSET2:%.+]] = muli [[MUL2]], [[C10]] : index + // CHECK: [[MEMPOOL_SIZE:%.+]] = addi [[OFFSET1]], [[OFFSET2]] : index + // CHECK: [[OFFSET1_I64:%.+]] = index_cast [[OFFSET1]] : index to i64 + // CHECK: [[DYN_MEMPOOL:%.+]] = alloc([[MEMPOOL_SIZE]]) : memref + // CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref, i64) -> memref + // CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[C0_I64]]) : (memref, i64) -> memref + // CHECK: [[STATIC_MEMPOOL:%.+]] = alloc() : memref<2800xi8> + // CHECK: [[DATA3:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C2000_I64]]) : (memref<2800xi8>, i64) -> memref<10x20xf32> + // CHECK: [[DATA4:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C400_I64]]) : (memref<2800xi8>, i64) -> memref<10x40xf32> + // CHECK: [[DATA5:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C0_I64]]) : (memref<2800xi8>, i64) -> memref<10x10xf32> + // CHECK: [[RES:%.+]] = alloc([[DIM]]) : memref + // CHECK: affine.store [[CST]], [[DATA1]][0, 0] : memref + // CHECK: affine.store [[CST]], [[DATA2]][0, 0] : memref + // CHECK: affine.store [[CST]], [[RES]][0, 0] : memref + // CHECK: affine.store [[CST]], [[DATA5]][0, 0] : memref<10x10xf32> + // CHECK: affine.store [[CST]], [[DATA3]][0, 0] : memref<10x20xf32> + // CHECK: affine.store [[CST]], [[DATA4]][0, 0] : memref<10x40xf32> + // CHECK: dealloc [[DYN_MEMPOOL]] : memref + // CHECK: dealloc [[STATIC_MEMPOOL]] : memref<2800xi8> + // CHECK: return [[RES]] : memref +}