Enable bundling of dynamic memory pools on a block basis. (#330)
* Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Fix dynamic bundling. * Format. * Fix test.
This commit is contained in:
parent
931127c7e9
commit
7bfb5c93c1
|
@ -26,23 +26,12 @@ using namespace mlir;
|
|||
namespace {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Insertion point for initialization instructions and the blocks used for
|
||||
// inserting the initialization and main code. These blocks will disappear
|
||||
// when the first canonicalization is performed because the init block
|
||||
// unconditionally branches into the second block. These blocks exist only for
|
||||
// the purpose of this optimization.
|
||||
// The information is recorded on a per function basis.
|
||||
// Data structures for managing memory pools.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
typedef struct ONNXOperandsInitState {
|
||||
Block *initBlock;
|
||||
Block *mainBlock;
|
||||
BranchOp branchInit;
|
||||
AllocOp dynamicMemoryPool;
|
||||
} ONNXOperandsInitState;
|
||||
|
||||
// Helper data structure for the bundling of dynamic AllocOps.
|
||||
std::map<FuncOp, std::unique_ptr<ONNXOperandsInitState>> initMap;
|
||||
// Data structure for managing dyanmic memory pool.
|
||||
typedef std::map<Block *, AllocOp> BlockToDynamicPool;
|
||||
std::map<FuncOp, std::unique_ptr<BlockToDynamicPool>> dynamicPoolMap;
|
||||
|
||||
// Handling of static memory pool on a block-basis in each function.
|
||||
typedef std::map<Block *, AllocOp> BlockToStaticPool;
|
||||
|
@ -63,65 +52,24 @@ FuncOp getContainingFunction(AllocOp op) {
|
|||
return cast<FuncOp>(parentFuncOp);
|
||||
}
|
||||
|
||||
bool hasInitBlock(FuncOp function) {
|
||||
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function);
|
||||
return initState->initBlock != nullptr;
|
||||
}
|
||||
// Check if this value is an argument of one of the blocks nested
|
||||
// around it.
|
||||
bool isBlockArgument(AllocOp allocOp, Value operand) {
|
||||
// Parent operation of the current block.
|
||||
Operation *parentBlockOp;
|
||||
Block *currentBlock = allocOp.getOperation()->getBlock();
|
||||
|
||||
bool addInitBlock(PatternRewriter &rewriter, Location loc, AllocOp allocOp) {
|
||||
// If this is the first time we encounter an operation in this
|
||||
// function, we create an entry inside the initMap and split the
|
||||
// function body into an init block and a main block.
|
||||
//
|
||||
// function func_name() {
|
||||
// ... init block ...
|
||||
// br ^bb1
|
||||
// ^bb1: // pred: ^bb0
|
||||
// ... main block ...
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// Note: the block ^bb0 being the first block has its label omitted.
|
||||
//
|
||||
FuncOp function = getContainingFunction(allocOp);
|
||||
// If the function does not contain an init block, create one.
|
||||
if (!hasInitBlock(function)) {
|
||||
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function);
|
||||
initState = std::make_unique<ONNXOperandsInitState>();
|
||||
do {
|
||||
// Check the arguments of the current block.
|
||||
for (auto arg : currentBlock->getArguments())
|
||||
if (operand == arg)
|
||||
return true;
|
||||
|
||||
// All input arguments are considered as part of the initialization block
|
||||
// so add them to the operandsInInitBlock set.
|
||||
Block *functionBlock = &function.front();
|
||||
PatternRewriter::InsertionGuard insertGuard(rewriter);
|
||||
rewriter.setInsertionPointToStart(functionBlock);
|
||||
parentBlockOp = currentBlock->getParentOp();
|
||||
currentBlock = parentBlockOp->getBlock();
|
||||
|
||||
initState->initBlock = rewriter.getInsertionBlock();
|
||||
auto currentPoint = rewriter.getInsertionPoint();
|
||||
initState->mainBlock =
|
||||
rewriter.splitBlock(initState->initBlock, currentPoint);
|
||||
} while (!llvm::dyn_cast_or_null<FuncOp>(parentBlockOp));
|
||||
|
||||
rewriter.setInsertionPointToEnd(initState->initBlock);
|
||||
|
||||
// Insert a branch operation from initBlock to mainBlock. This
|
||||
// ensures the final code contains legal blocks.
|
||||
initState->branchInit =
|
||||
rewriter.create<BranchOp>(loc, initState->mainBlock);
|
||||
|
||||
rewriter.setInsertionPointToStart(initState->mainBlock);
|
||||
|
||||
// Save a reference to the current dynamic memory pool value.
|
||||
initState->dynamicMemoryPool = allocOp;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool isBlockArgument(Block *block, Value operand) {
|
||||
for (auto arg : block->getArguments())
|
||||
if (operand == arg)
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -332,14 +280,17 @@ public:
|
|||
|
||||
// Get function.
|
||||
FuncOp function = getContainingFunction(allocOp);
|
||||
Block *firstBlock = &function.getBody().front();
|
||||
|
||||
// If this is the alloc representing the memory pool and the function
|
||||
// already has an init block, pattern matching must fail to avoid
|
||||
// processing the dynamic memory pool a second time.
|
||||
if (hasInitBlock(function)) {
|
||||
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function);
|
||||
if (allocOp == initState->dynamicMemoryPool)
|
||||
// Use function to retrieve the list of blocks for this function.
|
||||
std::unique_ptr<BlockToDynamicPool> &blockToDynamicPool =
|
||||
dynamicPoolMap.at(function);
|
||||
|
||||
// If this is not the first time we process an alloc in this block, avoid
|
||||
// processing the current dynamic memory pool again.
|
||||
if (blockToDynamicPool->count(parentBlock) > 0) {
|
||||
std::unique_ptr<BlockToDynamicPool> &blockToDynamicPool =
|
||||
dynamicPoolMap.at(function);
|
||||
if (allocOp == blockToDynamicPool->at(parentBlock))
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
@ -365,11 +316,10 @@ public:
|
|||
dependentOps.insert(definingOperation);
|
||||
|
||||
// Add operands to work queue.
|
||||
// printf("Processing the args of the following op:\n");
|
||||
for (const auto &operand : definingOperation->getOperands()) {
|
||||
// Check operand is not a block argument. If it is skip it, we
|
||||
// consider block arguments to be leafs.
|
||||
if (!isBlockArgument(firstBlock, operand)) {
|
||||
if (!isBlockArgument(allocOp, operand)) {
|
||||
operandList.emplace_back(operand);
|
||||
|
||||
// Check if the current operation is a dim or a load and the
|
||||
|
@ -416,15 +366,23 @@ public:
|
|||
if (dependentOps.count(&op) > 0)
|
||||
orderedDependentOps.emplace_back(&op);
|
||||
|
||||
// If no dynamic alloc is in the trace of the dependent operations,
|
||||
// emit the size calculation in the init block, if one exists already,
|
||||
// if not, create the init block.
|
||||
bool addedInitBlock = addInitBlock(rewriter, loc, allocOp);
|
||||
// If this is the first valid alloc we can bundle in this block, then we
|
||||
// need to move it to the top of the block as it will consitute an
|
||||
// insertion point for all other bundle-able AllocOps in the block.
|
||||
bool isFirstBundledAllocOp = blockToDynamicPool->count(parentBlock) == 0;
|
||||
if (isFirstBundledAllocOp) {
|
||||
allocOp.getOperation()->moveBefore(&parentBlock->front());
|
||||
|
||||
// Move the ordered dependent size calculation to the init block.
|
||||
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function);
|
||||
// Create new entry in the block map.
|
||||
blockToDynamicPool->insert(
|
||||
std::pair<Block *, AllocOp>(parentBlock, allocOp));
|
||||
}
|
||||
|
||||
// Move the computation instructions at the start of the block.
|
||||
AllocOp oldDynamicMemoryPool = blockToDynamicPool->at(parentBlock);
|
||||
std::reverse(orderedDependentOps.begin(), orderedDependentOps.end());
|
||||
for (auto &op : orderedDependentOps)
|
||||
op->moveBefore(initState->branchInit);
|
||||
op->moveBefore(&parentBlock->front());
|
||||
|
||||
// Bundle MemRef type: <?xi8>
|
||||
SmallVector<int64_t, 1> memPoolShape;
|
||||
|
@ -438,16 +396,16 @@ public:
|
|||
return failure();
|
||||
|
||||
// Add the current alloc size to the current MemPool size.
|
||||
Value dynamicMemoryPoolSize = initState->dynamicMemoryPool.getOperand(0);
|
||||
if (addedInitBlock) {
|
||||
Value dynamicMemoryPoolSize = oldDynamicMemoryPool.getOperand(0);
|
||||
if (isFirstBundledAllocOp) {
|
||||
Value zero = emitConstantOp(rewriter, loc, rewriter.getIndexType(), 0);
|
||||
zero.getDefiningOp()->moveBefore(initState->branchInit);
|
||||
zero.getDefiningOp()->moveBefore(oldDynamicMemoryPool);
|
||||
dynamicMemoryPoolSize = zero;
|
||||
}
|
||||
|
||||
AddIOp bundledAllocOperand = rewriter.create<AddIOp>(
|
||||
loc, dynamicMemoryPoolSize, allocOp.getOperand(0));
|
||||
bundledAllocOperand.getOperation()->moveBefore(initState->branchInit);
|
||||
bundledAllocOperand.getOperation()->moveBefore(oldDynamicMemoryPool);
|
||||
|
||||
// The newly bundled MemRef expressed as a KrnlGetRefOp.
|
||||
// Current memory pool size is the offset for the newly bundled
|
||||
|
@ -455,26 +413,27 @@ public:
|
|||
Value integerDynamicMemoryPoolSize = rewriter.create<IndexCastOp>(
|
||||
loc, dynamicMemoryPoolSize, rewriter.getIntegerType(64));
|
||||
integerDynamicMemoryPoolSize.getDefiningOp()->moveBefore(
|
||||
initState->branchInit);
|
||||
oldDynamicMemoryPool);
|
||||
|
||||
// We need to emit a new alloc which contains the additional MemRef.
|
||||
AllocOp bundledAlloc = rewriter.create<AllocOp>(
|
||||
loc, bundledMemPoolMemRefType, bundledAllocOperand.getResult());
|
||||
bundledAlloc.getOperation()->moveBefore(&initState->mainBlock->front());
|
||||
bundledAlloc.getOperation()->moveBefore(oldDynamicMemoryPool);
|
||||
|
||||
KrnlGetRefOp bundledMemRef = rewriter.create<KrnlGetRefOp>(loc,
|
||||
currentAllocGetRef.getResult().getType(), bundledAlloc,
|
||||
integerDynamicMemoryPoolSize);
|
||||
bundledMemRef.getOperation()->moveAfter(bundledAlloc);
|
||||
|
||||
// Replace old memory pool with new one.
|
||||
rewriter.replaceOp(initState->dynamicMemoryPool, bundledAlloc.getResult());
|
||||
rewriter.replaceOp(oldDynamicMemoryPool, bundledAlloc.getResult());
|
||||
|
||||
// Replace old getref with new getref from new memory pool.
|
||||
rewriter.replaceOp(currentAllocGetRef, bundledMemRef.getResult());
|
||||
|
||||
// Update MemPool size.
|
||||
initState->dynamicMemoryPool = bundledAlloc;
|
||||
// Update MemPool data structure.
|
||||
blockToDynamicPool->erase(parentBlock);
|
||||
blockToDynamicPool->insert(
|
||||
std::pair<Block *, AllocOp>(parentBlock, bundledAlloc));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -488,16 +447,14 @@ class KrnlBundleMemoryPoolsPass
|
|||
public:
|
||||
void runOnFunction() override {
|
||||
auto function = getFunction();
|
||||
initMap.insert(std::pair<FuncOp, std::unique_ptr<ONNXOperandsInitState>>(
|
||||
function, std::make_unique<ONNXOperandsInitState>()));
|
||||
|
||||
dynamicPoolMap.insert(
|
||||
std::pair<FuncOp, std::unique_ptr<BlockToDynamicPool>>(
|
||||
function, std::make_unique<BlockToDynamicPool>()));
|
||||
|
||||
staticPoolMap.insert(std::pair<FuncOp, std::unique_ptr<BlockToStaticPool>>(
|
||||
function, std::make_unique<BlockToStaticPool>()));
|
||||
|
||||
// Initialize state for this function.
|
||||
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function);
|
||||
initState->initBlock = nullptr;
|
||||
|
||||
ConversionTarget target(getContext());
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<KrnlBundleStaticMemoryPools, KrnlBundleDynamicMemoryPools>(
|
||||
|
@ -505,7 +462,7 @@ public:
|
|||
|
||||
applyPatternsAndFoldGreedily(function, patterns);
|
||||
|
||||
initMap.erase(function);
|
||||
dynamicPoolMap.erase(function);
|
||||
staticPoolMap.erase(function);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: onnx-mlir-opt --bundle-memory-pools --canonicalize %s -split-input-file | FileCheck %s
|
||||
// RUN: onnx-mlir-opt --bundle-memory-pools --canonicalize %s | FileCheck %s
|
||||
|
||||
func @test_pool_bundling(%arg0: memref<10x10xf32>, %arg1: memref<10x20xf32>) -> memref<10x20xf32> {
|
||||
%c0_i64 = constant 0 : i64
|
||||
|
@ -88,17 +88,17 @@ func @test_dynamic_pool_bundling(%arg0: memref<?x?xf32>) -> memref<?x10xf32> {
|
|||
// CHECK: [[C10:%.+]] = constant 10 : index
|
||||
// CHECK: [[C0_I64:%.+]] = constant 0 : i64
|
||||
// CHECK: [[DIM:%.+]] = dim %arg0, [[C0]] : memref<?x?xf32>
|
||||
// CHECK: [[MUL2:%.+]] = muli [[DIM]], [[C4]] : index
|
||||
// CHECK: [[OFFSET2:%.+]] = muli [[MUL2]], [[C10]] : index
|
||||
// CHECK: [[SGT:%.+]] = cmpi "sgt", [[DIM]], [[DIM]] : index
|
||||
// CHECK: [[SELECT:%.+]] = select [[SGT]], [[DIM]], [[DIM]] : index
|
||||
// CHECK: [[MUL1:%.+]] = muli [[SELECT]], [[C4]] : index
|
||||
// CHECK: [[OFFSET1:%.+]] = muli [[MUL1]], [[C10]] : index
|
||||
// CHECK: [[MUL2:%.+]] = muli [[DIM]], [[C4]] : index
|
||||
// CHECK: [[OFFSET2:%.+]] = muli [[MUL2]], [[C10]] : index
|
||||
// CHECK: [[MEMPOOL_SIZE:%.+]] = addi [[OFFSET1]], [[OFFSET2]] : index
|
||||
// CHECK: [[OFFSET1_I64:%.+]] = index_cast [[OFFSET1]] : index to i64
|
||||
// CHECK: [[DYN_MEMPOOL:%.+]] = alloc([[MEMPOOL_SIZE]]) : memref<?xi8>
|
||||
// CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
||||
// CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[C0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
||||
// CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
||||
// CHECK: [[RES:%.+]] = alloc([[DIM]]) : memref<?x10xf32>
|
||||
// CHECK: affine.store [[CST]], [[DATA1]][0, 0] : memref<?x10xf32>
|
||||
// CHECK: affine.store [[CST]], [[DATA2]][0, 0] : memref<?x10xf32>
|
||||
|
@ -149,36 +149,36 @@ func @test_dynamic_and_static_pool_bundling(%arg0: memref<?x?xf32>, %arg1: memre
|
|||
return %15 : memref<?x10xf32>
|
||||
|
||||
// CHECK-LABEL: test_dynamic_and_static_pool_bundling
|
||||
// CHECK: [[C1200_I64:%.+]] = constant 1200 : i64
|
||||
// CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: [[C0:%.+]] = constant 0 : index
|
||||
// CHECK: [[C4:%.+]] = constant 4 : index
|
||||
// CHECK: [[C10:%.+]] = constant 10 : index
|
||||
// CHECK: [[C400_I64:%.+]] = constant 400 : i64
|
||||
// CHECK: [[C2000_I64:%.+]] = constant 2000 : i64
|
||||
// CHECK: [[C1600_I64:%.+]] = constant 1600 : i64
|
||||
// CHECK: [[C0_I64:%.+]] = constant 0 : i64
|
||||
// CHECK: [[DIM:%.+]] = dim %arg0, [[C0]] : memref<?x?xf32>
|
||||
// CHECK: [[MUL2:%.+]] = muli [[DIM]], [[C4]] : index
|
||||
// CHECK: [[OFFSET2:%.+]] = muli [[MUL2]], [[C10]] : index
|
||||
// CHECK: [[SGT:%.+]] = cmpi "sgt", [[DIM]], [[DIM]] : index
|
||||
// CHECK: [[SELECT:%.+]] = select [[SGT]], [[DIM]], [[DIM]] : index
|
||||
// CHECK: [[MUL1:%.+]] = muli [[SELECT]], [[C4]] : index
|
||||
// CHECK: [[OFFSET1:%.+]] = muli [[MUL1]], [[C10]] : index
|
||||
// CHECK: [[MUL2:%.+]] = muli [[DIM]], [[C4]] : index
|
||||
// CHECK: [[OFFSET2:%.+]] = muli [[MUL2]], [[C10]] : index
|
||||
// CHECK: [[MEMPOOL_SIZE:%.+]] = addi [[OFFSET1]], [[OFFSET2]] : index
|
||||
// CHECK: [[OFFSET1_I64:%.+]] = index_cast [[OFFSET1]] : index to i64
|
||||
// CHECK: [[DYN_MEMPOOL:%.+]] = alloc([[MEMPOOL_SIZE]]) : memref<?xi8>
|
||||
// CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
||||
// CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[C0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
||||
// CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
||||
// CHECK: [[STATIC_MEMPOOL:%.+]] = alloc() : memref<2800xi8>
|
||||
// CHECK: [[DATA3:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C1200_I64]]) : (memref<2800xi8>, i64) -> memref<10x40xf32>
|
||||
// CHECK: [[DATA4:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C400_I64]]) : (memref<2800xi8>, i64) -> memref<10x20xf32>
|
||||
// CHECK: [[DATA5:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C0_I64]]) : (memref<2800xi8>, i64) -> memref<10x10xf32>
|
||||
// CHECK: [[DATA3:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C2000_I64]]) : (memref<2800xi8>, i64) -> memref<10x20xf32>
|
||||
// CHECK: [[DATA4:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C1600_I64]]) : (memref<2800xi8>, i64) -> memref<10x10xf32>
|
||||
// CHECK: [[RES:%.+]] = alloc([[DIM]]) : memref<?x10xf32>
|
||||
// CHECK: [[DATA5:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C0_I64]]) : (memref<2800xi8>, i64) -> memref<10x40xf32>
|
||||
// CHECK: affine.store [[CST]], [[DATA1]][0, 0] : memref<?x10xf32>
|
||||
// CHECK: affine.store [[CST]], [[DATA2]][0, 0] : memref<?x10xf32>
|
||||
// CHECK: affine.store [[CST]], [[RES]][0, 0] : memref<?x10xf32>
|
||||
// CHECK: affine.store [[CST]], [[DATA5]][0, 0] : memref<10x10xf32>
|
||||
// CHECK: affine.store [[CST]], [[DATA4]][0, 0] : memref<10x20xf32>
|
||||
// CHECK: affine.store [[CST]], [[DATA3]][0, 0] : memref<10x40xf32>
|
||||
// CHECK: affine.store [[CST]], [[DATA4]][0, 0] : memref<10x10xf32>
|
||||
// CHECK: affine.store [[CST]], [[DATA3]][0, 0] : memref<10x20xf32>
|
||||
// CHECK: affine.store [[CST]], [[DATA5]][0, 0] : memref<10x40xf32>
|
||||
// CHECK: dealloc [[DYN_MEMPOOL]] : memref<?xi8>
|
||||
// CHECK: dealloc [[STATIC_MEMPOOL]] : memref<2800xi8>
|
||||
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||
|
@ -365,3 +365,97 @@ func @static_mem_pool_rnn_sub_and_main_block(%arg0: memref<1x3x2xf32>, %arg1: me
|
|||
// CHECK: dealloc [[STATIC_MEM_POOL_MAIN]] : memref<12xi8>
|
||||
// CHECK: return [[RES]] : memref<1x3x4xf32>
|
||||
}
|
||||
|
||||
/// Test dynamic pooling in sub-block.
|
||||
func @test_dynamic_pool_rnn(%arg0: memref<1x3x2xf32>, %arg1: memref<1x4x2xf32>, %arg2: memref<1x?x?xf32>) -> memref<1x3x?xf32> attributes {input_names = ["X", "W", "R"], output_names = ["Y"]} {
|
||||
%cst = constant 0.000000e+00 : f32
|
||||
%c0_i64 = constant 0 : i64
|
||||
%c_ind = constant 1 : index
|
||||
%dim0 = dim %arg1, %c_ind : memref<1x4x2xf32>
|
||||
%0 = alloc(%dim0) : memref<1x3x?xf32>
|
||||
%1:3 = krnl.define_loops 3
|
||||
krnl.iterate(%1#0, %1#1, %1#2) with (%1#0 -> %arg3 = 0 to 1, %1#1 -> %arg4 = 0 to 3, %1#2 -> %arg5 = 0 to 4) {
|
||||
affine.store %cst, %0[symbol(%arg3), symbol(%arg4), symbol(%arg5)] : memref<1x3x?xf32>
|
||||
}
|
||||
%2 = krnl.define_loops 1
|
||||
krnl.iterate(%2) with (%2 -> %arg3 = 0 to 1) {
|
||||
%3:2 = krnl.define_loops 2
|
||||
krnl.iterate(%3#0, %3#1) with (%3#0 -> %arg4 = 0 to 3, %3#1 -> %arg5 = 0 to 4) {
|
||||
%dim1 = dim %arg2, %c_ind : memref<1x?x?xf32>
|
||||
%4 = alloc(%dim1) : memref<?xi8>
|
||||
%5 = "krnl.getref"(%4, %c0_i64) : (memref<?xi8>, i64) -> memref<f32>
|
||||
%6 = affine.load %0[0, symbol(%arg4), symbol(%arg5)] : memref<1x3x?xf32>
|
||||
%7 = alloc(%dim1) : memref<?xi8>
|
||||
%8 = "krnl.getref"(%7, %c0_i64) : (memref<?xi8>, i64) -> memref<f32>
|
||||
affine.store %cst, %8[] : memref<f32>
|
||||
%c_ind_2 = constant 2 : index
|
||||
%dim2 = dim %arg2, %c_ind_2 : memref<1x?x?xf32>
|
||||
%9 = alloc(%dim2) : memref<?xi8>
|
||||
%10 = "krnl.getref"(%9, %c0_i64) : (memref<?xi8>, i64) -> memref<f32>
|
||||
affine.store %cst, %10[] : memref<f32>
|
||||
%11 = krnl.define_loops 1
|
||||
krnl.iterate(%11) with (%11 -> %arg6 = 0 to 2) {
|
||||
%25 = affine.load %arg0[symbol(%arg3), symbol(%arg4), symbol(%arg6)] : memref<1x3x2xf32>
|
||||
%26 = affine.load %arg1[0, symbol(%arg5), symbol(%arg6)] : memref<1x4x2xf32>
|
||||
%27 = mulf %25, %26 : f32
|
||||
%28 = affine.load %8[] : memref<f32>
|
||||
%29 = addf %28, %27 : f32
|
||||
affine.store %29, %8[] : memref<f32>
|
||||
%30 = affine.load %arg2[0, symbol(%arg5), symbol(%arg6)] : memref<1x?x?xf32>
|
||||
%31 = mulf %6, %30 : f32
|
||||
%32 = affine.load %10[] : memref<f32>
|
||||
%33 = addf %32, %31 : f32
|
||||
affine.store %33, %10[] : memref<f32>
|
||||
}
|
||||
%12 = affine.load %8[] : memref<f32>
|
||||
%13 = affine.load %10[] : memref<f32>
|
||||
%14 = addf %12, %13 : f32
|
||||
%15 = alloc(%dim2) : memref<?xi8>
|
||||
%16 = "krnl.getref"(%15, %c0_i64) : (memref<?xi8>, i64) -> memref<f32>
|
||||
affine.store %14, %16[] : memref<f32>
|
||||
%17 = affine.load %16[] : memref<f32>
|
||||
%18 = subf %cst, %17 : f32
|
||||
%19 = exp %17 : f32
|
||||
%20 = exp %18 : f32
|
||||
%21 = subf %19, %20 : f32
|
||||
%22 = addf %19, %20 : f32
|
||||
%23 = divf %21, %22 : f32
|
||||
affine.store %23, %5[] : memref<f32>
|
||||
%24 = affine.load %5[] : memref<f32>
|
||||
affine.store %24, %0[0, symbol(%arg4), symbol(%arg5)] : memref<1x3x?xf32>
|
||||
dealloc %15 : memref<?xi8>
|
||||
dealloc %9 : memref<?xi8>
|
||||
dealloc %7 : memref<?xi8>
|
||||
dealloc %4 : memref<?xi8>
|
||||
}
|
||||
}
|
||||
return %0 : memref<1x3x?xf32>
|
||||
|
||||
// CHECK-LABEL: test_dynamic_pool_rnn
|
||||
// CHECK: [[C1:%.+]] = constant 1 : index
|
||||
// CHECK: [[C2:%.+]] = constant 2 : index
|
||||
// CHECK: [[C0:%.+]] = constant 0 : i64
|
||||
// CHECK: krnl.define_loops 1
|
||||
// CHECK: krnl.iterate
|
||||
// CHECK: krnl.define_loops 2
|
||||
// CHECK: krnl.iterate
|
||||
// CHECK: [[DIM1:%.+]] = dim %arg2, [[C1]] : memref<1x?x?xf32>
|
||||
// CHECK: [[DIM2:%.+]] = dim %arg2, [[C2]] : memref<1x?x?xf32>
|
||||
// CHECK: [[ADD1:%.+]] = addi [[DIM2]], [[DIM2]] : index
|
||||
// CHECK: [[OFFSET1:%.+]] = index_cast [[DIM2]] : index to i64
|
||||
// CHECK: [[ADD2:%.+]] = addi [[ADD1]], [[DIM1]] : index
|
||||
// CHECK: [[OFFSET2:%.+]] = index_cast [[ADD1]] : index to i64
|
||||
// CHECK: [[ADD3:%.+]] = addi [[ADD2]], [[DIM1]] : index
|
||||
// CHECK: [[OFFSET3:%.+]] = index_cast [[ADD2]] : index to i64
|
||||
// CHECK: [[DYNAMIC_MEMORY_POOL:%.+]] = alloc([[ADD3]]) : memref<?xi8>
|
||||
// CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYNAMIC_MEMORY_POOL]], [[C0]]) : (memref<?xi8>, i64) -> memref<f32>
|
||||
// CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYNAMIC_MEMORY_POOL]], [[OFFSET3]]) : (memref<?xi8>, i64) -> memref<f32>
|
||||
// CHECK: affine.load
|
||||
// CHECK: [[DATA3:%.+]] = "krnl.getref"([[DYNAMIC_MEMORY_POOL]], [[OFFSET2]]) : (memref<?xi8>, i64) -> memref<f32>
|
||||
// CHECK: affine.store
|
||||
// CHECK: [[DATA4:%.+]] = "krnl.getref"([[DYNAMIC_MEMORY_POOL]], [[OFFSET1]]) : (memref<?xi8>, i64) -> memref<f32>
|
||||
// CHECK: affine.store
|
||||
// CHECK: krnl.define_loops
|
||||
// CHECK: krnl.iterate
|
||||
// CHECK: dealloc [[DYNAMIC_MEMORY_POOL]] : memref<?xi8>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue