diff --git a/src/Transform/BundleMemoryPools.cpp b/src/Transform/BundleMemoryPools.cpp index 9cbac6c..5a5db70 100644 --- a/src/Transform/BundleMemoryPools.cpp +++ b/src/Transform/BundleMemoryPools.cpp @@ -26,23 +26,12 @@ using namespace mlir; namespace { //===----------------------------------------------------------------------===// -// Insertion point for initialization instructions and the blocks used for -// inserting the initialization and main code. These blocks will disappear -// when the first canonicalization is performed because the init block -// unconditionally branches into the second block. These blocks exist only for -// the purpose of this optimization. -// The information is recorded on a per function basis. +// Data structures for managing memory pools. //===----------------------------------------------------------------------===// -typedef struct ONNXOperandsInitState { - Block *initBlock; - Block *mainBlock; - BranchOp branchInit; - AllocOp dynamicMemoryPool; -} ONNXOperandsInitState; - -// Helper data structure for the bundling of dynamic AllocOps. -std::map> initMap; +// Data structure for managing dyanmic memory pool. +typedef std::map BlockToDynamicPool; +std::map> dynamicPoolMap; // Handling of static memory pool on a block-basis in each function. typedef std::map BlockToStaticPool; @@ -63,65 +52,24 @@ FuncOp getContainingFunction(AllocOp op) { return cast(parentFuncOp); } -bool hasInitBlock(FuncOp function) { - std::unique_ptr &initState = initMap.at(function); - return initState->initBlock != nullptr; -} +// Check if this value is an argument of one of the blocks nested +// around it. +bool isBlockArgument(AllocOp allocOp, Value operand) { + // Parent operation of the current block. + Operation *parentBlockOp; + Block *currentBlock = allocOp.getOperation()->getBlock(); -bool addInitBlock(PatternRewriter &rewriter, Location loc, AllocOp allocOp) { - // If this is the first time we encounter an operation in this - // function, we create an entry inside the initMap and split the - // function body into an init block and a main block. - // - // function func_name() { - // ... init block ... - // br ^bb1 - // ^bb1: // pred: ^bb0 - // ... main block ... - // return - // } - // - // Note: the block ^bb0 being the first block has its label omitted. - // - FuncOp function = getContainingFunction(allocOp); - // If the function does not contain an init block, create one. - if (!hasInitBlock(function)) { - std::unique_ptr &initState = initMap.at(function); - initState = std::make_unique(); + do { + // Check the arguments of the current block. + for (auto arg : currentBlock->getArguments()) + if (operand == arg) + return true; - // All input arguments are considered as part of the initialization block - // so add them to the operandsInInitBlock set. - Block *functionBlock = &function.front(); - PatternRewriter::InsertionGuard insertGuard(rewriter); - rewriter.setInsertionPointToStart(functionBlock); + parentBlockOp = currentBlock->getParentOp(); + currentBlock = parentBlockOp->getBlock(); - initState->initBlock = rewriter.getInsertionBlock(); - auto currentPoint = rewriter.getInsertionPoint(); - initState->mainBlock = - rewriter.splitBlock(initState->initBlock, currentPoint); + } while (!llvm::dyn_cast_or_null(parentBlockOp)); - rewriter.setInsertionPointToEnd(initState->initBlock); - - // Insert a branch operation from initBlock to mainBlock. This - // ensures the final code contains legal blocks. - initState->branchInit = - rewriter.create(loc, initState->mainBlock); - - rewriter.setInsertionPointToStart(initState->mainBlock); - - // Save a reference to the current dynamic memory pool value. - initState->dynamicMemoryPool = allocOp; - - return true; - } - - return false; -} - -bool isBlockArgument(Block *block, Value operand) { - for (auto arg : block->getArguments()) - if (operand == arg) - return true; return false; } @@ -332,14 +280,17 @@ public: // Get function. FuncOp function = getContainingFunction(allocOp); - Block *firstBlock = &function.getBody().front(); - // If this is the alloc representing the memory pool and the function - // already has an init block, pattern matching must fail to avoid - // processing the dynamic memory pool a second time. - if (hasInitBlock(function)) { - std::unique_ptr &initState = initMap.at(function); - if (allocOp == initState->dynamicMemoryPool) + // Use function to retrieve the list of blocks for this function. + std::unique_ptr &blockToDynamicPool = + dynamicPoolMap.at(function); + + // If this is not the first time we process an alloc in this block, avoid + // processing the current dynamic memory pool again. + if (blockToDynamicPool->count(parentBlock) > 0) { + std::unique_ptr &blockToDynamicPool = + dynamicPoolMap.at(function); + if (allocOp == blockToDynamicPool->at(parentBlock)) return failure(); } @@ -365,11 +316,10 @@ public: dependentOps.insert(definingOperation); // Add operands to work queue. - // printf("Processing the args of the following op:\n"); for (const auto &operand : definingOperation->getOperands()) { // Check operand is not a block argument. If it is skip it, we // consider block arguments to be leafs. - if (!isBlockArgument(firstBlock, operand)) { + if (!isBlockArgument(allocOp, operand)) { operandList.emplace_back(operand); // Check if the current operation is a dim or a load and the @@ -416,15 +366,23 @@ public: if (dependentOps.count(&op) > 0) orderedDependentOps.emplace_back(&op); - // If no dynamic alloc is in the trace of the dependent operations, - // emit the size calculation in the init block, if one exists already, - // if not, create the init block. - bool addedInitBlock = addInitBlock(rewriter, loc, allocOp); + // If this is the first valid alloc we can bundle in this block, then we + // need to move it to the top of the block as it will consitute an + // insertion point for all other bundle-able AllocOps in the block. + bool isFirstBundledAllocOp = blockToDynamicPool->count(parentBlock) == 0; + if (isFirstBundledAllocOp) { + allocOp.getOperation()->moveBefore(&parentBlock->front()); - // Move the ordered dependent size calculation to the init block. - std::unique_ptr &initState = initMap.at(function); + // Create new entry in the block map. + blockToDynamicPool->insert( + std::pair(parentBlock, allocOp)); + } + + // Move the computation instructions at the start of the block. + AllocOp oldDynamicMemoryPool = blockToDynamicPool->at(parentBlock); + std::reverse(orderedDependentOps.begin(), orderedDependentOps.end()); for (auto &op : orderedDependentOps) - op->moveBefore(initState->branchInit); + op->moveBefore(&parentBlock->front()); // Bundle MemRef type: SmallVector memPoolShape; @@ -438,16 +396,16 @@ public: return failure(); // Add the current alloc size to the current MemPool size. - Value dynamicMemoryPoolSize = initState->dynamicMemoryPool.getOperand(0); - if (addedInitBlock) { + Value dynamicMemoryPoolSize = oldDynamicMemoryPool.getOperand(0); + if (isFirstBundledAllocOp) { Value zero = emitConstantOp(rewriter, loc, rewriter.getIndexType(), 0); - zero.getDefiningOp()->moveBefore(initState->branchInit); + zero.getDefiningOp()->moveBefore(oldDynamicMemoryPool); dynamicMemoryPoolSize = zero; } AddIOp bundledAllocOperand = rewriter.create( loc, dynamicMemoryPoolSize, allocOp.getOperand(0)); - bundledAllocOperand.getOperation()->moveBefore(initState->branchInit); + bundledAllocOperand.getOperation()->moveBefore(oldDynamicMemoryPool); // The newly bundled MemRef expressed as a KrnlGetRefOp. // Current memory pool size is the offset for the newly bundled @@ -455,26 +413,27 @@ public: Value integerDynamicMemoryPoolSize = rewriter.create( loc, dynamicMemoryPoolSize, rewriter.getIntegerType(64)); integerDynamicMemoryPoolSize.getDefiningOp()->moveBefore( - initState->branchInit); + oldDynamicMemoryPool); // We need to emit a new alloc which contains the additional MemRef. AllocOp bundledAlloc = rewriter.create( loc, bundledMemPoolMemRefType, bundledAllocOperand.getResult()); - bundledAlloc.getOperation()->moveBefore(&initState->mainBlock->front()); + bundledAlloc.getOperation()->moveBefore(oldDynamicMemoryPool); KrnlGetRefOp bundledMemRef = rewriter.create(loc, currentAllocGetRef.getResult().getType(), bundledAlloc, integerDynamicMemoryPoolSize); - bundledMemRef.getOperation()->moveAfter(bundledAlloc); // Replace old memory pool with new one. - rewriter.replaceOp(initState->dynamicMemoryPool, bundledAlloc.getResult()); + rewriter.replaceOp(oldDynamicMemoryPool, bundledAlloc.getResult()); // Replace old getref with new getref from new memory pool. rewriter.replaceOp(currentAllocGetRef, bundledMemRef.getResult()); - // Update MemPool size. - initState->dynamicMemoryPool = bundledAlloc; + // Update MemPool data structure. + blockToDynamicPool->erase(parentBlock); + blockToDynamicPool->insert( + std::pair(parentBlock, bundledAlloc)); return success(); } @@ -488,16 +447,14 @@ class KrnlBundleMemoryPoolsPass public: void runOnFunction() override { auto function = getFunction(); - initMap.insert(std::pair>( - function, std::make_unique())); + + dynamicPoolMap.insert( + std::pair>( + function, std::make_unique())); staticPoolMap.insert(std::pair>( function, std::make_unique())); - // Initialize state for this function. - std::unique_ptr &initState = initMap.at(function); - initState->initBlock = nullptr; - ConversionTarget target(getContext()); OwningRewritePatternList patterns; patterns.insert( @@ -505,7 +462,7 @@ public: applyPatternsAndFoldGreedily(function, patterns); - initMap.erase(function); + dynamicPoolMap.erase(function); staticPoolMap.erase(function); } }; diff --git a/test/mlir/krnl/krnl_bundle_memory_pool.mlir b/test/mlir/krnl/krnl_bundle_memory_pool.mlir index 196e6b4..33859d8 100644 --- a/test/mlir/krnl/krnl_bundle_memory_pool.mlir +++ b/test/mlir/krnl/krnl_bundle_memory_pool.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --bundle-memory-pools --canonicalize %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --bundle-memory-pools --canonicalize %s | FileCheck %s func @test_pool_bundling(%arg0: memref<10x10xf32>, %arg1: memref<10x20xf32>) -> memref<10x20xf32> { %c0_i64 = constant 0 : i64 @@ -88,17 +88,17 @@ func @test_dynamic_pool_bundling(%arg0: memref) -> memref { // CHECK: [[C10:%.+]] = constant 10 : index // CHECK: [[C0_I64:%.+]] = constant 0 : i64 // CHECK: [[DIM:%.+]] = dim %arg0, [[C0]] : memref + // CHECK: [[MUL2:%.+]] = muli [[DIM]], [[C4]] : index + // CHECK: [[OFFSET2:%.+]] = muli [[MUL2]], [[C10]] : index // CHECK: [[SGT:%.+]] = cmpi "sgt", [[DIM]], [[DIM]] : index // CHECK: [[SELECT:%.+]] = select [[SGT]], [[DIM]], [[DIM]] : index // CHECK: [[MUL1:%.+]] = muli [[SELECT]], [[C4]] : index // CHECK: [[OFFSET1:%.+]] = muli [[MUL1]], [[C10]] : index - // CHECK: [[MUL2:%.+]] = muli [[DIM]], [[C4]] : index - // CHECK: [[OFFSET2:%.+]] = muli [[MUL2]], [[C10]] : index // CHECK: [[MEMPOOL_SIZE:%.+]] = addi [[OFFSET1]], [[OFFSET2]] : index // CHECK: [[OFFSET1_I64:%.+]] = index_cast [[OFFSET1]] : index to i64 // CHECK: [[DYN_MEMPOOL:%.+]] = alloc([[MEMPOOL_SIZE]]) : memref - // CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref, i64) -> memref // CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[C0_I64]]) : (memref, i64) -> memref + // CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref, i64) -> memref // CHECK: [[RES:%.+]] = alloc([[DIM]]) : memref // CHECK: affine.store [[CST]], [[DATA1]][0, 0] : memref // CHECK: affine.store [[CST]], [[DATA2]][0, 0] : memref @@ -149,36 +149,36 @@ func @test_dynamic_and_static_pool_bundling(%arg0: memref, %arg1: memre return %15 : memref // CHECK-LABEL: test_dynamic_and_static_pool_bundling - // CHECK: [[C1200_I64:%.+]] = constant 1200 : i64 // CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32 // CHECK: [[C0:%.+]] = constant 0 : index // CHECK: [[C4:%.+]] = constant 4 : index // CHECK: [[C10:%.+]] = constant 10 : index - // CHECK: [[C400_I64:%.+]] = constant 400 : i64 + // CHECK: [[C2000_I64:%.+]] = constant 2000 : i64 + // CHECK: [[C1600_I64:%.+]] = constant 1600 : i64 // CHECK: [[C0_I64:%.+]] = constant 0 : i64 // CHECK: [[DIM:%.+]] = dim %arg0, [[C0]] : memref + // CHECK: [[MUL2:%.+]] = muli [[DIM]], [[C4]] : index + // CHECK: [[OFFSET2:%.+]] = muli [[MUL2]], [[C10]] : index // CHECK: [[SGT:%.+]] = cmpi "sgt", [[DIM]], [[DIM]] : index // CHECK: [[SELECT:%.+]] = select [[SGT]], [[DIM]], [[DIM]] : index // CHECK: [[MUL1:%.+]] = muli [[SELECT]], [[C4]] : index // CHECK: [[OFFSET1:%.+]] = muli [[MUL1]], [[C10]] : index - // CHECK: [[MUL2:%.+]] = muli [[DIM]], [[C4]] : index - // CHECK: [[OFFSET2:%.+]] = muli [[MUL2]], [[C10]] : index // CHECK: [[MEMPOOL_SIZE:%.+]] = addi [[OFFSET1]], [[OFFSET2]] : index // CHECK: [[OFFSET1_I64:%.+]] = index_cast [[OFFSET1]] : index to i64 // CHECK: [[DYN_MEMPOOL:%.+]] = alloc([[MEMPOOL_SIZE]]) : memref - // CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref, i64) -> memref // CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[C0_I64]]) : (memref, i64) -> memref + // CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref, i64) -> memref // CHECK: [[STATIC_MEMPOOL:%.+]] = alloc() : memref<2800xi8> - // CHECK: [[DATA3:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C1200_I64]]) : (memref<2800xi8>, i64) -> memref<10x40xf32> - // CHECK: [[DATA4:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C400_I64]]) : (memref<2800xi8>, i64) -> memref<10x20xf32> - // CHECK: [[DATA5:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C0_I64]]) : (memref<2800xi8>, i64) -> memref<10x10xf32> + // CHECK: [[DATA3:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C2000_I64]]) : (memref<2800xi8>, i64) -> memref<10x20xf32> + // CHECK: [[DATA4:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C1600_I64]]) : (memref<2800xi8>, i64) -> memref<10x10xf32> // CHECK: [[RES:%.+]] = alloc([[DIM]]) : memref + // CHECK: [[DATA5:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C0_I64]]) : (memref<2800xi8>, i64) -> memref<10x40xf32> // CHECK: affine.store [[CST]], [[DATA1]][0, 0] : memref // CHECK: affine.store [[CST]], [[DATA2]][0, 0] : memref // CHECK: affine.store [[CST]], [[RES]][0, 0] : memref - // CHECK: affine.store [[CST]], [[DATA5]][0, 0] : memref<10x10xf32> - // CHECK: affine.store [[CST]], [[DATA4]][0, 0] : memref<10x20xf32> - // CHECK: affine.store [[CST]], [[DATA3]][0, 0] : memref<10x40xf32> + // CHECK: affine.store [[CST]], [[DATA4]][0, 0] : memref<10x10xf32> + // CHECK: affine.store [[CST]], [[DATA3]][0, 0] : memref<10x20xf32> + // CHECK: affine.store [[CST]], [[DATA5]][0, 0] : memref<10x40xf32> // CHECK: dealloc [[DYN_MEMPOOL]] : memref // CHECK: dealloc [[STATIC_MEMPOOL]] : memref<2800xi8> // CHECK: return [[RES]] : memref @@ -365,3 +365,97 @@ func @static_mem_pool_rnn_sub_and_main_block(%arg0: memref<1x3x2xf32>, %arg1: me // CHECK: dealloc [[STATIC_MEM_POOL_MAIN]] : memref<12xi8> // CHECK: return [[RES]] : memref<1x3x4xf32> } + +/// Test dynamic pooling in sub-block. +func @test_dynamic_pool_rnn(%arg0: memref<1x3x2xf32>, %arg1: memref<1x4x2xf32>, %arg2: memref<1x?x?xf32>) -> memref<1x3x?xf32> attributes {input_names = ["X", "W", "R"], output_names = ["Y"]} { + %cst = constant 0.000000e+00 : f32 + %c0_i64 = constant 0 : i64 + %c_ind = constant 1 : index + %dim0 = dim %arg1, %c_ind : memref<1x4x2xf32> + %0 = alloc(%dim0) : memref<1x3x?xf32> + %1:3 = krnl.define_loops 3 + krnl.iterate(%1#0, %1#1, %1#2) with (%1#0 -> %arg3 = 0 to 1, %1#1 -> %arg4 = 0 to 3, %1#2 -> %arg5 = 0 to 4) { + affine.store %cst, %0[symbol(%arg3), symbol(%arg4), symbol(%arg5)] : memref<1x3x?xf32> + } + %2 = krnl.define_loops 1 + krnl.iterate(%2) with (%2 -> %arg3 = 0 to 1) { + %3:2 = krnl.define_loops 2 + krnl.iterate(%3#0, %3#1) with (%3#0 -> %arg4 = 0 to 3, %3#1 -> %arg5 = 0 to 4) { + %dim1 = dim %arg2, %c_ind : memref<1x?x?xf32> + %4 = alloc(%dim1) : memref + %5 = "krnl.getref"(%4, %c0_i64) : (memref, i64) -> memref + %6 = affine.load %0[0, symbol(%arg4), symbol(%arg5)] : memref<1x3x?xf32> + %7 = alloc(%dim1) : memref + %8 = "krnl.getref"(%7, %c0_i64) : (memref, i64) -> memref + affine.store %cst, %8[] : memref + %c_ind_2 = constant 2 : index + %dim2 = dim %arg2, %c_ind_2 : memref<1x?x?xf32> + %9 = alloc(%dim2) : memref + %10 = "krnl.getref"(%9, %c0_i64) : (memref, i64) -> memref + affine.store %cst, %10[] : memref + %11 = krnl.define_loops 1 + krnl.iterate(%11) with (%11 -> %arg6 = 0 to 2) { + %25 = affine.load %arg0[symbol(%arg3), symbol(%arg4), symbol(%arg6)] : memref<1x3x2xf32> + %26 = affine.load %arg1[0, symbol(%arg5), symbol(%arg6)] : memref<1x4x2xf32> + %27 = mulf %25, %26 : f32 + %28 = affine.load %8[] : memref + %29 = addf %28, %27 : f32 + affine.store %29, %8[] : memref + %30 = affine.load %arg2[0, symbol(%arg5), symbol(%arg6)] : memref<1x?x?xf32> + %31 = mulf %6, %30 : f32 + %32 = affine.load %10[] : memref + %33 = addf %32, %31 : f32 + affine.store %33, %10[] : memref + } + %12 = affine.load %8[] : memref + %13 = affine.load %10[] : memref + %14 = addf %12, %13 : f32 + %15 = alloc(%dim2) : memref + %16 = "krnl.getref"(%15, %c0_i64) : (memref, i64) -> memref + affine.store %14, %16[] : memref + %17 = affine.load %16[] : memref + %18 = subf %cst, %17 : f32 + %19 = exp %17 : f32 + %20 = exp %18 : f32 + %21 = subf %19, %20 : f32 + %22 = addf %19, %20 : f32 + %23 = divf %21, %22 : f32 + affine.store %23, %5[] : memref + %24 = affine.load %5[] : memref + affine.store %24, %0[0, symbol(%arg4), symbol(%arg5)] : memref<1x3x?xf32> + dealloc %15 : memref + dealloc %9 : memref + dealloc %7 : memref + dealloc %4 : memref + } + } + return %0 : memref<1x3x?xf32> + + // CHECK-LABEL: test_dynamic_pool_rnn + // CHECK: [[C1:%.+]] = constant 1 : index + // CHECK: [[C2:%.+]] = constant 2 : index + // CHECK: [[C0:%.+]] = constant 0 : i64 + // CHECK: krnl.define_loops 1 + // CHECK: krnl.iterate + // CHECK: krnl.define_loops 2 + // CHECK: krnl.iterate + // CHECK: [[DIM1:%.+]] = dim %arg2, [[C1]] : memref<1x?x?xf32> + // CHECK: [[DIM2:%.+]] = dim %arg2, [[C2]] : memref<1x?x?xf32> + // CHECK: [[ADD1:%.+]] = addi [[DIM2]], [[DIM2]] : index + // CHECK: [[OFFSET1:%.+]] = index_cast [[DIM2]] : index to i64 + // CHECK: [[ADD2:%.+]] = addi [[ADD1]], [[DIM1]] : index + // CHECK: [[OFFSET2:%.+]] = index_cast [[ADD1]] : index to i64 + // CHECK: [[ADD3:%.+]] = addi [[ADD2]], [[DIM1]] : index + // CHECK: [[OFFSET3:%.+]] = index_cast [[ADD2]] : index to i64 + // CHECK: [[DYNAMIC_MEMORY_POOL:%.+]] = alloc([[ADD3]]) : memref + // CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYNAMIC_MEMORY_POOL]], [[C0]]) : (memref, i64) -> memref + // CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYNAMIC_MEMORY_POOL]], [[OFFSET3]]) : (memref, i64) -> memref + // CHECK: affine.load + // CHECK: [[DATA3:%.+]] = "krnl.getref"([[DYNAMIC_MEMORY_POOL]], [[OFFSET2]]) : (memref, i64) -> memref + // CHECK: affine.store + // CHECK: [[DATA4:%.+]] = "krnl.getref"([[DYNAMIC_MEMORY_POOL]], [[OFFSET1]]) : (memref, i64) -> memref + // CHECK: affine.store + // CHECK: krnl.define_loops + // CHECK: krnl.iterate + // CHECK: dealloc [[DYNAMIC_MEMORY_POOL]] : memref +}