Emit the dynamic memory pool (#290)
* Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Add support for bundling dynamic memory pools. * Add dynamic bundling. * Clean-up code. * Clean-up file. * Add test for bundling dynamic memory pool. * Fixes. Simplify data structure. Add mixed test. * Remove unused import.
This commit is contained in:
		
							parent
							
								
									930e20f682
								
							
						
					
					
						commit
						9f69b2f317
					
				|  | @ -15,6 +15,7 @@ | |||
| #include "mlir/Dialect/StandardOps/IR/Ops.h" | ||||
| #include "mlir/Pass/Pass.h" | ||||
| #include "mlir/Transforms/DialectConversion.h" | ||||
| #include "llvm/ADT/SetVector.h" | ||||
| 
 | ||||
| #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" | ||||
| #include "src/Dialect/Krnl/KrnlOps.hpp" | ||||
|  | @ -24,6 +25,102 @@ using namespace mlir; | |||
| 
 | ||||
| namespace { | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Insertion point for initialization instructions and the blocks used for
 | ||||
| // inserting the initialization and main code. These blocks will disappear
 | ||||
| // when the first canonicalization is performed because the init block
 | ||||
| // unconditionally branches into the second block. These blocks exist only for
 | ||||
| // the purpose of this optimization.
 | ||||
| // The information is recorded on a per function basis.
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| typedef struct ONNXOperandsInitState { | ||||
|   Block *initBlock; | ||||
|   Block *mainBlock; | ||||
|   BranchOp branchInit; | ||||
|   AllocOp dynamicMemoryPool; | ||||
| } ONNXOperandsInitState; | ||||
| 
 | ||||
| // Helper data structure for the bundling of dynamic AllocOps.
 | ||||
| std::map<FuncOp, std::unique_ptr<ONNXOperandsInitState>> initMap; | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Helper functions.
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| /// Retrieve function which contains the current operation.
 | ||||
| FuncOp getContainingFunction(AllocOp op) { | ||||
|   Operation *parentFuncOp = op.getParentOp(); | ||||
| 
 | ||||
|   // While parent is not a FuncOp and its cast to a FuncOp is null.
 | ||||
|   while (!llvm::dyn_cast_or_null<FuncOp>(parentFuncOp)) | ||||
|     parentFuncOp = parentFuncOp->getParentOp(); | ||||
| 
 | ||||
|   return cast<FuncOp>(parentFuncOp); | ||||
| } | ||||
| 
 | ||||
| bool hasInitBlock(FuncOp function) { | ||||
|   std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function); | ||||
|   return initState->initBlock != nullptr; | ||||
| } | ||||
| 
 | ||||
| bool addInitBlock(PatternRewriter &rewriter, Location loc, AllocOp allocOp) { | ||||
|   // If this is the first time we encounter an operation in this
 | ||||
|   // function, we create an entry inside the initMap and split the
 | ||||
|   // function body into an init block and a main block.
 | ||||
|   //
 | ||||
|   // function func_name() {
 | ||||
|   //    ... init block ...
 | ||||
|   //    br ^bb1
 | ||||
|   //  ^bb1:  // pred: ^bb0
 | ||||
|   //    ... main block ...
 | ||||
|   //    return
 | ||||
|   // }
 | ||||
|   //
 | ||||
|   // Note: the block ^bb0 being the first block has its label omitted.
 | ||||
|   //
 | ||||
|   FuncOp function = getContainingFunction(allocOp); | ||||
|   // If the function does not contain an init block, create one.
 | ||||
|   if (!hasInitBlock(function)) { | ||||
|     std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function); | ||||
|     initState = std::make_unique<ONNXOperandsInitState>(); | ||||
| 
 | ||||
|     // All input arguments are considered as part of the initialization block
 | ||||
|     // so add them to the operandsInInitBlock set.
 | ||||
|     Block *functionBlock = &function.front(); | ||||
|     PatternRewriter::InsertionGuard insertGuard(rewriter); | ||||
|     rewriter.setInsertionPointToStart(functionBlock); | ||||
| 
 | ||||
|     initState->initBlock = rewriter.getInsertionBlock(); | ||||
|     auto currentPoint = rewriter.getInsertionPoint(); | ||||
|     initState->mainBlock = | ||||
|         rewriter.splitBlock(initState->initBlock, currentPoint); | ||||
| 
 | ||||
|     rewriter.setInsertionPointToEnd(initState->initBlock); | ||||
| 
 | ||||
|     // Insert a branch operation from initBlock to mainBlock. This
 | ||||
|     // ensures the final code contains legal blocks.
 | ||||
|     initState->branchInit = | ||||
|         rewriter.create<BranchOp>(loc, initState->mainBlock); | ||||
| 
 | ||||
|     rewriter.setInsertionPointToStart(initState->mainBlock); | ||||
| 
 | ||||
|     // Save a reference to the current dynamic memory pool value.
 | ||||
|     initState->dynamicMemoryPool = allocOp; | ||||
| 
 | ||||
|     return true; | ||||
|   } | ||||
| 
 | ||||
|   return false; | ||||
| } | ||||
| 
 | ||||
| bool isBlockArgument(Block *block, Value operand) { | ||||
|   for (auto arg : block->getArguments()) | ||||
|     if (operand == arg) | ||||
|       return true; | ||||
|   return false; | ||||
| } | ||||
| 
 | ||||
| KrnlGetRefOp getUnbundledGetRef(AllocOp *memPool) { | ||||
|   auto parentBlock = memPool->getOperation()->getBlock(); | ||||
| 
 | ||||
|  | @ -37,6 +134,23 @@ KrnlGetRefOp getUnbundledGetRef(AllocOp *memPool) { | |||
|   return unbundledGetRef; | ||||
| } | ||||
| 
 | ||||
| KrnlGetRefOp getCurrentAllocGetRef(AllocOp *allocOp) { | ||||
|   auto parentBlock = allocOp->getOperation()->getBlock(); | ||||
| 
 | ||||
|   KrnlGetRefOp currentAllocGetRef = nullptr; | ||||
|   parentBlock->walk([¤tAllocGetRef, allocOp](KrnlGetRefOp op) { | ||||
|     auto result = allocOp->getResult(); | ||||
|     if (op.getOperands()[0] == result) | ||||
|       currentAllocGetRef = op; | ||||
|   }); | ||||
| 
 | ||||
|   return currentAllocGetRef; | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Rewrite patterns.
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| /*!
 | ||||
|  *  RewritePattern that replaces: | ||||
|  *    %mem1 = alloc() : memref<<dims1>x<type>> | ||||
|  | @ -73,7 +187,7 @@ public: | |||
|     if (!checkOpResultIsUsedByGetRef(&allocOp)) | ||||
|       return failure(); | ||||
| 
 | ||||
|     // TODO: remove once we support the bundling of dynamic memory pools.
 | ||||
|     // Only handle constant AllocOps.
 | ||||
|     if (!hasAllConstantDimensions(memRefType)) | ||||
|       return failure(); | ||||
| 
 | ||||
|  | @ -85,12 +199,16 @@ public: | |||
|     if (memRefShape.size() != 1) | ||||
|       return failure(); | ||||
| 
 | ||||
|     // TODO: Change this when dyanmic shapes are supported.
 | ||||
|     // TODO: Add support for dynamic shapes.
 | ||||
|     int64_t currentMemPoolSize = memRefShape[0]; | ||||
| 
 | ||||
|     // Get a KrnlGetRefOp which does not use the current alloc.
 | ||||
|     if (KrnlGetRefOp unbundledGetRef = getUnbundledGetRef(&allocOp)) { | ||||
|       // Make sure that this get ref uses a static alloc.
 | ||||
|       auto unbundledGetRefType = | ||||
|           convertToMemRefType(unbundledGetRef.getResult().getType()); | ||||
|       if (!hasAllConstantDimensions(unbundledGetRefType)) | ||||
|         return failure(); | ||||
| 
 | ||||
|       // Current memory pool size is the offset for the newly bundled
 | ||||
|       // internal MemRef. Emit the offset as a constant.
 | ||||
|       auto offset = rewriter.create<ConstantOp>( | ||||
|  | @ -127,6 +245,201 @@ public: | |||
|   } | ||||
| }; | ||||
| 
 | ||||
| /*!
 | ||||
|  *  RewritePattern that merges a new dynamic AllocOp with the existing dynamic | ||||
|  *  memory pool. | ||||
|  *    %dyn_mempool = alloc(%a) : memref<?xi8> | ||||
|  *    %new_alloc = alloc(%b) : memref<?xi8> | ||||
|  *    %new_ref = krnl.getref %new_alloc 0 : memref<?xi8> | ||||
|  *  => | ||||
|  *    %c = addi %a, %b | ||||
|  *    %dyn_mempool = alloc(%c) : memref<?xi8> | ||||
|  *    %new_ref = krnl.getref %dyn_mempool %a : memref<?xi8> | ||||
|  */ | ||||
| 
 | ||||
| class KrnlBundleDynamicMemoryPools : public OpRewritePattern<AllocOp> { | ||||
| public: | ||||
|   using OpRewritePattern<AllocOp>::OpRewritePattern; | ||||
| 
 | ||||
|   LogicalResult matchAndRewrite( | ||||
|       AllocOp allocOp, PatternRewriter &rewriter) const override { | ||||
|     auto loc = allocOp.getLoc(); | ||||
| 
 | ||||
|     auto memRefType = convertToMemRefType(allocOp.getResult().getType()); | ||||
|     auto memRefShape = memRefType.getShape(); | ||||
| 
 | ||||
|     // If alloca result is not used by getref then it cannot be part of
 | ||||
|     // the memory pool.
 | ||||
|     if (!checkOpResultIsUsedByGetRef(&allocOp)) | ||||
|       return failure(); | ||||
| 
 | ||||
|     // Only handle dynamic allocs here.
 | ||||
|     if (hasAllConstantDimensions(memRefType)) | ||||
|       return failure(); | ||||
| 
 | ||||
|     // Alloc memory type must be byte.
 | ||||
|     if (getMemRefEltSizeInBytes(memRefType) != 1) | ||||
|       return failure(); | ||||
| 
 | ||||
|     // Rank of the allocated MemRef must be 1.
 | ||||
|     if (memRefShape.size() != 1) | ||||
|       return failure(); | ||||
| 
 | ||||
|     // Visit dependendent operations in the current parent block and assemble
 | ||||
|     // a trace of operations which participate in the computation of the size
 | ||||
|     // of the AllocOp.
 | ||||
|     auto parentBlock = allocOp.getOperation()->getBlock(); | ||||
| 
 | ||||
|     // Get function.
 | ||||
|     FuncOp function = getContainingFunction(allocOp); | ||||
|     Block *firstBlock = &function.getBody().front(); | ||||
| 
 | ||||
|     // If this is the alloc representing the memory pool and the function
 | ||||
|     // already has an init block, pattern matching must fail to avoid
 | ||||
|     // processing the dynamic memory pool a second time.
 | ||||
|     if (hasInitBlock(function)) { | ||||
|       std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function); | ||||
|       if (allocOp == initState->dynamicMemoryPool) | ||||
|         return failure(); | ||||
|     } | ||||
| 
 | ||||
|     // Initialize work queue data structure.
 | ||||
|     Operation *op = allocOp.getOperation(); | ||||
|     std::vector<Value> operandList; | ||||
|     for (const auto &operand : allocOp.getOperands()) { | ||||
|       operandList.emplace_back(operand); | ||||
|     } | ||||
| 
 | ||||
|     // Check if list of operations depends on dynamic local AllocOp.
 | ||||
|     bool dependsOnLocalDynamicAlloc = false; | ||||
| 
 | ||||
|     // Construct the list of Values on which the current AllocOp depends on.
 | ||||
|     llvm::SetVector<Operation *> dependentOps; | ||||
|     while (operandList.size() > 0) { | ||||
|       Value currentElement = operandList[0]; | ||||
|       Operation *definingOperation = currentElement.getDefiningOp(); | ||||
| 
 | ||||
|       // If this value has not been seen before, process it.
 | ||||
|       if (dependentOps.count(definingOperation) == 0) { | ||||
|         // Add value to dependent values list.
 | ||||
|         dependentOps.insert(definingOperation); | ||||
| 
 | ||||
|         // Add operands to work queue.
 | ||||
|         // printf("Processing the args of the following op:\n");
 | ||||
|         for (const auto &operand : definingOperation->getOperands()) { | ||||
|           // Check operand is not a block argument. If it is skip it, we
 | ||||
|           // consider block arguments to be leafs.
 | ||||
|           if (!isBlockArgument(firstBlock, operand)) { | ||||
|             operandList.emplace_back(operand); | ||||
| 
 | ||||
|             // Check if the current operation is a dim or a load and the
 | ||||
|             // argument list involves a local AllocOp with dynamic sizes.
 | ||||
|             // If that's the case then it means that the whole set of
 | ||||
|             // instructions cannot be moved.
 | ||||
|             // Check if the current operation is a DimOp or a LoadOp.
 | ||||
|             if (llvm::dyn_cast<DimOp>(definingOperation) || | ||||
|                 llvm::dyn_cast<LoadOp>(definingOperation)) { | ||||
|               Operation *operandOp = operand.getDefiningOp(); | ||||
|               if (operandOp) { | ||||
|                 auto localAlloc = llvm::dyn_cast<AllocOp>(operandOp); | ||||
|                 if (localAlloc) { | ||||
|                   auto memRefType = | ||||
|                       convertToMemRefType(localAlloc.getResult().getType()); | ||||
|                   if (!hasAllConstantDimensions(memRefType)) | ||||
|                     dependsOnLocalDynamicAlloc = true; | ||||
|                 } | ||||
| 
 | ||||
|                 // If operand is a getref then this alloc cannot be bundled.
 | ||||
|                 auto memPool = llvm::dyn_cast<KrnlGetRefOp>(operandOp); | ||||
|                 if (memPool) | ||||
|                   dependsOnLocalDynamicAlloc = true; | ||||
|               } | ||||
|             } | ||||
|           } | ||||
|         } | ||||
|       } | ||||
| 
 | ||||
|       // Erase first element from work queue.
 | ||||
|       operandList.erase(operandList.begin()); | ||||
|     } | ||||
| 
 | ||||
|     if (dependsOnLocalDynamicAlloc) | ||||
|       return failure(); | ||||
| 
 | ||||
|     // Order the dependent values in the same order they appear in the code.
 | ||||
|     // One cannot iterate over and make changes to the order of the operations
 | ||||
|     // of a block. A temporary ordered list of dependent instructions is
 | ||||
|     // necessary.
 | ||||
|     llvm::SmallVector<Operation *, 32> orderedDependentOps; | ||||
|     for (auto &op : | ||||
|         llvm::make_range(parentBlock->begin(), std::prev(parentBlock->end()))) | ||||
|       if (dependentOps.count(&op) > 0) | ||||
|         orderedDependentOps.emplace_back(&op); | ||||
| 
 | ||||
|     // If no dynamic alloc is in the trace of the dependent operations,
 | ||||
|     // emit the size calculation in the init block, if one exists already,
 | ||||
|     // if not, create the init block.
 | ||||
|     bool addedInitBlock = addInitBlock(rewriter, loc, allocOp); | ||||
| 
 | ||||
|     // Move the ordered dependent size calculation to the init block.
 | ||||
|     std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function); | ||||
|     for (auto &op : orderedDependentOps) | ||||
|       op->moveBefore(initState->branchInit); | ||||
| 
 | ||||
|     // Bundle MemRef type: <?xi8>
 | ||||
|     SmallVector<int64_t, 1> memPoolShape; | ||||
|     memPoolShape.emplace_back(-1); | ||||
|     auto bundledMemPoolMemRefType = | ||||
|         MemRefType::get(memPoolShape, rewriter.getIntegerType(8)); | ||||
| 
 | ||||
|     // Get the getref of the current allocOp. There is exactly one such getref.
 | ||||
|     KrnlGetRefOp currentAllocGetRef = getCurrentAllocGetRef(&allocOp); | ||||
|     if (!currentAllocGetRef) | ||||
|       return failure(); | ||||
| 
 | ||||
|     // Add the current alloc size to the current MemPool size.
 | ||||
|     Value dynamicMemoryPoolSize = initState->dynamicMemoryPool.getOperand(0); | ||||
|     if (addedInitBlock) { | ||||
|       Value zero = emitConstantOp(rewriter, loc, rewriter.getIndexType(), 0); | ||||
|       zero.getDefiningOp()->moveBefore(initState->branchInit); | ||||
|       dynamicMemoryPoolSize = zero; | ||||
|     } | ||||
| 
 | ||||
|     AddIOp bundledAllocOperand = rewriter.create<AddIOp>( | ||||
|         loc, dynamicMemoryPoolSize, allocOp.getOperand(0)); | ||||
|     bundledAllocOperand.getOperation()->moveBefore(initState->branchInit); | ||||
| 
 | ||||
|     // The newly bundled MemRef expressed as a KrnlGetRefOp.
 | ||||
|     // Current memory pool size is the offset for the newly bundled
 | ||||
|     // internal MemRef.
 | ||||
|     Value integerDynamicMemoryPoolSize = rewriter.create<IndexCastOp>( | ||||
|         loc, dynamicMemoryPoolSize, rewriter.getIntegerType(64)); | ||||
|     integerDynamicMemoryPoolSize.getDefiningOp()->moveBefore( | ||||
|         initState->branchInit); | ||||
| 
 | ||||
|     // We need to emit a new alloc which contains the additional MemRef.
 | ||||
|     AllocOp bundledAlloc = rewriter.create<AllocOp>( | ||||
|         loc, bundledMemPoolMemRefType, bundledAllocOperand.getResult()); | ||||
|     bundledAlloc.getOperation()->moveBefore(&initState->mainBlock->front()); | ||||
| 
 | ||||
|     KrnlGetRefOp bundledMemRef = rewriter.create<KrnlGetRefOp>(loc, | ||||
|         currentAllocGetRef.getResult().getType(), bundledAlloc, | ||||
|         integerDynamicMemoryPoolSize); | ||||
|     bundledMemRef.getOperation()->moveAfter(bundledAlloc); | ||||
| 
 | ||||
|     // Replace old memory pool with new one.
 | ||||
|     rewriter.replaceOp(initState->dynamicMemoryPool, bundledAlloc.getResult()); | ||||
| 
 | ||||
|     // Replace old getref with new getref from new memory pool.
 | ||||
|     rewriter.replaceOp(currentAllocGetRef, bundledMemRef.getResult()); | ||||
| 
 | ||||
|     // Update MemPool size.
 | ||||
|     initState->dynamicMemoryPool = bundledAlloc; | ||||
| 
 | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| /*!
 | ||||
|  *  Function pass that enables memory pooling for MemRefs. | ||||
|  */ | ||||
|  | @ -136,11 +449,22 @@ public: | |||
|   void runOnFunction() override { | ||||
|     auto function = getFunction(); | ||||
| 
 | ||||
|     // ModuleOp module = cast<ModuleOp>(function.getParentOp());
 | ||||
|     initMap.insert(std::pair<FuncOp, std::unique_ptr<ONNXOperandsInitState>>( | ||||
|         function, std::make_unique<ONNXOperandsInitState>())); | ||||
| 
 | ||||
|     // Initialize state for this function.
 | ||||
|     std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function); | ||||
|     initState->initBlock = nullptr; | ||||
| 
 | ||||
|     ConversionTarget target(getContext()); | ||||
|     OwningRewritePatternList patterns; | ||||
|     patterns.insert<KrnlBundleMemoryPools>(&getContext()); | ||||
|     patterns.insert<KrnlBundleMemoryPools, KrnlBundleDynamicMemoryPools>( | ||||
|         &getContext()); | ||||
| 
 | ||||
|     applyPatternsAndFoldGreedily(function, patterns); | ||||
| 
 | ||||
|     initMap.erase(function); | ||||
|   } | ||||
| }; | ||||
| } // namespace
 | ||||
|  |  | |||
|  | @ -51,3 +51,135 @@ func @test_pool_bundling(%arg0: memref<10x10xf32>, %arg1: memref<10x20xf32>) -> | |||
|   // CHECK: dealloc [[MEMPOOL]] : memref<3200xi8> | ||||
|   // CHECK: return [[RES]] : memref<10x20xf32> | ||||
| } | ||||
| 
 | ||||
| func @test_dynamic_pool_bundling(%arg0: memref<?x?xf32>) -> memref<?x10xf32> { | ||||
|   %c1 = constant 1 : index | ||||
|   %c0 = constant 0 : index | ||||
|   %cst = constant 0.000000e+00 : f32 | ||||
|   %ind = constant 0 : index | ||||
|   %c4 = constant 4 : index | ||||
|   %c10 = constant 10 : index | ||||
|   %c0_i64 = constant 0 : i64 | ||||
|   %0 = dim %arg0, %c0 : memref<?x?xf32> | ||||
|   %1 = muli %0, %c4 : index | ||||
|   %2 = muli %1, %c10 : index | ||||
|   %3 = alloc(%2) : memref<?xi8> | ||||
|   %4 = "krnl.getref"(%3, %c0_i64) : (memref<?xi8>, i64) -> memref<?x10xf32> | ||||
|   %6 = cmpi "sgt", %0, %0 : index | ||||
|   %7 = select %6, %0, %0 : index | ||||
|   %8 = muli %7, %c4 : index | ||||
|   %9 = muli %8, %c10 : index | ||||
|   %10 = alloc(%9) : memref<?xi8> | ||||
|   %11 = "krnl.getref"(%10, %c0_i64) : (memref<?xi8>, i64) -> memref<?x10xf32> | ||||
|   %12 = cmpi "eq", %0, %c1 : index | ||||
|   %13 = cmpi "eq", %0, %c1 : index | ||||
|   %15 = alloc(%0) : memref<?x10xf32> | ||||
|   affine.store %cst, %4[%ind, %ind] : memref<?x10xf32> | ||||
|   affine.store %cst, %11[%ind, %ind] : memref<?x10xf32> | ||||
|   affine.store %cst, %15[%ind, %ind] : memref<?x10xf32> | ||||
|   dealloc %10 : memref<?xi8> | ||||
|   dealloc %3 : memref<?xi8> | ||||
|   return %15 : memref<?x10xf32> | ||||
| 
 | ||||
|   // CHECK-LABEL: test_dynamic_pool_bundling | ||||
|   // CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32 | ||||
|   // CHECK: [[C0:%.+]] = constant 0 : index | ||||
|   // CHECK: [[C4:%.+]] = constant 4 : index | ||||
|   // CHECK: [[C10:%.+]] = constant 10 : index | ||||
|   // CHECK: [[C0_I64:%.+]] = constant 0 : i64 | ||||
|   // CHECK: [[DIM:%.+]] = dim %arg0, [[C0]] : memref<?x?xf32> | ||||
|   // CHECK: [[SGT:%.+]] = cmpi "sgt", [[DIM]], [[DIM]] : index | ||||
|   // CHECK: [[SELECT:%.+]] = select [[SGT]], [[DIM]], [[DIM]] : index | ||||
|   // CHECK: [[MUL1:%.+]] = muli [[SELECT]], [[C4]] : index | ||||
|   // CHECK: [[OFFSET1:%.+]] = muli [[MUL1]], [[C10]] : index | ||||
|   // CHECK: [[MUL2:%.+]] = muli [[DIM]], [[C4]] : index | ||||
|   // CHECK: [[OFFSET2:%.+]] = muli [[MUL2]], [[C10]] : index | ||||
|   // CHECK: [[MEMPOOL_SIZE:%.+]] = addi [[OFFSET1]], [[OFFSET2]] : index | ||||
|   // CHECK: [[OFFSET1_I64:%.+]] = index_cast [[OFFSET1]] : index to i64 | ||||
|   // CHECK: [[DYN_MEMPOOL:%.+]] = alloc([[MEMPOOL_SIZE]]) : memref<?xi8> | ||||
|   // CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32> | ||||
|   // CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[C0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32> | ||||
|   // CHECK: [[RES:%.+]] = alloc([[DIM]]) : memref<?x10xf32> | ||||
|   // CHECK: affine.store [[CST]], [[DATA1]][0, 0] : memref<?x10xf32> | ||||
|   // CHECK: affine.store [[CST]], [[DATA2]][0, 0] : memref<?x10xf32> | ||||
|   // CHECK: affine.store [[CST]], [[RES]][0, 0] : memref<?x10xf32> | ||||
|   // CHECK: dealloc [[DYN_MEMPOOL]] : memref<?xi8> | ||||
|   // CHECK: return [[RES]] : memref<?x10xf32> | ||||
| } | ||||
| 
 | ||||
| func @test_dynamic_and_static_pool_bundling(%arg0: memref<?x?xf32>, %arg1: memref<10x10xf32>) -> memref<?x10xf32> { | ||||
|   %c1 = constant 1 : index | ||||
|   %c0 = constant 0 : index | ||||
|   %cst = constant 0.000000e+00 : f32 | ||||
|   %ind = constant 0 : index | ||||
|   %c4 = constant 4 : index | ||||
|   %c10 = constant 10 : index | ||||
|   %c0_i64 = constant 0 : i64 | ||||
|   %0 = dim %arg0, %c0 : memref<?x?xf32> | ||||
|   %1 = muli %0, %c4 : index | ||||
|   %2 = muli %1, %c10 : index | ||||
|   %3 = alloc(%2) : memref<?xi8> | ||||
|   %4 = "krnl.getref"(%3, %c0_i64) : (memref<?xi8>, i64) -> memref<?x10xf32> | ||||
|   %const_alloc1 = alloc() : memref<800xi8> | ||||
|   %const_ref1 = "krnl.getref"(%const_alloc1, %c0_i64) : (memref<800xi8>, i64) -> memref<10x20xf32> | ||||
|   %const_alloc2 = alloc() : memref<400xi8> | ||||
|   %const_ref2 = "krnl.getref"(%const_alloc2, %c0_i64) : (memref<400xi8>, i64) -> memref<10x10xf32> | ||||
|   %6 = cmpi "sgt", %0, %0 : index | ||||
|   %7 = select %6, %0, %0 : index | ||||
|   %8 = muli %7, %c4 : index | ||||
|   %9 = muli %8, %c10 : index | ||||
|   %10 = alloc(%9) : memref<?xi8> | ||||
|   %11 = "krnl.getref"(%10, %c0_i64) : (memref<?xi8>, i64) -> memref<?x10xf32> | ||||
|   %12 = cmpi "eq", %0, %c1 : index | ||||
|   %13 = cmpi "eq", %0, %c1 : index | ||||
|   %15 = alloc(%0) : memref<?x10xf32> | ||||
|   %const_alloc3 = alloc() : memref<1600xi8> | ||||
|   %const_ref3 = "krnl.getref"(%const_alloc3, %c0_i64) : (memref<1600xi8>, i64) -> memref<10x40xf32> | ||||
|   affine.store %cst, %4[%ind, %ind] : memref<?x10xf32> | ||||
|   affine.store %cst, %11[%ind, %ind] : memref<?x10xf32> | ||||
|   affine.store %cst, %15[%ind, %ind] : memref<?x10xf32> | ||||
|   affine.store %cst, %const_ref2[%ind, %ind] : memref<10x10xf32> | ||||
|   affine.store %cst, %const_ref1[%ind, %ind] : memref<10x20xf32> | ||||
|   affine.store %cst, %const_ref3[%ind, %ind] : memref<10x40xf32> | ||||
|   dealloc %10 : memref<?xi8> | ||||
|   dealloc %3 : memref<?xi8> | ||||
|   dealloc %const_alloc1 : memref<800xi8> | ||||
|   dealloc %const_alloc2 : memref<400xi8> | ||||
|   dealloc %const_alloc3 : memref<1600xi8> | ||||
|   return %15 : memref<?x10xf32> | ||||
| 
 | ||||
|   // CHECK-LABEL: test_dynamic_and_static_pool_bundling | ||||
|   // CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32 | ||||
|   // CHECK: [[C0:%.+]] = constant 0 : index | ||||
|   // CHECK: [[C4:%.+]] = constant 4 : index | ||||
|   // CHECK: [[C10:%.+]] = constant 10 : index | ||||
|   // CHECK: [[C400_I64:%.+]] = constant 400 : i64 | ||||
|   // CHECK: [[C2000_I64:%.+]] = constant 2000 : i64 | ||||
|   // CHECK: [[C0_I64:%.+]] = constant 0 : i64 | ||||
|   // CHECK: [[DIM:%.+]] = dim %arg0, [[C0]] : memref<?x?xf32> | ||||
|   // CHECK: [[SGT:%.+]] = cmpi "sgt", [[DIM]], [[DIM]] : index | ||||
|   // CHECK: [[SELECT:%.+]] = select [[SGT]], [[DIM]], [[DIM]] : index | ||||
|   // CHECK: [[MUL1:%.+]] = muli [[SELECT]], [[C4]] : index | ||||
|   // CHECK: [[OFFSET1:%.+]] = muli [[MUL1]], [[C10]] : index | ||||
|   // CHECK: [[MUL2:%.+]] = muli [[DIM]], [[C4]] : index | ||||
|   // CHECK: [[OFFSET2:%.+]] = muli [[MUL2]], [[C10]] : index | ||||
|   // CHECK: [[MEMPOOL_SIZE:%.+]] = addi [[OFFSET1]], [[OFFSET2]] : index | ||||
|   // CHECK: [[OFFSET1_I64:%.+]] = index_cast [[OFFSET1]] : index to i64 | ||||
|   // CHECK: [[DYN_MEMPOOL:%.+]] = alloc([[MEMPOOL_SIZE]]) : memref<?xi8> | ||||
|   // CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32> | ||||
|   // CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[C0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32> | ||||
|   // CHECK: [[STATIC_MEMPOOL:%.+]] = alloc() : memref<2800xi8> | ||||
|   // CHECK: [[DATA3:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C2000_I64]]) : (memref<2800xi8>, i64) -> memref<10x20xf32> | ||||
|   // CHECK: [[DATA4:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C400_I64]]) : (memref<2800xi8>, i64) -> memref<10x40xf32> | ||||
|   // CHECK: [[DATA5:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C0_I64]]) : (memref<2800xi8>, i64) -> memref<10x10xf32> | ||||
|   // CHECK: [[RES:%.+]] = alloc([[DIM]]) : memref<?x10xf32> | ||||
|   // CHECK: affine.store [[CST]], [[DATA1]][0, 0] : memref<?x10xf32> | ||||
|   // CHECK: affine.store [[CST]], [[DATA2]][0, 0] : memref<?x10xf32> | ||||
|   // CHECK: affine.store [[CST]], [[RES]][0, 0] : memref<?x10xf32> | ||||
|   // CHECK: affine.store [[CST]], [[DATA5]][0, 0] : memref<10x10xf32> | ||||
|   // CHECK: affine.store [[CST]], [[DATA3]][0, 0] : memref<10x20xf32> | ||||
|   // CHECK: affine.store [[CST]], [[DATA4]][0, 0] : memref<10x40xf32> | ||||
|   // CHECK: dealloc [[DYN_MEMPOOL]] : memref<?xi8> | ||||
|   // CHECK: dealloc [[STATIC_MEMPOOL]] : memref<2800xi8> | ||||
|   // CHECK: return [[RES]] : memref<?x10xf32> | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue