Emit the dynamic memory pool (#290)
* Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Add support for bundling dynamic memory pools. * Add dynamic bundling. * Clean-up code. * Clean-up file. * Add test for bundling dynamic memory pool. * Fixes. Simplify data structure. Add mixed test. * Remove unused import.
This commit is contained in:
		
							parent
							
								
									930e20f682
								
							
						
					
					
						commit
						9f69b2f317
					
				|  | @ -15,6 +15,7 @@ | ||||||
| #include "mlir/Dialect/StandardOps/IR/Ops.h" | #include "mlir/Dialect/StandardOps/IR/Ops.h" | ||||||
| #include "mlir/Pass/Pass.h" | #include "mlir/Pass/Pass.h" | ||||||
| #include "mlir/Transforms/DialectConversion.h" | #include "mlir/Transforms/DialectConversion.h" | ||||||
|  | #include "llvm/ADT/SetVector.h" | ||||||
| 
 | 
 | ||||||
| #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" | #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" | ||||||
| #include "src/Dialect/Krnl/KrnlOps.hpp" | #include "src/Dialect/Krnl/KrnlOps.hpp" | ||||||
|  | @ -24,6 +25,102 @@ using namespace mlir; | ||||||
| 
 | 
 | ||||||
| namespace { | namespace { | ||||||
| 
 | 
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | // Insertion point for initialization instructions and the blocks used for
 | ||||||
|  | // inserting the initialization and main code. These blocks will disappear
 | ||||||
|  | // when the first canonicalization is performed because the init block
 | ||||||
|  | // unconditionally branches into the second block. These blocks exist only for
 | ||||||
|  | // the purpose of this optimization.
 | ||||||
|  | // The information is recorded on a per function basis.
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | 
 | ||||||
|  | typedef struct ONNXOperandsInitState { | ||||||
|  |   Block *initBlock; | ||||||
|  |   Block *mainBlock; | ||||||
|  |   BranchOp branchInit; | ||||||
|  |   AllocOp dynamicMemoryPool; | ||||||
|  | } ONNXOperandsInitState; | ||||||
|  | 
 | ||||||
|  | // Helper data structure for the bundling of dynamic AllocOps.
 | ||||||
|  | std::map<FuncOp, std::unique_ptr<ONNXOperandsInitState>> initMap; | ||||||
|  | 
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | // Helper functions.
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | 
 | ||||||
|  | /// Retrieve function which contains the current operation.
 | ||||||
|  | FuncOp getContainingFunction(AllocOp op) { | ||||||
|  |   Operation *parentFuncOp = op.getParentOp(); | ||||||
|  | 
 | ||||||
|  |   // While parent is not a FuncOp and its cast to a FuncOp is null.
 | ||||||
|  |   while (!llvm::dyn_cast_or_null<FuncOp>(parentFuncOp)) | ||||||
|  |     parentFuncOp = parentFuncOp->getParentOp(); | ||||||
|  | 
 | ||||||
|  |   return cast<FuncOp>(parentFuncOp); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | bool hasInitBlock(FuncOp function) { | ||||||
|  |   std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function); | ||||||
|  |   return initState->initBlock != nullptr; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | bool addInitBlock(PatternRewriter &rewriter, Location loc, AllocOp allocOp) { | ||||||
|  |   // If this is the first time we encounter an operation in this
 | ||||||
|  |   // function, we create an entry inside the initMap and split the
 | ||||||
|  |   // function body into an init block and a main block.
 | ||||||
|  |   //
 | ||||||
|  |   // function func_name() {
 | ||||||
|  |   //    ... init block ...
 | ||||||
|  |   //    br ^bb1
 | ||||||
|  |   //  ^bb1:  // pred: ^bb0
 | ||||||
|  |   //    ... main block ...
 | ||||||
|  |   //    return
 | ||||||
|  |   // }
 | ||||||
|  |   //
 | ||||||
|  |   // Note: the block ^bb0 being the first block has its label omitted.
 | ||||||
|  |   //
 | ||||||
|  |   FuncOp function = getContainingFunction(allocOp); | ||||||
|  |   // If the function does not contain an init block, create one.
 | ||||||
|  |   if (!hasInitBlock(function)) { | ||||||
|  |     std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function); | ||||||
|  |     initState = std::make_unique<ONNXOperandsInitState>(); | ||||||
|  | 
 | ||||||
|  |     // All input arguments are considered as part of the initialization block
 | ||||||
|  |     // so add them to the operandsInInitBlock set.
 | ||||||
|  |     Block *functionBlock = &function.front(); | ||||||
|  |     PatternRewriter::InsertionGuard insertGuard(rewriter); | ||||||
|  |     rewriter.setInsertionPointToStart(functionBlock); | ||||||
|  | 
 | ||||||
|  |     initState->initBlock = rewriter.getInsertionBlock(); | ||||||
|  |     auto currentPoint = rewriter.getInsertionPoint(); | ||||||
|  |     initState->mainBlock = | ||||||
|  |         rewriter.splitBlock(initState->initBlock, currentPoint); | ||||||
|  | 
 | ||||||
|  |     rewriter.setInsertionPointToEnd(initState->initBlock); | ||||||
|  | 
 | ||||||
|  |     // Insert a branch operation from initBlock to mainBlock. This
 | ||||||
|  |     // ensures the final code contains legal blocks.
 | ||||||
|  |     initState->branchInit = | ||||||
|  |         rewriter.create<BranchOp>(loc, initState->mainBlock); | ||||||
|  | 
 | ||||||
|  |     rewriter.setInsertionPointToStart(initState->mainBlock); | ||||||
|  | 
 | ||||||
|  |     // Save a reference to the current dynamic memory pool value.
 | ||||||
|  |     initState->dynamicMemoryPool = allocOp; | ||||||
|  | 
 | ||||||
|  |     return true; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   return false; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | bool isBlockArgument(Block *block, Value operand) { | ||||||
|  |   for (auto arg : block->getArguments()) | ||||||
|  |     if (operand == arg) | ||||||
|  |       return true; | ||||||
|  |   return false; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| KrnlGetRefOp getUnbundledGetRef(AllocOp *memPool) { | KrnlGetRefOp getUnbundledGetRef(AllocOp *memPool) { | ||||||
|   auto parentBlock = memPool->getOperation()->getBlock(); |   auto parentBlock = memPool->getOperation()->getBlock(); | ||||||
| 
 | 
 | ||||||
|  | @ -37,6 +134,23 @@ KrnlGetRefOp getUnbundledGetRef(AllocOp *memPool) { | ||||||
|   return unbundledGetRef; |   return unbundledGetRef; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | KrnlGetRefOp getCurrentAllocGetRef(AllocOp *allocOp) { | ||||||
|  |   auto parentBlock = allocOp->getOperation()->getBlock(); | ||||||
|  | 
 | ||||||
|  |   KrnlGetRefOp currentAllocGetRef = nullptr; | ||||||
|  |   parentBlock->walk([¤tAllocGetRef, allocOp](KrnlGetRefOp op) { | ||||||
|  |     auto result = allocOp->getResult(); | ||||||
|  |     if (op.getOperands()[0] == result) | ||||||
|  |       currentAllocGetRef = op; | ||||||
|  |   }); | ||||||
|  | 
 | ||||||
|  |   return currentAllocGetRef; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | // Rewrite patterns.
 | ||||||
|  | //===----------------------------------------------------------------------===//
 | ||||||
|  | 
 | ||||||
| /*!
 | /*!
 | ||||||
|  *  RewritePattern that replaces: |  *  RewritePattern that replaces: | ||||||
|  *    %mem1 = alloc() : memref<<dims1>x<type>> |  *    %mem1 = alloc() : memref<<dims1>x<type>> | ||||||
|  | @ -73,7 +187,7 @@ public: | ||||||
|     if (!checkOpResultIsUsedByGetRef(&allocOp)) |     if (!checkOpResultIsUsedByGetRef(&allocOp)) | ||||||
|       return failure(); |       return failure(); | ||||||
| 
 | 
 | ||||||
|     // TODO: remove once we support the bundling of dynamic memory pools.
 |     // Only handle constant AllocOps.
 | ||||||
|     if (!hasAllConstantDimensions(memRefType)) |     if (!hasAllConstantDimensions(memRefType)) | ||||||
|       return failure(); |       return failure(); | ||||||
| 
 | 
 | ||||||
|  | @ -85,12 +199,16 @@ public: | ||||||
|     if (memRefShape.size() != 1) |     if (memRefShape.size() != 1) | ||||||
|       return failure(); |       return failure(); | ||||||
| 
 | 
 | ||||||
|     // TODO: Change this when dyanmic shapes are supported.
 |  | ||||||
|     // TODO: Add support for dynamic shapes.
 |  | ||||||
|     int64_t currentMemPoolSize = memRefShape[0]; |     int64_t currentMemPoolSize = memRefShape[0]; | ||||||
| 
 | 
 | ||||||
|     // Get a KrnlGetRefOp which does not use the current alloc.
 |     // Get a KrnlGetRefOp which does not use the current alloc.
 | ||||||
|     if (KrnlGetRefOp unbundledGetRef = getUnbundledGetRef(&allocOp)) { |     if (KrnlGetRefOp unbundledGetRef = getUnbundledGetRef(&allocOp)) { | ||||||
|  |       // Make sure that this get ref uses a static alloc.
 | ||||||
|  |       auto unbundledGetRefType = | ||||||
|  |           convertToMemRefType(unbundledGetRef.getResult().getType()); | ||||||
|  |       if (!hasAllConstantDimensions(unbundledGetRefType)) | ||||||
|  |         return failure(); | ||||||
|  | 
 | ||||||
|       // Current memory pool size is the offset for the newly bundled
 |       // Current memory pool size is the offset for the newly bundled
 | ||||||
|       // internal MemRef. Emit the offset as a constant.
 |       // internal MemRef. Emit the offset as a constant.
 | ||||||
|       auto offset = rewriter.create<ConstantOp>( |       auto offset = rewriter.create<ConstantOp>( | ||||||
|  | @ -127,6 +245,201 @@ public: | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | /*!
 | ||||||
|  |  *  RewritePattern that merges a new dynamic AllocOp with the existing dynamic | ||||||
|  |  *  memory pool. | ||||||
|  |  *    %dyn_mempool = alloc(%a) : memref<?xi8> | ||||||
|  |  *    %new_alloc = alloc(%b) : memref<?xi8> | ||||||
|  |  *    %new_ref = krnl.getref %new_alloc 0 : memref<?xi8> | ||||||
|  |  *  => | ||||||
|  |  *    %c = addi %a, %b | ||||||
|  |  *    %dyn_mempool = alloc(%c) : memref<?xi8> | ||||||
|  |  *    %new_ref = krnl.getref %dyn_mempool %a : memref<?xi8> | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | class KrnlBundleDynamicMemoryPools : public OpRewritePattern<AllocOp> { | ||||||
|  | public: | ||||||
|  |   using OpRewritePattern<AllocOp>::OpRewritePattern; | ||||||
|  | 
 | ||||||
|  |   LogicalResult matchAndRewrite( | ||||||
|  |       AllocOp allocOp, PatternRewriter &rewriter) const override { | ||||||
|  |     auto loc = allocOp.getLoc(); | ||||||
|  | 
 | ||||||
|  |     auto memRefType = convertToMemRefType(allocOp.getResult().getType()); | ||||||
|  |     auto memRefShape = memRefType.getShape(); | ||||||
|  | 
 | ||||||
|  |     // If alloca result is not used by getref then it cannot be part of
 | ||||||
|  |     // the memory pool.
 | ||||||
|  |     if (!checkOpResultIsUsedByGetRef(&allocOp)) | ||||||
|  |       return failure(); | ||||||
|  | 
 | ||||||
|  |     // Only handle dynamic allocs here.
 | ||||||
|  |     if (hasAllConstantDimensions(memRefType)) | ||||||
|  |       return failure(); | ||||||
|  | 
 | ||||||
|  |     // Alloc memory type must be byte.
 | ||||||
|  |     if (getMemRefEltSizeInBytes(memRefType) != 1) | ||||||
|  |       return failure(); | ||||||
|  | 
 | ||||||
|  |     // Rank of the allocated MemRef must be 1.
 | ||||||
|  |     if (memRefShape.size() != 1) | ||||||
|  |       return failure(); | ||||||
|  | 
 | ||||||
|  |     // Visit dependendent operations in the current parent block and assemble
 | ||||||
|  |     // a trace of operations which participate in the computation of the size
 | ||||||
|  |     // of the AllocOp.
 | ||||||
|  |     auto parentBlock = allocOp.getOperation()->getBlock(); | ||||||
|  | 
 | ||||||
|  |     // Get function.
 | ||||||
|  |     FuncOp function = getContainingFunction(allocOp); | ||||||
|  |     Block *firstBlock = &function.getBody().front(); | ||||||
|  | 
 | ||||||
|  |     // If this is the alloc representing the memory pool and the function
 | ||||||
|  |     // already has an init block, pattern matching must fail to avoid
 | ||||||
|  |     // processing the dynamic memory pool a second time.
 | ||||||
|  |     if (hasInitBlock(function)) { | ||||||
|  |       std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function); | ||||||
|  |       if (allocOp == initState->dynamicMemoryPool) | ||||||
|  |         return failure(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Initialize work queue data structure.
 | ||||||
|  |     Operation *op = allocOp.getOperation(); | ||||||
|  |     std::vector<Value> operandList; | ||||||
|  |     for (const auto &operand : allocOp.getOperands()) { | ||||||
|  |       operandList.emplace_back(operand); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Check if list of operations depends on dynamic local AllocOp.
 | ||||||
|  |     bool dependsOnLocalDynamicAlloc = false; | ||||||
|  | 
 | ||||||
|  |     // Construct the list of Values on which the current AllocOp depends on.
 | ||||||
|  |     llvm::SetVector<Operation *> dependentOps; | ||||||
|  |     while (operandList.size() > 0) { | ||||||
|  |       Value currentElement = operandList[0]; | ||||||
|  |       Operation *definingOperation = currentElement.getDefiningOp(); | ||||||
|  | 
 | ||||||
|  |       // If this value has not been seen before, process it.
 | ||||||
|  |       if (dependentOps.count(definingOperation) == 0) { | ||||||
|  |         // Add value to dependent values list.
 | ||||||
|  |         dependentOps.insert(definingOperation); | ||||||
|  | 
 | ||||||
|  |         // Add operands to work queue.
 | ||||||
|  |         // printf("Processing the args of the following op:\n");
 | ||||||
|  |         for (const auto &operand : definingOperation->getOperands()) { | ||||||
|  |           // Check operand is not a block argument. If it is skip it, we
 | ||||||
|  |           // consider block arguments to be leafs.
 | ||||||
|  |           if (!isBlockArgument(firstBlock, operand)) { | ||||||
|  |             operandList.emplace_back(operand); | ||||||
|  | 
 | ||||||
|  |             // Check if the current operation is a dim or a load and the
 | ||||||
|  |             // argument list involves a local AllocOp with dynamic sizes.
 | ||||||
|  |             // If that's the case then it means that the whole set of
 | ||||||
|  |             // instructions cannot be moved.
 | ||||||
|  |             // Check if the current operation is a DimOp or a LoadOp.
 | ||||||
|  |             if (llvm::dyn_cast<DimOp>(definingOperation) || | ||||||
|  |                 llvm::dyn_cast<LoadOp>(definingOperation)) { | ||||||
|  |               Operation *operandOp = operand.getDefiningOp(); | ||||||
|  |               if (operandOp) { | ||||||
|  |                 auto localAlloc = llvm::dyn_cast<AllocOp>(operandOp); | ||||||
|  |                 if (localAlloc) { | ||||||
|  |                   auto memRefType = | ||||||
|  |                       convertToMemRefType(localAlloc.getResult().getType()); | ||||||
|  |                   if (!hasAllConstantDimensions(memRefType)) | ||||||
|  |                     dependsOnLocalDynamicAlloc = true; | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 // If operand is a getref then this alloc cannot be bundled.
 | ||||||
|  |                 auto memPool = llvm::dyn_cast<KrnlGetRefOp>(operandOp); | ||||||
|  |                 if (memPool) | ||||||
|  |                   dependsOnLocalDynamicAlloc = true; | ||||||
|  |               } | ||||||
|  |             } | ||||||
|  |           } | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  | 
 | ||||||
|  |       // Erase first element from work queue.
 | ||||||
|  |       operandList.erase(operandList.begin()); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if (dependsOnLocalDynamicAlloc) | ||||||
|  |       return failure(); | ||||||
|  | 
 | ||||||
|  |     // Order the dependent values in the same order they appear in the code.
 | ||||||
|  |     // One cannot iterate over and make changes to the order of the operations
 | ||||||
|  |     // of a block. A temporary ordered list of dependent instructions is
 | ||||||
|  |     // necessary.
 | ||||||
|  |     llvm::SmallVector<Operation *, 32> orderedDependentOps; | ||||||
|  |     for (auto &op : | ||||||
|  |         llvm::make_range(parentBlock->begin(), std::prev(parentBlock->end()))) | ||||||
|  |       if (dependentOps.count(&op) > 0) | ||||||
|  |         orderedDependentOps.emplace_back(&op); | ||||||
|  | 
 | ||||||
|  |     // If no dynamic alloc is in the trace of the dependent operations,
 | ||||||
|  |     // emit the size calculation in the init block, if one exists already,
 | ||||||
|  |     // if not, create the init block.
 | ||||||
|  |     bool addedInitBlock = addInitBlock(rewriter, loc, allocOp); | ||||||
|  | 
 | ||||||
|  |     // Move the ordered dependent size calculation to the init block.
 | ||||||
|  |     std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function); | ||||||
|  |     for (auto &op : orderedDependentOps) | ||||||
|  |       op->moveBefore(initState->branchInit); | ||||||
|  | 
 | ||||||
|  |     // Bundle MemRef type: <?xi8>
 | ||||||
|  |     SmallVector<int64_t, 1> memPoolShape; | ||||||
|  |     memPoolShape.emplace_back(-1); | ||||||
|  |     auto bundledMemPoolMemRefType = | ||||||
|  |         MemRefType::get(memPoolShape, rewriter.getIntegerType(8)); | ||||||
|  | 
 | ||||||
|  |     // Get the getref of the current allocOp. There is exactly one such getref.
 | ||||||
|  |     KrnlGetRefOp currentAllocGetRef = getCurrentAllocGetRef(&allocOp); | ||||||
|  |     if (!currentAllocGetRef) | ||||||
|  |       return failure(); | ||||||
|  | 
 | ||||||
|  |     // Add the current alloc size to the current MemPool size.
 | ||||||
|  |     Value dynamicMemoryPoolSize = initState->dynamicMemoryPool.getOperand(0); | ||||||
|  |     if (addedInitBlock) { | ||||||
|  |       Value zero = emitConstantOp(rewriter, loc, rewriter.getIndexType(), 0); | ||||||
|  |       zero.getDefiningOp()->moveBefore(initState->branchInit); | ||||||
|  |       dynamicMemoryPoolSize = zero; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     AddIOp bundledAllocOperand = rewriter.create<AddIOp>( | ||||||
|  |         loc, dynamicMemoryPoolSize, allocOp.getOperand(0)); | ||||||
|  |     bundledAllocOperand.getOperation()->moveBefore(initState->branchInit); | ||||||
|  | 
 | ||||||
|  |     // The newly bundled MemRef expressed as a KrnlGetRefOp.
 | ||||||
|  |     // Current memory pool size is the offset for the newly bundled
 | ||||||
|  |     // internal MemRef.
 | ||||||
|  |     Value integerDynamicMemoryPoolSize = rewriter.create<IndexCastOp>( | ||||||
|  |         loc, dynamicMemoryPoolSize, rewriter.getIntegerType(64)); | ||||||
|  |     integerDynamicMemoryPoolSize.getDefiningOp()->moveBefore( | ||||||
|  |         initState->branchInit); | ||||||
|  | 
 | ||||||
|  |     // We need to emit a new alloc which contains the additional MemRef.
 | ||||||
|  |     AllocOp bundledAlloc = rewriter.create<AllocOp>( | ||||||
|  |         loc, bundledMemPoolMemRefType, bundledAllocOperand.getResult()); | ||||||
|  |     bundledAlloc.getOperation()->moveBefore(&initState->mainBlock->front()); | ||||||
|  | 
 | ||||||
|  |     KrnlGetRefOp bundledMemRef = rewriter.create<KrnlGetRefOp>(loc, | ||||||
|  |         currentAllocGetRef.getResult().getType(), bundledAlloc, | ||||||
|  |         integerDynamicMemoryPoolSize); | ||||||
|  |     bundledMemRef.getOperation()->moveAfter(bundledAlloc); | ||||||
|  | 
 | ||||||
|  |     // Replace old memory pool with new one.
 | ||||||
|  |     rewriter.replaceOp(initState->dynamicMemoryPool, bundledAlloc.getResult()); | ||||||
|  | 
 | ||||||
|  |     // Replace old getref with new getref from new memory pool.
 | ||||||
|  |     rewriter.replaceOp(currentAllocGetRef, bundledMemRef.getResult()); | ||||||
|  | 
 | ||||||
|  |     // Update MemPool size.
 | ||||||
|  |     initState->dynamicMemoryPool = bundledAlloc; | ||||||
|  | 
 | ||||||
|  |     return success(); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
| /*!
 | /*!
 | ||||||
|  *  Function pass that enables memory pooling for MemRefs. |  *  Function pass that enables memory pooling for MemRefs. | ||||||
|  */ |  */ | ||||||
|  | @ -136,11 +449,22 @@ public: | ||||||
|   void runOnFunction() override { |   void runOnFunction() override { | ||||||
|     auto function = getFunction(); |     auto function = getFunction(); | ||||||
| 
 | 
 | ||||||
|  |     // ModuleOp module = cast<ModuleOp>(function.getParentOp());
 | ||||||
|  |     initMap.insert(std::pair<FuncOp, std::unique_ptr<ONNXOperandsInitState>>( | ||||||
|  |         function, std::make_unique<ONNXOperandsInitState>())); | ||||||
|  | 
 | ||||||
|  |     // Initialize state for this function.
 | ||||||
|  |     std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function); | ||||||
|  |     initState->initBlock = nullptr; | ||||||
|  | 
 | ||||||
|     ConversionTarget target(getContext()); |     ConversionTarget target(getContext()); | ||||||
|     OwningRewritePatternList patterns; |     OwningRewritePatternList patterns; | ||||||
|     patterns.insert<KrnlBundleMemoryPools>(&getContext()); |     patterns.insert<KrnlBundleMemoryPools, KrnlBundleDynamicMemoryPools>( | ||||||
|  |         &getContext()); | ||||||
| 
 | 
 | ||||||
|     applyPatternsAndFoldGreedily(function, patterns); |     applyPatternsAndFoldGreedily(function, patterns); | ||||||
|  | 
 | ||||||
|  |     initMap.erase(function); | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
| } // namespace
 | } // namespace
 | ||||||
|  |  | ||||||
|  | @ -51,3 +51,135 @@ func @test_pool_bundling(%arg0: memref<10x10xf32>, %arg1: memref<10x20xf32>) -> | ||||||
|   // CHECK: dealloc [[MEMPOOL]] : memref<3200xi8> |   // CHECK: dealloc [[MEMPOOL]] : memref<3200xi8> | ||||||
|   // CHECK: return [[RES]] : memref<10x20xf32> |   // CHECK: return [[RES]] : memref<10x20xf32> | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func @test_dynamic_pool_bundling(%arg0: memref<?x?xf32>) -> memref<?x10xf32> { | ||||||
|  |   %c1 = constant 1 : index | ||||||
|  |   %c0 = constant 0 : index | ||||||
|  |   %cst = constant 0.000000e+00 : f32 | ||||||
|  |   %ind = constant 0 : index | ||||||
|  |   %c4 = constant 4 : index | ||||||
|  |   %c10 = constant 10 : index | ||||||
|  |   %c0_i64 = constant 0 : i64 | ||||||
|  |   %0 = dim %arg0, %c0 : memref<?x?xf32> | ||||||
|  |   %1 = muli %0, %c4 : index | ||||||
|  |   %2 = muli %1, %c10 : index | ||||||
|  |   %3 = alloc(%2) : memref<?xi8> | ||||||
|  |   %4 = "krnl.getref"(%3, %c0_i64) : (memref<?xi8>, i64) -> memref<?x10xf32> | ||||||
|  |   %6 = cmpi "sgt", %0, %0 : index | ||||||
|  |   %7 = select %6, %0, %0 : index | ||||||
|  |   %8 = muli %7, %c4 : index | ||||||
|  |   %9 = muli %8, %c10 : index | ||||||
|  |   %10 = alloc(%9) : memref<?xi8> | ||||||
|  |   %11 = "krnl.getref"(%10, %c0_i64) : (memref<?xi8>, i64) -> memref<?x10xf32> | ||||||
|  |   %12 = cmpi "eq", %0, %c1 : index | ||||||
|  |   %13 = cmpi "eq", %0, %c1 : index | ||||||
|  |   %15 = alloc(%0) : memref<?x10xf32> | ||||||
|  |   affine.store %cst, %4[%ind, %ind] : memref<?x10xf32> | ||||||
|  |   affine.store %cst, %11[%ind, %ind] : memref<?x10xf32> | ||||||
|  |   affine.store %cst, %15[%ind, %ind] : memref<?x10xf32> | ||||||
|  |   dealloc %10 : memref<?xi8> | ||||||
|  |   dealloc %3 : memref<?xi8> | ||||||
|  |   return %15 : memref<?x10xf32> | ||||||
|  | 
 | ||||||
|  |   // CHECK-LABEL: test_dynamic_pool_bundling | ||||||
|  |   // CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32 | ||||||
|  |   // CHECK: [[C0:%.+]] = constant 0 : index | ||||||
|  |   // CHECK: [[C4:%.+]] = constant 4 : index | ||||||
|  |   // CHECK: [[C10:%.+]] = constant 10 : index | ||||||
|  |   // CHECK: [[C0_I64:%.+]] = constant 0 : i64 | ||||||
|  |   // CHECK: [[DIM:%.+]] = dim %arg0, [[C0]] : memref<?x?xf32> | ||||||
|  |   // CHECK: [[SGT:%.+]] = cmpi "sgt", [[DIM]], [[DIM]] : index | ||||||
|  |   // CHECK: [[SELECT:%.+]] = select [[SGT]], [[DIM]], [[DIM]] : index | ||||||
|  |   // CHECK: [[MUL1:%.+]] = muli [[SELECT]], [[C4]] : index | ||||||
|  |   // CHECK: [[OFFSET1:%.+]] = muli [[MUL1]], [[C10]] : index | ||||||
|  |   // CHECK: [[MUL2:%.+]] = muli [[DIM]], [[C4]] : index | ||||||
|  |   // CHECK: [[OFFSET2:%.+]] = muli [[MUL2]], [[C10]] : index | ||||||
|  |   // CHECK: [[MEMPOOL_SIZE:%.+]] = addi [[OFFSET1]], [[OFFSET2]] : index | ||||||
|  |   // CHECK: [[OFFSET1_I64:%.+]] = index_cast [[OFFSET1]] : index to i64 | ||||||
|  |   // CHECK: [[DYN_MEMPOOL:%.+]] = alloc([[MEMPOOL_SIZE]]) : memref<?xi8> | ||||||
|  |   // CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32> | ||||||
|  |   // CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[C0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32> | ||||||
|  |   // CHECK: [[RES:%.+]] = alloc([[DIM]]) : memref<?x10xf32> | ||||||
|  |   // CHECK: affine.store [[CST]], [[DATA1]][0, 0] : memref<?x10xf32> | ||||||
|  |   // CHECK: affine.store [[CST]], [[DATA2]][0, 0] : memref<?x10xf32> | ||||||
|  |   // CHECK: affine.store [[CST]], [[RES]][0, 0] : memref<?x10xf32> | ||||||
|  |   // CHECK: dealloc [[DYN_MEMPOOL]] : memref<?xi8> | ||||||
|  |   // CHECK: return [[RES]] : memref<?x10xf32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func @test_dynamic_and_static_pool_bundling(%arg0: memref<?x?xf32>, %arg1: memref<10x10xf32>) -> memref<?x10xf32> { | ||||||
|  |   %c1 = constant 1 : index | ||||||
|  |   %c0 = constant 0 : index | ||||||
|  |   %cst = constant 0.000000e+00 : f32 | ||||||
|  |   %ind = constant 0 : index | ||||||
|  |   %c4 = constant 4 : index | ||||||
|  |   %c10 = constant 10 : index | ||||||
|  |   %c0_i64 = constant 0 : i64 | ||||||
|  |   %0 = dim %arg0, %c0 : memref<?x?xf32> | ||||||
|  |   %1 = muli %0, %c4 : index | ||||||
|  |   %2 = muli %1, %c10 : index | ||||||
|  |   %3 = alloc(%2) : memref<?xi8> | ||||||
|  |   %4 = "krnl.getref"(%3, %c0_i64) : (memref<?xi8>, i64) -> memref<?x10xf32> | ||||||
|  |   %const_alloc1 = alloc() : memref<800xi8> | ||||||
|  |   %const_ref1 = "krnl.getref"(%const_alloc1, %c0_i64) : (memref<800xi8>, i64) -> memref<10x20xf32> | ||||||
|  |   %const_alloc2 = alloc() : memref<400xi8> | ||||||
|  |   %const_ref2 = "krnl.getref"(%const_alloc2, %c0_i64) : (memref<400xi8>, i64) -> memref<10x10xf32> | ||||||
|  |   %6 = cmpi "sgt", %0, %0 : index | ||||||
|  |   %7 = select %6, %0, %0 : index | ||||||
|  |   %8 = muli %7, %c4 : index | ||||||
|  |   %9 = muli %8, %c10 : index | ||||||
|  |   %10 = alloc(%9) : memref<?xi8> | ||||||
|  |   %11 = "krnl.getref"(%10, %c0_i64) : (memref<?xi8>, i64) -> memref<?x10xf32> | ||||||
|  |   %12 = cmpi "eq", %0, %c1 : index | ||||||
|  |   %13 = cmpi "eq", %0, %c1 : index | ||||||
|  |   %15 = alloc(%0) : memref<?x10xf32> | ||||||
|  |   %const_alloc3 = alloc() : memref<1600xi8> | ||||||
|  |   %const_ref3 = "krnl.getref"(%const_alloc3, %c0_i64) : (memref<1600xi8>, i64) -> memref<10x40xf32> | ||||||
|  |   affine.store %cst, %4[%ind, %ind] : memref<?x10xf32> | ||||||
|  |   affine.store %cst, %11[%ind, %ind] : memref<?x10xf32> | ||||||
|  |   affine.store %cst, %15[%ind, %ind] : memref<?x10xf32> | ||||||
|  |   affine.store %cst, %const_ref2[%ind, %ind] : memref<10x10xf32> | ||||||
|  |   affine.store %cst, %const_ref1[%ind, %ind] : memref<10x20xf32> | ||||||
|  |   affine.store %cst, %const_ref3[%ind, %ind] : memref<10x40xf32> | ||||||
|  |   dealloc %10 : memref<?xi8> | ||||||
|  |   dealloc %3 : memref<?xi8> | ||||||
|  |   dealloc %const_alloc1 : memref<800xi8> | ||||||
|  |   dealloc %const_alloc2 : memref<400xi8> | ||||||
|  |   dealloc %const_alloc3 : memref<1600xi8> | ||||||
|  |   return %15 : memref<?x10xf32> | ||||||
|  | 
 | ||||||
|  |   // CHECK-LABEL: test_dynamic_and_static_pool_bundling | ||||||
|  |   // CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32 | ||||||
|  |   // CHECK: [[C0:%.+]] = constant 0 : index | ||||||
|  |   // CHECK: [[C4:%.+]] = constant 4 : index | ||||||
|  |   // CHECK: [[C10:%.+]] = constant 10 : index | ||||||
|  |   // CHECK: [[C400_I64:%.+]] = constant 400 : i64 | ||||||
|  |   // CHECK: [[C2000_I64:%.+]] = constant 2000 : i64 | ||||||
|  |   // CHECK: [[C0_I64:%.+]] = constant 0 : i64 | ||||||
|  |   // CHECK: [[DIM:%.+]] = dim %arg0, [[C0]] : memref<?x?xf32> | ||||||
|  |   // CHECK: [[SGT:%.+]] = cmpi "sgt", [[DIM]], [[DIM]] : index | ||||||
|  |   // CHECK: [[SELECT:%.+]] = select [[SGT]], [[DIM]], [[DIM]] : index | ||||||
|  |   // CHECK: [[MUL1:%.+]] = muli [[SELECT]], [[C4]] : index | ||||||
|  |   // CHECK: [[OFFSET1:%.+]] = muli [[MUL1]], [[C10]] : index | ||||||
|  |   // CHECK: [[MUL2:%.+]] = muli [[DIM]], [[C4]] : index | ||||||
|  |   // CHECK: [[OFFSET2:%.+]] = muli [[MUL2]], [[C10]] : index | ||||||
|  |   // CHECK: [[MEMPOOL_SIZE:%.+]] = addi [[OFFSET1]], [[OFFSET2]] : index | ||||||
|  |   // CHECK: [[OFFSET1_I64:%.+]] = index_cast [[OFFSET1]] : index to i64 | ||||||
|  |   // CHECK: [[DYN_MEMPOOL:%.+]] = alloc([[MEMPOOL_SIZE]]) : memref<?xi8> | ||||||
|  |   // CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32> | ||||||
|  |   // CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[C0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32> | ||||||
|  |   // CHECK: [[STATIC_MEMPOOL:%.+]] = alloc() : memref<2800xi8> | ||||||
|  |   // CHECK: [[DATA3:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C2000_I64]]) : (memref<2800xi8>, i64) -> memref<10x20xf32> | ||||||
|  |   // CHECK: [[DATA4:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C400_I64]]) : (memref<2800xi8>, i64) -> memref<10x40xf32> | ||||||
|  |   // CHECK: [[DATA5:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C0_I64]]) : (memref<2800xi8>, i64) -> memref<10x10xf32> | ||||||
|  |   // CHECK: [[RES:%.+]] = alloc([[DIM]]) : memref<?x10xf32> | ||||||
|  |   // CHECK: affine.store [[CST]], [[DATA1]][0, 0] : memref<?x10xf32> | ||||||
|  |   // CHECK: affine.store [[CST]], [[DATA2]][0, 0] : memref<?x10xf32> | ||||||
|  |   // CHECK: affine.store [[CST]], [[RES]][0, 0] : memref<?x10xf32> | ||||||
|  |   // CHECK: affine.store [[CST]], [[DATA5]][0, 0] : memref<10x10xf32> | ||||||
|  |   // CHECK: affine.store [[CST]], [[DATA3]][0, 0] : memref<10x20xf32> | ||||||
|  |   // CHECK: affine.store [[CST]], [[DATA4]][0, 0] : memref<10x40xf32> | ||||||
|  |   // CHECK: dealloc [[DYN_MEMPOOL]] : memref<?xi8> | ||||||
|  |   // CHECK: dealloc [[STATIC_MEMPOOL]] : memref<2800xi8> | ||||||
|  |   // CHECK: return [[RES]] : memref<?x10xf32> | ||||||
|  | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue