Enable bundling of dynamic memory pools on a block basis. (#330)
* Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Fix dynamic bundling. * Format. * Fix test.
This commit is contained in:
parent
931127c7e9
commit
7bfb5c93c1
|
@ -26,23 +26,12 @@ using namespace mlir;
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Insertion point for initialization instructions and the blocks used for
|
// Data structures for managing memory pools.
|
||||||
// inserting the initialization and main code. These blocks will disappear
|
|
||||||
// when the first canonicalization is performed because the init block
|
|
||||||
// unconditionally branches into the second block. These blocks exist only for
|
|
||||||
// the purpose of this optimization.
|
|
||||||
// The information is recorded on a per function basis.
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
typedef struct ONNXOperandsInitState {
|
// Data structure for managing dyanmic memory pool.
|
||||||
Block *initBlock;
|
typedef std::map<Block *, AllocOp> BlockToDynamicPool;
|
||||||
Block *mainBlock;
|
std::map<FuncOp, std::unique_ptr<BlockToDynamicPool>> dynamicPoolMap;
|
||||||
BranchOp branchInit;
|
|
||||||
AllocOp dynamicMemoryPool;
|
|
||||||
} ONNXOperandsInitState;
|
|
||||||
|
|
||||||
// Helper data structure for the bundling of dynamic AllocOps.
|
|
||||||
std::map<FuncOp, std::unique_ptr<ONNXOperandsInitState>> initMap;
|
|
||||||
|
|
||||||
// Handling of static memory pool on a block-basis in each function.
|
// Handling of static memory pool on a block-basis in each function.
|
||||||
typedef std::map<Block *, AllocOp> BlockToStaticPool;
|
typedef std::map<Block *, AllocOp> BlockToStaticPool;
|
||||||
|
@ -63,65 +52,24 @@ FuncOp getContainingFunction(AllocOp op) {
|
||||||
return cast<FuncOp>(parentFuncOp);
|
return cast<FuncOp>(parentFuncOp);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool hasInitBlock(FuncOp function) {
|
// Check if this value is an argument of one of the blocks nested
|
||||||
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function);
|
// around it.
|
||||||
return initState->initBlock != nullptr;
|
bool isBlockArgument(AllocOp allocOp, Value operand) {
|
||||||
}
|
// Parent operation of the current block.
|
||||||
|
Operation *parentBlockOp;
|
||||||
|
Block *currentBlock = allocOp.getOperation()->getBlock();
|
||||||
|
|
||||||
bool addInitBlock(PatternRewriter &rewriter, Location loc, AllocOp allocOp) {
|
do {
|
||||||
// If this is the first time we encounter an operation in this
|
// Check the arguments of the current block.
|
||||||
// function, we create an entry inside the initMap and split the
|
for (auto arg : currentBlock->getArguments())
|
||||||
// function body into an init block and a main block.
|
if (operand == arg)
|
||||||
//
|
return true;
|
||||||
// function func_name() {
|
|
||||||
// ... init block ...
|
|
||||||
// br ^bb1
|
|
||||||
// ^bb1: // pred: ^bb0
|
|
||||||
// ... main block ...
|
|
||||||
// return
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// Note: the block ^bb0 being the first block has its label omitted.
|
|
||||||
//
|
|
||||||
FuncOp function = getContainingFunction(allocOp);
|
|
||||||
// If the function does not contain an init block, create one.
|
|
||||||
if (!hasInitBlock(function)) {
|
|
||||||
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function);
|
|
||||||
initState = std::make_unique<ONNXOperandsInitState>();
|
|
||||||
|
|
||||||
// All input arguments are considered as part of the initialization block
|
parentBlockOp = currentBlock->getParentOp();
|
||||||
// so add them to the operandsInInitBlock set.
|
currentBlock = parentBlockOp->getBlock();
|
||||||
Block *functionBlock = &function.front();
|
|
||||||
PatternRewriter::InsertionGuard insertGuard(rewriter);
|
|
||||||
rewriter.setInsertionPointToStart(functionBlock);
|
|
||||||
|
|
||||||
initState->initBlock = rewriter.getInsertionBlock();
|
} while (!llvm::dyn_cast_or_null<FuncOp>(parentBlockOp));
|
||||||
auto currentPoint = rewriter.getInsertionPoint();
|
|
||||||
initState->mainBlock =
|
|
||||||
rewriter.splitBlock(initState->initBlock, currentPoint);
|
|
||||||
|
|
||||||
rewriter.setInsertionPointToEnd(initState->initBlock);
|
|
||||||
|
|
||||||
// Insert a branch operation from initBlock to mainBlock. This
|
|
||||||
// ensures the final code contains legal blocks.
|
|
||||||
initState->branchInit =
|
|
||||||
rewriter.create<BranchOp>(loc, initState->mainBlock);
|
|
||||||
|
|
||||||
rewriter.setInsertionPointToStart(initState->mainBlock);
|
|
||||||
|
|
||||||
// Save a reference to the current dynamic memory pool value.
|
|
||||||
initState->dynamicMemoryPool = allocOp;
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isBlockArgument(Block *block, Value operand) {
|
|
||||||
for (auto arg : block->getArguments())
|
|
||||||
if (operand == arg)
|
|
||||||
return true;
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -332,14 +280,17 @@ public:
|
||||||
|
|
||||||
// Get function.
|
// Get function.
|
||||||
FuncOp function = getContainingFunction(allocOp);
|
FuncOp function = getContainingFunction(allocOp);
|
||||||
Block *firstBlock = &function.getBody().front();
|
|
||||||
|
|
||||||
// If this is the alloc representing the memory pool and the function
|
// Use function to retrieve the list of blocks for this function.
|
||||||
// already has an init block, pattern matching must fail to avoid
|
std::unique_ptr<BlockToDynamicPool> &blockToDynamicPool =
|
||||||
// processing the dynamic memory pool a second time.
|
dynamicPoolMap.at(function);
|
||||||
if (hasInitBlock(function)) {
|
|
||||||
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function);
|
// If this is not the first time we process an alloc in this block, avoid
|
||||||
if (allocOp == initState->dynamicMemoryPool)
|
// processing the current dynamic memory pool again.
|
||||||
|
if (blockToDynamicPool->count(parentBlock) > 0) {
|
||||||
|
std::unique_ptr<BlockToDynamicPool> &blockToDynamicPool =
|
||||||
|
dynamicPoolMap.at(function);
|
||||||
|
if (allocOp == blockToDynamicPool->at(parentBlock))
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -365,11 +316,10 @@ public:
|
||||||
dependentOps.insert(definingOperation);
|
dependentOps.insert(definingOperation);
|
||||||
|
|
||||||
// Add operands to work queue.
|
// Add operands to work queue.
|
||||||
// printf("Processing the args of the following op:\n");
|
|
||||||
for (const auto &operand : definingOperation->getOperands()) {
|
for (const auto &operand : definingOperation->getOperands()) {
|
||||||
// Check operand is not a block argument. If it is skip it, we
|
// Check operand is not a block argument. If it is skip it, we
|
||||||
// consider block arguments to be leafs.
|
// consider block arguments to be leafs.
|
||||||
if (!isBlockArgument(firstBlock, operand)) {
|
if (!isBlockArgument(allocOp, operand)) {
|
||||||
operandList.emplace_back(operand);
|
operandList.emplace_back(operand);
|
||||||
|
|
||||||
// Check if the current operation is a dim or a load and the
|
// Check if the current operation is a dim or a load and the
|
||||||
|
@ -416,15 +366,23 @@ public:
|
||||||
if (dependentOps.count(&op) > 0)
|
if (dependentOps.count(&op) > 0)
|
||||||
orderedDependentOps.emplace_back(&op);
|
orderedDependentOps.emplace_back(&op);
|
||||||
|
|
||||||
// If no dynamic alloc is in the trace of the dependent operations,
|
// If this is the first valid alloc we can bundle in this block, then we
|
||||||
// emit the size calculation in the init block, if one exists already,
|
// need to move it to the top of the block as it will consitute an
|
||||||
// if not, create the init block.
|
// insertion point for all other bundle-able AllocOps in the block.
|
||||||
bool addedInitBlock = addInitBlock(rewriter, loc, allocOp);
|
bool isFirstBundledAllocOp = blockToDynamicPool->count(parentBlock) == 0;
|
||||||
|
if (isFirstBundledAllocOp) {
|
||||||
|
allocOp.getOperation()->moveBefore(&parentBlock->front());
|
||||||
|
|
||||||
// Move the ordered dependent size calculation to the init block.
|
// Create new entry in the block map.
|
||||||
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function);
|
blockToDynamicPool->insert(
|
||||||
|
std::pair<Block *, AllocOp>(parentBlock, allocOp));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move the computation instructions at the start of the block.
|
||||||
|
AllocOp oldDynamicMemoryPool = blockToDynamicPool->at(parentBlock);
|
||||||
|
std::reverse(orderedDependentOps.begin(), orderedDependentOps.end());
|
||||||
for (auto &op : orderedDependentOps)
|
for (auto &op : orderedDependentOps)
|
||||||
op->moveBefore(initState->branchInit);
|
op->moveBefore(&parentBlock->front());
|
||||||
|
|
||||||
// Bundle MemRef type: <?xi8>
|
// Bundle MemRef type: <?xi8>
|
||||||
SmallVector<int64_t, 1> memPoolShape;
|
SmallVector<int64_t, 1> memPoolShape;
|
||||||
|
@ -438,16 +396,16 @@ public:
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Add the current alloc size to the current MemPool size.
|
// Add the current alloc size to the current MemPool size.
|
||||||
Value dynamicMemoryPoolSize = initState->dynamicMemoryPool.getOperand(0);
|
Value dynamicMemoryPoolSize = oldDynamicMemoryPool.getOperand(0);
|
||||||
if (addedInitBlock) {
|
if (isFirstBundledAllocOp) {
|
||||||
Value zero = emitConstantOp(rewriter, loc, rewriter.getIndexType(), 0);
|
Value zero = emitConstantOp(rewriter, loc, rewriter.getIndexType(), 0);
|
||||||
zero.getDefiningOp()->moveBefore(initState->branchInit);
|
zero.getDefiningOp()->moveBefore(oldDynamicMemoryPool);
|
||||||
dynamicMemoryPoolSize = zero;
|
dynamicMemoryPoolSize = zero;
|
||||||
}
|
}
|
||||||
|
|
||||||
AddIOp bundledAllocOperand = rewriter.create<AddIOp>(
|
AddIOp bundledAllocOperand = rewriter.create<AddIOp>(
|
||||||
loc, dynamicMemoryPoolSize, allocOp.getOperand(0));
|
loc, dynamicMemoryPoolSize, allocOp.getOperand(0));
|
||||||
bundledAllocOperand.getOperation()->moveBefore(initState->branchInit);
|
bundledAllocOperand.getOperation()->moveBefore(oldDynamicMemoryPool);
|
||||||
|
|
||||||
// The newly bundled MemRef expressed as a KrnlGetRefOp.
|
// The newly bundled MemRef expressed as a KrnlGetRefOp.
|
||||||
// Current memory pool size is the offset for the newly bundled
|
// Current memory pool size is the offset for the newly bundled
|
||||||
|
@ -455,26 +413,27 @@ public:
|
||||||
Value integerDynamicMemoryPoolSize = rewriter.create<IndexCastOp>(
|
Value integerDynamicMemoryPoolSize = rewriter.create<IndexCastOp>(
|
||||||
loc, dynamicMemoryPoolSize, rewriter.getIntegerType(64));
|
loc, dynamicMemoryPoolSize, rewriter.getIntegerType(64));
|
||||||
integerDynamicMemoryPoolSize.getDefiningOp()->moveBefore(
|
integerDynamicMemoryPoolSize.getDefiningOp()->moveBefore(
|
||||||
initState->branchInit);
|
oldDynamicMemoryPool);
|
||||||
|
|
||||||
// We need to emit a new alloc which contains the additional MemRef.
|
// We need to emit a new alloc which contains the additional MemRef.
|
||||||
AllocOp bundledAlloc = rewriter.create<AllocOp>(
|
AllocOp bundledAlloc = rewriter.create<AllocOp>(
|
||||||
loc, bundledMemPoolMemRefType, bundledAllocOperand.getResult());
|
loc, bundledMemPoolMemRefType, bundledAllocOperand.getResult());
|
||||||
bundledAlloc.getOperation()->moveBefore(&initState->mainBlock->front());
|
bundledAlloc.getOperation()->moveBefore(oldDynamicMemoryPool);
|
||||||
|
|
||||||
KrnlGetRefOp bundledMemRef = rewriter.create<KrnlGetRefOp>(loc,
|
KrnlGetRefOp bundledMemRef = rewriter.create<KrnlGetRefOp>(loc,
|
||||||
currentAllocGetRef.getResult().getType(), bundledAlloc,
|
currentAllocGetRef.getResult().getType(), bundledAlloc,
|
||||||
integerDynamicMemoryPoolSize);
|
integerDynamicMemoryPoolSize);
|
||||||
bundledMemRef.getOperation()->moveAfter(bundledAlloc);
|
|
||||||
|
|
||||||
// Replace old memory pool with new one.
|
// Replace old memory pool with new one.
|
||||||
rewriter.replaceOp(initState->dynamicMemoryPool, bundledAlloc.getResult());
|
rewriter.replaceOp(oldDynamicMemoryPool, bundledAlloc.getResult());
|
||||||
|
|
||||||
// Replace old getref with new getref from new memory pool.
|
// Replace old getref with new getref from new memory pool.
|
||||||
rewriter.replaceOp(currentAllocGetRef, bundledMemRef.getResult());
|
rewriter.replaceOp(currentAllocGetRef, bundledMemRef.getResult());
|
||||||
|
|
||||||
// Update MemPool size.
|
// Update MemPool data structure.
|
||||||
initState->dynamicMemoryPool = bundledAlloc;
|
blockToDynamicPool->erase(parentBlock);
|
||||||
|
blockToDynamicPool->insert(
|
||||||
|
std::pair<Block *, AllocOp>(parentBlock, bundledAlloc));
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -488,16 +447,14 @@ class KrnlBundleMemoryPoolsPass
|
||||||
public:
|
public:
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
auto function = getFunction();
|
auto function = getFunction();
|
||||||
initMap.insert(std::pair<FuncOp, std::unique_ptr<ONNXOperandsInitState>>(
|
|
||||||
function, std::make_unique<ONNXOperandsInitState>()));
|
dynamicPoolMap.insert(
|
||||||
|
std::pair<FuncOp, std::unique_ptr<BlockToDynamicPool>>(
|
||||||
|
function, std::make_unique<BlockToDynamicPool>()));
|
||||||
|
|
||||||
staticPoolMap.insert(std::pair<FuncOp, std::unique_ptr<BlockToStaticPool>>(
|
staticPoolMap.insert(std::pair<FuncOp, std::unique_ptr<BlockToStaticPool>>(
|
||||||
function, std::make_unique<BlockToStaticPool>()));
|
function, std::make_unique<BlockToStaticPool>()));
|
||||||
|
|
||||||
// Initialize state for this function.
|
|
||||||
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function);
|
|
||||||
initState->initBlock = nullptr;
|
|
||||||
|
|
||||||
ConversionTarget target(getContext());
|
ConversionTarget target(getContext());
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
patterns.insert<KrnlBundleStaticMemoryPools, KrnlBundleDynamicMemoryPools>(
|
patterns.insert<KrnlBundleStaticMemoryPools, KrnlBundleDynamicMemoryPools>(
|
||||||
|
@ -505,7 +462,7 @@ public:
|
||||||
|
|
||||||
applyPatternsAndFoldGreedily(function, patterns);
|
applyPatternsAndFoldGreedily(function, patterns);
|
||||||
|
|
||||||
initMap.erase(function);
|
dynamicPoolMap.erase(function);
|
||||||
staticPoolMap.erase(function);
|
staticPoolMap.erase(function);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: onnx-mlir-opt --bundle-memory-pools --canonicalize %s -split-input-file | FileCheck %s
|
// RUN: onnx-mlir-opt --bundle-memory-pools --canonicalize %s | FileCheck %s
|
||||||
|
|
||||||
func @test_pool_bundling(%arg0: memref<10x10xf32>, %arg1: memref<10x20xf32>) -> memref<10x20xf32> {
|
func @test_pool_bundling(%arg0: memref<10x10xf32>, %arg1: memref<10x20xf32>) -> memref<10x20xf32> {
|
||||||
%c0_i64 = constant 0 : i64
|
%c0_i64 = constant 0 : i64
|
||||||
|
@ -88,17 +88,17 @@ func @test_dynamic_pool_bundling(%arg0: memref<?x?xf32>) -> memref<?x10xf32> {
|
||||||
// CHECK: [[C10:%.+]] = constant 10 : index
|
// CHECK: [[C10:%.+]] = constant 10 : index
|
||||||
// CHECK: [[C0_I64:%.+]] = constant 0 : i64
|
// CHECK: [[C0_I64:%.+]] = constant 0 : i64
|
||||||
// CHECK: [[DIM:%.+]] = dim %arg0, [[C0]] : memref<?x?xf32>
|
// CHECK: [[DIM:%.+]] = dim %arg0, [[C0]] : memref<?x?xf32>
|
||||||
|
// CHECK: [[MUL2:%.+]] = muli [[DIM]], [[C4]] : index
|
||||||
|
// CHECK: [[OFFSET2:%.+]] = muli [[MUL2]], [[C10]] : index
|
||||||
// CHECK: [[SGT:%.+]] = cmpi "sgt", [[DIM]], [[DIM]] : index
|
// CHECK: [[SGT:%.+]] = cmpi "sgt", [[DIM]], [[DIM]] : index
|
||||||
// CHECK: [[SELECT:%.+]] = select [[SGT]], [[DIM]], [[DIM]] : index
|
// CHECK: [[SELECT:%.+]] = select [[SGT]], [[DIM]], [[DIM]] : index
|
||||||
// CHECK: [[MUL1:%.+]] = muli [[SELECT]], [[C4]] : index
|
// CHECK: [[MUL1:%.+]] = muli [[SELECT]], [[C4]] : index
|
||||||
// CHECK: [[OFFSET1:%.+]] = muli [[MUL1]], [[C10]] : index
|
// CHECK: [[OFFSET1:%.+]] = muli [[MUL1]], [[C10]] : index
|
||||||
// CHECK: [[MUL2:%.+]] = muli [[DIM]], [[C4]] : index
|
|
||||||
// CHECK: [[OFFSET2:%.+]] = muli [[MUL2]], [[C10]] : index
|
|
||||||
// CHECK: [[MEMPOOL_SIZE:%.+]] = addi [[OFFSET1]], [[OFFSET2]] : index
|
// CHECK: [[MEMPOOL_SIZE:%.+]] = addi [[OFFSET1]], [[OFFSET2]] : index
|
||||||
// CHECK: [[OFFSET1_I64:%.+]] = index_cast [[OFFSET1]] : index to i64
|
// CHECK: [[OFFSET1_I64:%.+]] = index_cast [[OFFSET1]] : index to i64
|
||||||
// CHECK: [[DYN_MEMPOOL:%.+]] = alloc([[MEMPOOL_SIZE]]) : memref<?xi8>
|
// CHECK: [[DYN_MEMPOOL:%.+]] = alloc([[MEMPOOL_SIZE]]) : memref<?xi8>
|
||||||
// CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
|
||||||
// CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[C0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
// CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[C0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
||||||
|
// CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
||||||
// CHECK: [[RES:%.+]] = alloc([[DIM]]) : memref<?x10xf32>
|
// CHECK: [[RES:%.+]] = alloc([[DIM]]) : memref<?x10xf32>
|
||||||
// CHECK: affine.store [[CST]], [[DATA1]][0, 0] : memref<?x10xf32>
|
// CHECK: affine.store [[CST]], [[DATA1]][0, 0] : memref<?x10xf32>
|
||||||
// CHECK: affine.store [[CST]], [[DATA2]][0, 0] : memref<?x10xf32>
|
// CHECK: affine.store [[CST]], [[DATA2]][0, 0] : memref<?x10xf32>
|
||||||
|
@ -149,36 +149,36 @@ func @test_dynamic_and_static_pool_bundling(%arg0: memref<?x?xf32>, %arg1: memre
|
||||||
return %15 : memref<?x10xf32>
|
return %15 : memref<?x10xf32>
|
||||||
|
|
||||||
// CHECK-LABEL: test_dynamic_and_static_pool_bundling
|
// CHECK-LABEL: test_dynamic_and_static_pool_bundling
|
||||||
// CHECK: [[C1200_I64:%.+]] = constant 1200 : i64
|
|
||||||
// CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32
|
// CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32
|
||||||
// CHECK: [[C0:%.+]] = constant 0 : index
|
// CHECK: [[C0:%.+]] = constant 0 : index
|
||||||
// CHECK: [[C4:%.+]] = constant 4 : index
|
// CHECK: [[C4:%.+]] = constant 4 : index
|
||||||
// CHECK: [[C10:%.+]] = constant 10 : index
|
// CHECK: [[C10:%.+]] = constant 10 : index
|
||||||
// CHECK: [[C400_I64:%.+]] = constant 400 : i64
|
// CHECK: [[C2000_I64:%.+]] = constant 2000 : i64
|
||||||
|
// CHECK: [[C1600_I64:%.+]] = constant 1600 : i64
|
||||||
// CHECK: [[C0_I64:%.+]] = constant 0 : i64
|
// CHECK: [[C0_I64:%.+]] = constant 0 : i64
|
||||||
// CHECK: [[DIM:%.+]] = dim %arg0, [[C0]] : memref<?x?xf32>
|
// CHECK: [[DIM:%.+]] = dim %arg0, [[C0]] : memref<?x?xf32>
|
||||||
|
// CHECK: [[MUL2:%.+]] = muli [[DIM]], [[C4]] : index
|
||||||
|
// CHECK: [[OFFSET2:%.+]] = muli [[MUL2]], [[C10]] : index
|
||||||
// CHECK: [[SGT:%.+]] = cmpi "sgt", [[DIM]], [[DIM]] : index
|
// CHECK: [[SGT:%.+]] = cmpi "sgt", [[DIM]], [[DIM]] : index
|
||||||
// CHECK: [[SELECT:%.+]] = select [[SGT]], [[DIM]], [[DIM]] : index
|
// CHECK: [[SELECT:%.+]] = select [[SGT]], [[DIM]], [[DIM]] : index
|
||||||
// CHECK: [[MUL1:%.+]] = muli [[SELECT]], [[C4]] : index
|
// CHECK: [[MUL1:%.+]] = muli [[SELECT]], [[C4]] : index
|
||||||
// CHECK: [[OFFSET1:%.+]] = muli [[MUL1]], [[C10]] : index
|
// CHECK: [[OFFSET1:%.+]] = muli [[MUL1]], [[C10]] : index
|
||||||
// CHECK: [[MUL2:%.+]] = muli [[DIM]], [[C4]] : index
|
|
||||||
// CHECK: [[OFFSET2:%.+]] = muli [[MUL2]], [[C10]] : index
|
|
||||||
// CHECK: [[MEMPOOL_SIZE:%.+]] = addi [[OFFSET1]], [[OFFSET2]] : index
|
// CHECK: [[MEMPOOL_SIZE:%.+]] = addi [[OFFSET1]], [[OFFSET2]] : index
|
||||||
// CHECK: [[OFFSET1_I64:%.+]] = index_cast [[OFFSET1]] : index to i64
|
// CHECK: [[OFFSET1_I64:%.+]] = index_cast [[OFFSET1]] : index to i64
|
||||||
// CHECK: [[DYN_MEMPOOL:%.+]] = alloc([[MEMPOOL_SIZE]]) : memref<?xi8>
|
// CHECK: [[DYN_MEMPOOL:%.+]] = alloc([[MEMPOOL_SIZE]]) : memref<?xi8>
|
||||||
// CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
|
||||||
// CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[C0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
// CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[C0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
||||||
|
// CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
||||||
// CHECK: [[STATIC_MEMPOOL:%.+]] = alloc() : memref<2800xi8>
|
// CHECK: [[STATIC_MEMPOOL:%.+]] = alloc() : memref<2800xi8>
|
||||||
// CHECK: [[DATA3:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C1200_I64]]) : (memref<2800xi8>, i64) -> memref<10x40xf32>
|
// CHECK: [[DATA3:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C2000_I64]]) : (memref<2800xi8>, i64) -> memref<10x20xf32>
|
||||||
// CHECK: [[DATA4:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C400_I64]]) : (memref<2800xi8>, i64) -> memref<10x20xf32>
|
// CHECK: [[DATA4:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C1600_I64]]) : (memref<2800xi8>, i64) -> memref<10x10xf32>
|
||||||
// CHECK: [[DATA5:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C0_I64]]) : (memref<2800xi8>, i64) -> memref<10x10xf32>
|
|
||||||
// CHECK: [[RES:%.+]] = alloc([[DIM]]) : memref<?x10xf32>
|
// CHECK: [[RES:%.+]] = alloc([[DIM]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DATA5:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C0_I64]]) : (memref<2800xi8>, i64) -> memref<10x40xf32>
|
||||||
// CHECK: affine.store [[CST]], [[DATA1]][0, 0] : memref<?x10xf32>
|
// CHECK: affine.store [[CST]], [[DATA1]][0, 0] : memref<?x10xf32>
|
||||||
// CHECK: affine.store [[CST]], [[DATA2]][0, 0] : memref<?x10xf32>
|
// CHECK: affine.store [[CST]], [[DATA2]][0, 0] : memref<?x10xf32>
|
||||||
// CHECK: affine.store [[CST]], [[RES]][0, 0] : memref<?x10xf32>
|
// CHECK: affine.store [[CST]], [[RES]][0, 0] : memref<?x10xf32>
|
||||||
// CHECK: affine.store [[CST]], [[DATA5]][0, 0] : memref<10x10xf32>
|
// CHECK: affine.store [[CST]], [[DATA4]][0, 0] : memref<10x10xf32>
|
||||||
// CHECK: affine.store [[CST]], [[DATA4]][0, 0] : memref<10x20xf32>
|
// CHECK: affine.store [[CST]], [[DATA3]][0, 0] : memref<10x20xf32>
|
||||||
// CHECK: affine.store [[CST]], [[DATA3]][0, 0] : memref<10x40xf32>
|
// CHECK: affine.store [[CST]], [[DATA5]][0, 0] : memref<10x40xf32>
|
||||||
// CHECK: dealloc [[DYN_MEMPOOL]] : memref<?xi8>
|
// CHECK: dealloc [[DYN_MEMPOOL]] : memref<?xi8>
|
||||||
// CHECK: dealloc [[STATIC_MEMPOOL]] : memref<2800xi8>
|
// CHECK: dealloc [[STATIC_MEMPOOL]] : memref<2800xi8>
|
||||||
// CHECK: return [[RES]] : memref<?x10xf32>
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
|
@ -365,3 +365,97 @@ func @static_mem_pool_rnn_sub_and_main_block(%arg0: memref<1x3x2xf32>, %arg1: me
|
||||||
// CHECK: dealloc [[STATIC_MEM_POOL_MAIN]] : memref<12xi8>
|
// CHECK: dealloc [[STATIC_MEM_POOL_MAIN]] : memref<12xi8>
|
||||||
// CHECK: return [[RES]] : memref<1x3x4xf32>
|
// CHECK: return [[RES]] : memref<1x3x4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test dynamic pooling in sub-block.
|
||||||
|
func @test_dynamic_pool_rnn(%arg0: memref<1x3x2xf32>, %arg1: memref<1x4x2xf32>, %arg2: memref<1x?x?xf32>) -> memref<1x3x?xf32> attributes {input_names = ["X", "W", "R"], output_names = ["Y"]} {
|
||||||
|
%cst = constant 0.000000e+00 : f32
|
||||||
|
%c0_i64 = constant 0 : i64
|
||||||
|
%c_ind = constant 1 : index
|
||||||
|
%dim0 = dim %arg1, %c_ind : memref<1x4x2xf32>
|
||||||
|
%0 = alloc(%dim0) : memref<1x3x?xf32>
|
||||||
|
%1:3 = krnl.define_loops 3
|
||||||
|
krnl.iterate(%1#0, %1#1, %1#2) with (%1#0 -> %arg3 = 0 to 1, %1#1 -> %arg4 = 0 to 3, %1#2 -> %arg5 = 0 to 4) {
|
||||||
|
affine.store %cst, %0[symbol(%arg3), symbol(%arg4), symbol(%arg5)] : memref<1x3x?xf32>
|
||||||
|
}
|
||||||
|
%2 = krnl.define_loops 1
|
||||||
|
krnl.iterate(%2) with (%2 -> %arg3 = 0 to 1) {
|
||||||
|
%3:2 = krnl.define_loops 2
|
||||||
|
krnl.iterate(%3#0, %3#1) with (%3#0 -> %arg4 = 0 to 3, %3#1 -> %arg5 = 0 to 4) {
|
||||||
|
%dim1 = dim %arg2, %c_ind : memref<1x?x?xf32>
|
||||||
|
%4 = alloc(%dim1) : memref<?xi8>
|
||||||
|
%5 = "krnl.getref"(%4, %c0_i64) : (memref<?xi8>, i64) -> memref<f32>
|
||||||
|
%6 = affine.load %0[0, symbol(%arg4), symbol(%arg5)] : memref<1x3x?xf32>
|
||||||
|
%7 = alloc(%dim1) : memref<?xi8>
|
||||||
|
%8 = "krnl.getref"(%7, %c0_i64) : (memref<?xi8>, i64) -> memref<f32>
|
||||||
|
affine.store %cst, %8[] : memref<f32>
|
||||||
|
%c_ind_2 = constant 2 : index
|
||||||
|
%dim2 = dim %arg2, %c_ind_2 : memref<1x?x?xf32>
|
||||||
|
%9 = alloc(%dim2) : memref<?xi8>
|
||||||
|
%10 = "krnl.getref"(%9, %c0_i64) : (memref<?xi8>, i64) -> memref<f32>
|
||||||
|
affine.store %cst, %10[] : memref<f32>
|
||||||
|
%11 = krnl.define_loops 1
|
||||||
|
krnl.iterate(%11) with (%11 -> %arg6 = 0 to 2) {
|
||||||
|
%25 = affine.load %arg0[symbol(%arg3), symbol(%arg4), symbol(%arg6)] : memref<1x3x2xf32>
|
||||||
|
%26 = affine.load %arg1[0, symbol(%arg5), symbol(%arg6)] : memref<1x4x2xf32>
|
||||||
|
%27 = mulf %25, %26 : f32
|
||||||
|
%28 = affine.load %8[] : memref<f32>
|
||||||
|
%29 = addf %28, %27 : f32
|
||||||
|
affine.store %29, %8[] : memref<f32>
|
||||||
|
%30 = affine.load %arg2[0, symbol(%arg5), symbol(%arg6)] : memref<1x?x?xf32>
|
||||||
|
%31 = mulf %6, %30 : f32
|
||||||
|
%32 = affine.load %10[] : memref<f32>
|
||||||
|
%33 = addf %32, %31 : f32
|
||||||
|
affine.store %33, %10[] : memref<f32>
|
||||||
|
}
|
||||||
|
%12 = affine.load %8[] : memref<f32>
|
||||||
|
%13 = affine.load %10[] : memref<f32>
|
||||||
|
%14 = addf %12, %13 : f32
|
||||||
|
%15 = alloc(%dim2) : memref<?xi8>
|
||||||
|
%16 = "krnl.getref"(%15, %c0_i64) : (memref<?xi8>, i64) -> memref<f32>
|
||||||
|
affine.store %14, %16[] : memref<f32>
|
||||||
|
%17 = affine.load %16[] : memref<f32>
|
||||||
|
%18 = subf %cst, %17 : f32
|
||||||
|
%19 = exp %17 : f32
|
||||||
|
%20 = exp %18 : f32
|
||||||
|
%21 = subf %19, %20 : f32
|
||||||
|
%22 = addf %19, %20 : f32
|
||||||
|
%23 = divf %21, %22 : f32
|
||||||
|
affine.store %23, %5[] : memref<f32>
|
||||||
|
%24 = affine.load %5[] : memref<f32>
|
||||||
|
affine.store %24, %0[0, symbol(%arg4), symbol(%arg5)] : memref<1x3x?xf32>
|
||||||
|
dealloc %15 : memref<?xi8>
|
||||||
|
dealloc %9 : memref<?xi8>
|
||||||
|
dealloc %7 : memref<?xi8>
|
||||||
|
dealloc %4 : memref<?xi8>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return %0 : memref<1x3x?xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_dynamic_pool_rnn
|
||||||
|
// CHECK: [[C1:%.+]] = constant 1 : index
|
||||||
|
// CHECK: [[C2:%.+]] = constant 2 : index
|
||||||
|
// CHECK: [[C0:%.+]] = constant 0 : i64
|
||||||
|
// CHECK: krnl.define_loops 1
|
||||||
|
// CHECK: krnl.iterate
|
||||||
|
// CHECK: krnl.define_loops 2
|
||||||
|
// CHECK: krnl.iterate
|
||||||
|
// CHECK: [[DIM1:%.+]] = dim %arg2, [[C1]] : memref<1x?x?xf32>
|
||||||
|
// CHECK: [[DIM2:%.+]] = dim %arg2, [[C2]] : memref<1x?x?xf32>
|
||||||
|
// CHECK: [[ADD1:%.+]] = addi [[DIM2]], [[DIM2]] : index
|
||||||
|
// CHECK: [[OFFSET1:%.+]] = index_cast [[DIM2]] : index to i64
|
||||||
|
// CHECK: [[ADD2:%.+]] = addi [[ADD1]], [[DIM1]] : index
|
||||||
|
// CHECK: [[OFFSET2:%.+]] = index_cast [[ADD1]] : index to i64
|
||||||
|
// CHECK: [[ADD3:%.+]] = addi [[ADD2]], [[DIM1]] : index
|
||||||
|
// CHECK: [[OFFSET3:%.+]] = index_cast [[ADD2]] : index to i64
|
||||||
|
// CHECK: [[DYNAMIC_MEMORY_POOL:%.+]] = alloc([[ADD3]]) : memref<?xi8>
|
||||||
|
// CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYNAMIC_MEMORY_POOL]], [[C0]]) : (memref<?xi8>, i64) -> memref<f32>
|
||||||
|
// CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYNAMIC_MEMORY_POOL]], [[OFFSET3]]) : (memref<?xi8>, i64) -> memref<f32>
|
||||||
|
// CHECK: affine.load
|
||||||
|
// CHECK: [[DATA3:%.+]] = "krnl.getref"([[DYNAMIC_MEMORY_POOL]], [[OFFSET2]]) : (memref<?xi8>, i64) -> memref<f32>
|
||||||
|
// CHECK: affine.store
|
||||||
|
// CHECK: [[DATA4:%.+]] = "krnl.getref"([[DYNAMIC_MEMORY_POOL]], [[OFFSET1]]) : (memref<?xi8>, i64) -> memref<f32>
|
||||||
|
// CHECK: affine.store
|
||||||
|
// CHECK: krnl.define_loops
|
||||||
|
// CHECK: krnl.iterate
|
||||||
|
// CHECK: dealloc [[DYNAMIC_MEMORY_POOL]] : memref<?xi8>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue