Enable bundling of dynamic memory pools on a block basis. (#330)

* Reorganize main function.

* Follow review comments.

* Emit constants are globals in Krnl and LLVM dialects.

* Fix dynamic bundling.

* Format.

* Fix test.
This commit is contained in:
Gheorghe-Teodor Bercea 2020-10-05 13:55:46 -04:00 committed by GitHub
parent 931127c7e9
commit 7bfb5c93c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 169 additions and 118 deletions

View File

@ -26,23 +26,12 @@ using namespace mlir;
namespace { namespace {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Insertion point for initialization instructions and the blocks used for // Data structures for managing memory pools.
// inserting the initialization and main code. These blocks will disappear
// when the first canonicalization is performed because the init block
// unconditionally branches into the second block. These blocks exist only for
// the purpose of this optimization.
// The information is recorded on a per function basis.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
typedef struct ONNXOperandsInitState { // Data structure for managing dyanmic memory pool.
Block *initBlock; typedef std::map<Block *, AllocOp> BlockToDynamicPool;
Block *mainBlock; std::map<FuncOp, std::unique_ptr<BlockToDynamicPool>> dynamicPoolMap;
BranchOp branchInit;
AllocOp dynamicMemoryPool;
} ONNXOperandsInitState;
// Helper data structure for the bundling of dynamic AllocOps.
std::map<FuncOp, std::unique_ptr<ONNXOperandsInitState>> initMap;
// Handling of static memory pool on a block-basis in each function. // Handling of static memory pool on a block-basis in each function.
typedef std::map<Block *, AllocOp> BlockToStaticPool; typedef std::map<Block *, AllocOp> BlockToStaticPool;
@ -63,65 +52,24 @@ FuncOp getContainingFunction(AllocOp op) {
return cast<FuncOp>(parentFuncOp); return cast<FuncOp>(parentFuncOp);
} }
bool hasInitBlock(FuncOp function) { // Check if this value is an argument of one of the blocks nested
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function); // around it.
return initState->initBlock != nullptr; bool isBlockArgument(AllocOp allocOp, Value operand) {
} // Parent operation of the current block.
Operation *parentBlockOp;
Block *currentBlock = allocOp.getOperation()->getBlock();
bool addInitBlock(PatternRewriter &rewriter, Location loc, AllocOp allocOp) { do {
// If this is the first time we encounter an operation in this // Check the arguments of the current block.
// function, we create an entry inside the initMap and split the for (auto arg : currentBlock->getArguments())
// function body into an init block and a main block.
//
// function func_name() {
// ... init block ...
// br ^bb1
// ^bb1: // pred: ^bb0
// ... main block ...
// return
// }
//
// Note: the block ^bb0 being the first block has its label omitted.
//
FuncOp function = getContainingFunction(allocOp);
// If the function does not contain an init block, create one.
if (!hasInitBlock(function)) {
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function);
initState = std::make_unique<ONNXOperandsInitState>();
// All input arguments are considered as part of the initialization block
// so add them to the operandsInInitBlock set.
Block *functionBlock = &function.front();
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(functionBlock);
initState->initBlock = rewriter.getInsertionBlock();
auto currentPoint = rewriter.getInsertionPoint();
initState->mainBlock =
rewriter.splitBlock(initState->initBlock, currentPoint);
rewriter.setInsertionPointToEnd(initState->initBlock);
// Insert a branch operation from initBlock to mainBlock. This
// ensures the final code contains legal blocks.
initState->branchInit =
rewriter.create<BranchOp>(loc, initState->mainBlock);
rewriter.setInsertionPointToStart(initState->mainBlock);
// Save a reference to the current dynamic memory pool value.
initState->dynamicMemoryPool = allocOp;
return true;
}
return false;
}
bool isBlockArgument(Block *block, Value operand) {
for (auto arg : block->getArguments())
if (operand == arg) if (operand == arg)
return true; return true;
parentBlockOp = currentBlock->getParentOp();
currentBlock = parentBlockOp->getBlock();
} while (!llvm::dyn_cast_or_null<FuncOp>(parentBlockOp));
return false; return false;
} }
@ -332,14 +280,17 @@ public:
// Get function. // Get function.
FuncOp function = getContainingFunction(allocOp); FuncOp function = getContainingFunction(allocOp);
Block *firstBlock = &function.getBody().front();
// If this is the alloc representing the memory pool and the function // Use function to retrieve the list of blocks for this function.
// already has an init block, pattern matching must fail to avoid std::unique_ptr<BlockToDynamicPool> &blockToDynamicPool =
// processing the dynamic memory pool a second time. dynamicPoolMap.at(function);
if (hasInitBlock(function)) {
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function); // If this is not the first time we process an alloc in this block, avoid
if (allocOp == initState->dynamicMemoryPool) // processing the current dynamic memory pool again.
if (blockToDynamicPool->count(parentBlock) > 0) {
std::unique_ptr<BlockToDynamicPool> &blockToDynamicPool =
dynamicPoolMap.at(function);
if (allocOp == blockToDynamicPool->at(parentBlock))
return failure(); return failure();
} }
@ -365,11 +316,10 @@ public:
dependentOps.insert(definingOperation); dependentOps.insert(definingOperation);
// Add operands to work queue. // Add operands to work queue.
// printf("Processing the args of the following op:\n");
for (const auto &operand : definingOperation->getOperands()) { for (const auto &operand : definingOperation->getOperands()) {
// Check operand is not a block argument. If it is skip it, we // Check operand is not a block argument. If it is skip it, we
// consider block arguments to be leafs. // consider block arguments to be leafs.
if (!isBlockArgument(firstBlock, operand)) { if (!isBlockArgument(allocOp, operand)) {
operandList.emplace_back(operand); operandList.emplace_back(operand);
// Check if the current operation is a dim or a load and the // Check if the current operation is a dim or a load and the
@ -416,15 +366,23 @@ public:
if (dependentOps.count(&op) > 0) if (dependentOps.count(&op) > 0)
orderedDependentOps.emplace_back(&op); orderedDependentOps.emplace_back(&op);
// If no dynamic alloc is in the trace of the dependent operations, // If this is the first valid alloc we can bundle in this block, then we
// emit the size calculation in the init block, if one exists already, // need to move it to the top of the block as it will consitute an
// if not, create the init block. // insertion point for all other bundle-able AllocOps in the block.
bool addedInitBlock = addInitBlock(rewriter, loc, allocOp); bool isFirstBundledAllocOp = blockToDynamicPool->count(parentBlock) == 0;
if (isFirstBundledAllocOp) {
allocOp.getOperation()->moveBefore(&parentBlock->front());
// Move the ordered dependent size calculation to the init block. // Create new entry in the block map.
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function); blockToDynamicPool->insert(
std::pair<Block *, AllocOp>(parentBlock, allocOp));
}
// Move the computation instructions at the start of the block.
AllocOp oldDynamicMemoryPool = blockToDynamicPool->at(parentBlock);
std::reverse(orderedDependentOps.begin(), orderedDependentOps.end());
for (auto &op : orderedDependentOps) for (auto &op : orderedDependentOps)
op->moveBefore(initState->branchInit); op->moveBefore(&parentBlock->front());
// Bundle MemRef type: <?xi8> // Bundle MemRef type: <?xi8>
SmallVector<int64_t, 1> memPoolShape; SmallVector<int64_t, 1> memPoolShape;
@ -438,16 +396,16 @@ public:
return failure(); return failure();
// Add the current alloc size to the current MemPool size. // Add the current alloc size to the current MemPool size.
Value dynamicMemoryPoolSize = initState->dynamicMemoryPool.getOperand(0); Value dynamicMemoryPoolSize = oldDynamicMemoryPool.getOperand(0);
if (addedInitBlock) { if (isFirstBundledAllocOp) {
Value zero = emitConstantOp(rewriter, loc, rewriter.getIndexType(), 0); Value zero = emitConstantOp(rewriter, loc, rewriter.getIndexType(), 0);
zero.getDefiningOp()->moveBefore(initState->branchInit); zero.getDefiningOp()->moveBefore(oldDynamicMemoryPool);
dynamicMemoryPoolSize = zero; dynamicMemoryPoolSize = zero;
} }
AddIOp bundledAllocOperand = rewriter.create<AddIOp>( AddIOp bundledAllocOperand = rewriter.create<AddIOp>(
loc, dynamicMemoryPoolSize, allocOp.getOperand(0)); loc, dynamicMemoryPoolSize, allocOp.getOperand(0));
bundledAllocOperand.getOperation()->moveBefore(initState->branchInit); bundledAllocOperand.getOperation()->moveBefore(oldDynamicMemoryPool);
// The newly bundled MemRef expressed as a KrnlGetRefOp. // The newly bundled MemRef expressed as a KrnlGetRefOp.
// Current memory pool size is the offset for the newly bundled // Current memory pool size is the offset for the newly bundled
@ -455,26 +413,27 @@ public:
Value integerDynamicMemoryPoolSize = rewriter.create<IndexCastOp>( Value integerDynamicMemoryPoolSize = rewriter.create<IndexCastOp>(
loc, dynamicMemoryPoolSize, rewriter.getIntegerType(64)); loc, dynamicMemoryPoolSize, rewriter.getIntegerType(64));
integerDynamicMemoryPoolSize.getDefiningOp()->moveBefore( integerDynamicMemoryPoolSize.getDefiningOp()->moveBefore(
initState->branchInit); oldDynamicMemoryPool);
// We need to emit a new alloc which contains the additional MemRef. // We need to emit a new alloc which contains the additional MemRef.
AllocOp bundledAlloc = rewriter.create<AllocOp>( AllocOp bundledAlloc = rewriter.create<AllocOp>(
loc, bundledMemPoolMemRefType, bundledAllocOperand.getResult()); loc, bundledMemPoolMemRefType, bundledAllocOperand.getResult());
bundledAlloc.getOperation()->moveBefore(&initState->mainBlock->front()); bundledAlloc.getOperation()->moveBefore(oldDynamicMemoryPool);
KrnlGetRefOp bundledMemRef = rewriter.create<KrnlGetRefOp>(loc, KrnlGetRefOp bundledMemRef = rewriter.create<KrnlGetRefOp>(loc,
currentAllocGetRef.getResult().getType(), bundledAlloc, currentAllocGetRef.getResult().getType(), bundledAlloc,
integerDynamicMemoryPoolSize); integerDynamicMemoryPoolSize);
bundledMemRef.getOperation()->moveAfter(bundledAlloc);
// Replace old memory pool with new one. // Replace old memory pool with new one.
rewriter.replaceOp(initState->dynamicMemoryPool, bundledAlloc.getResult()); rewriter.replaceOp(oldDynamicMemoryPool, bundledAlloc.getResult());
// Replace old getref with new getref from new memory pool. // Replace old getref with new getref from new memory pool.
rewriter.replaceOp(currentAllocGetRef, bundledMemRef.getResult()); rewriter.replaceOp(currentAllocGetRef, bundledMemRef.getResult());
// Update MemPool size. // Update MemPool data structure.
initState->dynamicMemoryPool = bundledAlloc; blockToDynamicPool->erase(parentBlock);
blockToDynamicPool->insert(
std::pair<Block *, AllocOp>(parentBlock, bundledAlloc));
return success(); return success();
} }
@ -488,16 +447,14 @@ class KrnlBundleMemoryPoolsPass
public: public:
void runOnFunction() override { void runOnFunction() override {
auto function = getFunction(); auto function = getFunction();
initMap.insert(std::pair<FuncOp, std::unique_ptr<ONNXOperandsInitState>>(
function, std::make_unique<ONNXOperandsInitState>())); dynamicPoolMap.insert(
std::pair<FuncOp, std::unique_ptr<BlockToDynamicPool>>(
function, std::make_unique<BlockToDynamicPool>()));
staticPoolMap.insert(std::pair<FuncOp, std::unique_ptr<BlockToStaticPool>>( staticPoolMap.insert(std::pair<FuncOp, std::unique_ptr<BlockToStaticPool>>(
function, std::make_unique<BlockToStaticPool>())); function, std::make_unique<BlockToStaticPool>()));
// Initialize state for this function.
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function);
initState->initBlock = nullptr;
ConversionTarget target(getContext()); ConversionTarget target(getContext());
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
patterns.insert<KrnlBundleStaticMemoryPools, KrnlBundleDynamicMemoryPools>( patterns.insert<KrnlBundleStaticMemoryPools, KrnlBundleDynamicMemoryPools>(
@ -505,7 +462,7 @@ public:
applyPatternsAndFoldGreedily(function, patterns); applyPatternsAndFoldGreedily(function, patterns);
initMap.erase(function); dynamicPoolMap.erase(function);
staticPoolMap.erase(function); staticPoolMap.erase(function);
} }
}; };

View File

@ -1,4 +1,4 @@
// RUN: onnx-mlir-opt --bundle-memory-pools --canonicalize %s -split-input-file | FileCheck %s // RUN: onnx-mlir-opt --bundle-memory-pools --canonicalize %s | FileCheck %s
func @test_pool_bundling(%arg0: memref<10x10xf32>, %arg1: memref<10x20xf32>) -> memref<10x20xf32> { func @test_pool_bundling(%arg0: memref<10x10xf32>, %arg1: memref<10x20xf32>) -> memref<10x20xf32> {
%c0_i64 = constant 0 : i64 %c0_i64 = constant 0 : i64
@ -88,17 +88,17 @@ func @test_dynamic_pool_bundling(%arg0: memref<?x?xf32>) -> memref<?x10xf32> {
// CHECK: [[C10:%.+]] = constant 10 : index // CHECK: [[C10:%.+]] = constant 10 : index
// CHECK: [[C0_I64:%.+]] = constant 0 : i64 // CHECK: [[C0_I64:%.+]] = constant 0 : i64
// CHECK: [[DIM:%.+]] = dim %arg0, [[C0]] : memref<?x?xf32> // CHECK: [[DIM:%.+]] = dim %arg0, [[C0]] : memref<?x?xf32>
// CHECK: [[MUL2:%.+]] = muli [[DIM]], [[C4]] : index
// CHECK: [[OFFSET2:%.+]] = muli [[MUL2]], [[C10]] : index
// CHECK: [[SGT:%.+]] = cmpi "sgt", [[DIM]], [[DIM]] : index // CHECK: [[SGT:%.+]] = cmpi "sgt", [[DIM]], [[DIM]] : index
// CHECK: [[SELECT:%.+]] = select [[SGT]], [[DIM]], [[DIM]] : index // CHECK: [[SELECT:%.+]] = select [[SGT]], [[DIM]], [[DIM]] : index
// CHECK: [[MUL1:%.+]] = muli [[SELECT]], [[C4]] : index // CHECK: [[MUL1:%.+]] = muli [[SELECT]], [[C4]] : index
// CHECK: [[OFFSET1:%.+]] = muli [[MUL1]], [[C10]] : index // CHECK: [[OFFSET1:%.+]] = muli [[MUL1]], [[C10]] : index
// CHECK: [[MUL2:%.+]] = muli [[DIM]], [[C4]] : index
// CHECK: [[OFFSET2:%.+]] = muli [[MUL2]], [[C10]] : index
// CHECK: [[MEMPOOL_SIZE:%.+]] = addi [[OFFSET1]], [[OFFSET2]] : index // CHECK: [[MEMPOOL_SIZE:%.+]] = addi [[OFFSET1]], [[OFFSET2]] : index
// CHECK: [[OFFSET1_I64:%.+]] = index_cast [[OFFSET1]] : index to i64 // CHECK: [[OFFSET1_I64:%.+]] = index_cast [[OFFSET1]] : index to i64
// CHECK: [[DYN_MEMPOOL:%.+]] = alloc([[MEMPOOL_SIZE]]) : memref<?xi8> // CHECK: [[DYN_MEMPOOL:%.+]] = alloc([[MEMPOOL_SIZE]]) : memref<?xi8>
// CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
// CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[C0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32> // CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[C0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
// CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM]]) : memref<?x10xf32> // CHECK: [[RES:%.+]] = alloc([[DIM]]) : memref<?x10xf32>
// CHECK: affine.store [[CST]], [[DATA1]][0, 0] : memref<?x10xf32> // CHECK: affine.store [[CST]], [[DATA1]][0, 0] : memref<?x10xf32>
// CHECK: affine.store [[CST]], [[DATA2]][0, 0] : memref<?x10xf32> // CHECK: affine.store [[CST]], [[DATA2]][0, 0] : memref<?x10xf32>
@ -149,36 +149,36 @@ func @test_dynamic_and_static_pool_bundling(%arg0: memref<?x?xf32>, %arg1: memre
return %15 : memref<?x10xf32> return %15 : memref<?x10xf32>
// CHECK-LABEL: test_dynamic_and_static_pool_bundling // CHECK-LABEL: test_dynamic_and_static_pool_bundling
// CHECK: [[C1200_I64:%.+]] = constant 1200 : i64
// CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32 // CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32
// CHECK: [[C0:%.+]] = constant 0 : index // CHECK: [[C0:%.+]] = constant 0 : index
// CHECK: [[C4:%.+]] = constant 4 : index // CHECK: [[C4:%.+]] = constant 4 : index
// CHECK: [[C10:%.+]] = constant 10 : index // CHECK: [[C10:%.+]] = constant 10 : index
// CHECK: [[C400_I64:%.+]] = constant 400 : i64 // CHECK: [[C2000_I64:%.+]] = constant 2000 : i64
// CHECK: [[C1600_I64:%.+]] = constant 1600 : i64
// CHECK: [[C0_I64:%.+]] = constant 0 : i64 // CHECK: [[C0_I64:%.+]] = constant 0 : i64
// CHECK: [[DIM:%.+]] = dim %arg0, [[C0]] : memref<?x?xf32> // CHECK: [[DIM:%.+]] = dim %arg0, [[C0]] : memref<?x?xf32>
// CHECK: [[MUL2:%.+]] = muli [[DIM]], [[C4]] : index
// CHECK: [[OFFSET2:%.+]] = muli [[MUL2]], [[C10]] : index
// CHECK: [[SGT:%.+]] = cmpi "sgt", [[DIM]], [[DIM]] : index // CHECK: [[SGT:%.+]] = cmpi "sgt", [[DIM]], [[DIM]] : index
// CHECK: [[SELECT:%.+]] = select [[SGT]], [[DIM]], [[DIM]] : index // CHECK: [[SELECT:%.+]] = select [[SGT]], [[DIM]], [[DIM]] : index
// CHECK: [[MUL1:%.+]] = muli [[SELECT]], [[C4]] : index // CHECK: [[MUL1:%.+]] = muli [[SELECT]], [[C4]] : index
// CHECK: [[OFFSET1:%.+]] = muli [[MUL1]], [[C10]] : index // CHECK: [[OFFSET1:%.+]] = muli [[MUL1]], [[C10]] : index
// CHECK: [[MUL2:%.+]] = muli [[DIM]], [[C4]] : index
// CHECK: [[OFFSET2:%.+]] = muli [[MUL2]], [[C10]] : index
// CHECK: [[MEMPOOL_SIZE:%.+]] = addi [[OFFSET1]], [[OFFSET2]] : index // CHECK: [[MEMPOOL_SIZE:%.+]] = addi [[OFFSET1]], [[OFFSET2]] : index
// CHECK: [[OFFSET1_I64:%.+]] = index_cast [[OFFSET1]] : index to i64 // CHECK: [[OFFSET1_I64:%.+]] = index_cast [[OFFSET1]] : index to i64
// CHECK: [[DYN_MEMPOOL:%.+]] = alloc([[MEMPOOL_SIZE]]) : memref<?xi8> // CHECK: [[DYN_MEMPOOL:%.+]] = alloc([[MEMPOOL_SIZE]]) : memref<?xi8>
// CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
// CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[C0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32> // CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[C0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
// CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
// CHECK: [[STATIC_MEMPOOL:%.+]] = alloc() : memref<2800xi8> // CHECK: [[STATIC_MEMPOOL:%.+]] = alloc() : memref<2800xi8>
// CHECK: [[DATA3:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C1200_I64]]) : (memref<2800xi8>, i64) -> memref<10x40xf32> // CHECK: [[DATA3:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C2000_I64]]) : (memref<2800xi8>, i64) -> memref<10x20xf32>
// CHECK: [[DATA4:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C400_I64]]) : (memref<2800xi8>, i64) -> memref<10x20xf32> // CHECK: [[DATA4:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C1600_I64]]) : (memref<2800xi8>, i64) -> memref<10x10xf32>
// CHECK: [[DATA5:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C0_I64]]) : (memref<2800xi8>, i64) -> memref<10x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM]]) : memref<?x10xf32> // CHECK: [[RES:%.+]] = alloc([[DIM]]) : memref<?x10xf32>
// CHECK: [[DATA5:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C0_I64]]) : (memref<2800xi8>, i64) -> memref<10x40xf32>
// CHECK: affine.store [[CST]], [[DATA1]][0, 0] : memref<?x10xf32> // CHECK: affine.store [[CST]], [[DATA1]][0, 0] : memref<?x10xf32>
// CHECK: affine.store [[CST]], [[DATA2]][0, 0] : memref<?x10xf32> // CHECK: affine.store [[CST]], [[DATA2]][0, 0] : memref<?x10xf32>
// CHECK: affine.store [[CST]], [[RES]][0, 0] : memref<?x10xf32> // CHECK: affine.store [[CST]], [[RES]][0, 0] : memref<?x10xf32>
// CHECK: affine.store [[CST]], [[DATA5]][0, 0] : memref<10x10xf32> // CHECK: affine.store [[CST]], [[DATA4]][0, 0] : memref<10x10xf32>
// CHECK: affine.store [[CST]], [[DATA4]][0, 0] : memref<10x20xf32> // CHECK: affine.store [[CST]], [[DATA3]][0, 0] : memref<10x20xf32>
// CHECK: affine.store [[CST]], [[DATA3]][0, 0] : memref<10x40xf32> // CHECK: affine.store [[CST]], [[DATA5]][0, 0] : memref<10x40xf32>
// CHECK: dealloc [[DYN_MEMPOOL]] : memref<?xi8> // CHECK: dealloc [[DYN_MEMPOOL]] : memref<?xi8>
// CHECK: dealloc [[STATIC_MEMPOOL]] : memref<2800xi8> // CHECK: dealloc [[STATIC_MEMPOOL]] : memref<2800xi8>
// CHECK: return [[RES]] : memref<?x10xf32> // CHECK: return [[RES]] : memref<?x10xf32>
@ -365,3 +365,97 @@ func @static_mem_pool_rnn_sub_and_main_block(%arg0: memref<1x3x2xf32>, %arg1: me
// CHECK: dealloc [[STATIC_MEM_POOL_MAIN]] : memref<12xi8> // CHECK: dealloc [[STATIC_MEM_POOL_MAIN]] : memref<12xi8>
// CHECK: return [[RES]] : memref<1x3x4xf32> // CHECK: return [[RES]] : memref<1x3x4xf32>
} }
/// Test dynamic pooling in sub-block.
func @test_dynamic_pool_rnn(%arg0: memref<1x3x2xf32>, %arg1: memref<1x4x2xf32>, %arg2: memref<1x?x?xf32>) -> memref<1x3x?xf32> attributes {input_names = ["X", "W", "R"], output_names = ["Y"]} {
%cst = constant 0.000000e+00 : f32
%c0_i64 = constant 0 : i64
%c_ind = constant 1 : index
%dim0 = dim %arg1, %c_ind : memref<1x4x2xf32>
%0 = alloc(%dim0) : memref<1x3x?xf32>
%1:3 = krnl.define_loops 3
krnl.iterate(%1#0, %1#1, %1#2) with (%1#0 -> %arg3 = 0 to 1, %1#1 -> %arg4 = 0 to 3, %1#2 -> %arg5 = 0 to 4) {
affine.store %cst, %0[symbol(%arg3), symbol(%arg4), symbol(%arg5)] : memref<1x3x?xf32>
}
%2 = krnl.define_loops 1
krnl.iterate(%2) with (%2 -> %arg3 = 0 to 1) {
%3:2 = krnl.define_loops 2
krnl.iterate(%3#0, %3#1) with (%3#0 -> %arg4 = 0 to 3, %3#1 -> %arg5 = 0 to 4) {
%dim1 = dim %arg2, %c_ind : memref<1x?x?xf32>
%4 = alloc(%dim1) : memref<?xi8>
%5 = "krnl.getref"(%4, %c0_i64) : (memref<?xi8>, i64) -> memref<f32>
%6 = affine.load %0[0, symbol(%arg4), symbol(%arg5)] : memref<1x3x?xf32>
%7 = alloc(%dim1) : memref<?xi8>
%8 = "krnl.getref"(%7, %c0_i64) : (memref<?xi8>, i64) -> memref<f32>
affine.store %cst, %8[] : memref<f32>
%c_ind_2 = constant 2 : index
%dim2 = dim %arg2, %c_ind_2 : memref<1x?x?xf32>
%9 = alloc(%dim2) : memref<?xi8>
%10 = "krnl.getref"(%9, %c0_i64) : (memref<?xi8>, i64) -> memref<f32>
affine.store %cst, %10[] : memref<f32>
%11 = krnl.define_loops 1
krnl.iterate(%11) with (%11 -> %arg6 = 0 to 2) {
%25 = affine.load %arg0[symbol(%arg3), symbol(%arg4), symbol(%arg6)] : memref<1x3x2xf32>
%26 = affine.load %arg1[0, symbol(%arg5), symbol(%arg6)] : memref<1x4x2xf32>
%27 = mulf %25, %26 : f32
%28 = affine.load %8[] : memref<f32>
%29 = addf %28, %27 : f32
affine.store %29, %8[] : memref<f32>
%30 = affine.load %arg2[0, symbol(%arg5), symbol(%arg6)] : memref<1x?x?xf32>
%31 = mulf %6, %30 : f32
%32 = affine.load %10[] : memref<f32>
%33 = addf %32, %31 : f32
affine.store %33, %10[] : memref<f32>
}
%12 = affine.load %8[] : memref<f32>
%13 = affine.load %10[] : memref<f32>
%14 = addf %12, %13 : f32
%15 = alloc(%dim2) : memref<?xi8>
%16 = "krnl.getref"(%15, %c0_i64) : (memref<?xi8>, i64) -> memref<f32>
affine.store %14, %16[] : memref<f32>
%17 = affine.load %16[] : memref<f32>
%18 = subf %cst, %17 : f32
%19 = exp %17 : f32
%20 = exp %18 : f32
%21 = subf %19, %20 : f32
%22 = addf %19, %20 : f32
%23 = divf %21, %22 : f32
affine.store %23, %5[] : memref<f32>
%24 = affine.load %5[] : memref<f32>
affine.store %24, %0[0, symbol(%arg4), symbol(%arg5)] : memref<1x3x?xf32>
dealloc %15 : memref<?xi8>
dealloc %9 : memref<?xi8>
dealloc %7 : memref<?xi8>
dealloc %4 : memref<?xi8>
}
}
return %0 : memref<1x3x?xf32>
// CHECK-LABEL: test_dynamic_pool_rnn
// CHECK: [[C1:%.+]] = constant 1 : index
// CHECK: [[C2:%.+]] = constant 2 : index
// CHECK: [[C0:%.+]] = constant 0 : i64
// CHECK: krnl.define_loops 1
// CHECK: krnl.iterate
// CHECK: krnl.define_loops 2
// CHECK: krnl.iterate
// CHECK: [[DIM1:%.+]] = dim %arg2, [[C1]] : memref<1x?x?xf32>
// CHECK: [[DIM2:%.+]] = dim %arg2, [[C2]] : memref<1x?x?xf32>
// CHECK: [[ADD1:%.+]] = addi [[DIM2]], [[DIM2]] : index
// CHECK: [[OFFSET1:%.+]] = index_cast [[DIM2]] : index to i64
// CHECK: [[ADD2:%.+]] = addi [[ADD1]], [[DIM1]] : index
// CHECK: [[OFFSET2:%.+]] = index_cast [[ADD1]] : index to i64
// CHECK: [[ADD3:%.+]] = addi [[ADD2]], [[DIM1]] : index
// CHECK: [[OFFSET3:%.+]] = index_cast [[ADD2]] : index to i64
// CHECK: [[DYNAMIC_MEMORY_POOL:%.+]] = alloc([[ADD3]]) : memref<?xi8>
// CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYNAMIC_MEMORY_POOL]], [[C0]]) : (memref<?xi8>, i64) -> memref<f32>
// CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYNAMIC_MEMORY_POOL]], [[OFFSET3]]) : (memref<?xi8>, i64) -> memref<f32>
// CHECK: affine.load
// CHECK: [[DATA3:%.+]] = "krnl.getref"([[DYNAMIC_MEMORY_POOL]], [[OFFSET2]]) : (memref<?xi8>, i64) -> memref<f32>
// CHECK: affine.store
// CHECK: [[DATA4:%.+]] = "krnl.getref"([[DYNAMIC_MEMORY_POOL]], [[OFFSET1]]) : (memref<?xi8>, i64) -> memref<f32>
// CHECK: affine.store
// CHECK: krnl.define_loops
// CHECK: krnl.iterate
// CHECK: dealloc [[DYNAMIC_MEMORY_POOL]] : memref<?xi8>
}