Emit the dynamic memory pool (#290)

* Reorganize main function.

* Follow review comments.

* Emit constants are globals in Krnl and LLVM dialects.

* Add support for bundling dynamic memory pools.

* Add dynamic bundling.

* Clean-up code.

* Clean-up file.

* Add test for bundling dynamic memory pool.

* Fixes. Simplify data structure. Add mixed test.

* Remove unused import.
This commit is contained in:
Gheorghe-Teodor Bercea 2020-09-03 10:31:06 -04:00 committed by GitHub
parent 930e20f682
commit 9f69b2f317
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 460 additions and 4 deletions

View File

@ -15,6 +15,7 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SetVector.h"
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "src/Dialect/Krnl/KrnlOps.hpp"
@ -24,6 +25,102 @@ using namespace mlir;
namespace {
//===----------------------------------------------------------------------===//
// Insertion point for initialization instructions and the blocks used for
// inserting the initialization and main code. These blocks will disappear
// when the first canonicalization is performed because the init block
// unconditionally branches into the second block. These blocks exist only for
// the purpose of this optimization.
// The information is recorded on a per function basis.
//===----------------------------------------------------------------------===//
typedef struct ONNXOperandsInitState {
Block *initBlock;
Block *mainBlock;
BranchOp branchInit;
AllocOp dynamicMemoryPool;
} ONNXOperandsInitState;
// Helper data structure for the bundling of dynamic AllocOps.
std::map<FuncOp, std::unique_ptr<ONNXOperandsInitState>> initMap;
//===----------------------------------------------------------------------===//
// Helper functions.
//===----------------------------------------------------------------------===//
/// Retrieve function which contains the current operation.
FuncOp getContainingFunction(AllocOp op) {
Operation *parentFuncOp = op.getParentOp();
// While parent is not a FuncOp and its cast to a FuncOp is null.
while (!llvm::dyn_cast_or_null<FuncOp>(parentFuncOp))
parentFuncOp = parentFuncOp->getParentOp();
return cast<FuncOp>(parentFuncOp);
}
bool hasInitBlock(FuncOp function) {
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function);
return initState->initBlock != nullptr;
}
bool addInitBlock(PatternRewriter &rewriter, Location loc, AllocOp allocOp) {
// If this is the first time we encounter an operation in this
// function, we create an entry inside the initMap and split the
// function body into an init block and a main block.
//
// function func_name() {
// ... init block ...
// br ^bb1
// ^bb1: // pred: ^bb0
// ... main block ...
// return
// }
//
// Note: the block ^bb0 being the first block has its label omitted.
//
FuncOp function = getContainingFunction(allocOp);
// If the function does not contain an init block, create one.
if (!hasInitBlock(function)) {
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function);
initState = std::make_unique<ONNXOperandsInitState>();
// All input arguments are considered as part of the initialization block
// so add them to the operandsInInitBlock set.
Block *functionBlock = &function.front();
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(functionBlock);
initState->initBlock = rewriter.getInsertionBlock();
auto currentPoint = rewriter.getInsertionPoint();
initState->mainBlock =
rewriter.splitBlock(initState->initBlock, currentPoint);
rewriter.setInsertionPointToEnd(initState->initBlock);
// Insert a branch operation from initBlock to mainBlock. This
// ensures the final code contains legal blocks.
initState->branchInit =
rewriter.create<BranchOp>(loc, initState->mainBlock);
rewriter.setInsertionPointToStart(initState->mainBlock);
// Save a reference to the current dynamic memory pool value.
initState->dynamicMemoryPool = allocOp;
return true;
}
return false;
}
bool isBlockArgument(Block *block, Value operand) {
for (auto arg : block->getArguments())
if (operand == arg)
return true;
return false;
}
KrnlGetRefOp getUnbundledGetRef(AllocOp *memPool) {
auto parentBlock = memPool->getOperation()->getBlock();
@ -37,6 +134,23 @@ KrnlGetRefOp getUnbundledGetRef(AllocOp *memPool) {
return unbundledGetRef;
}
KrnlGetRefOp getCurrentAllocGetRef(AllocOp *allocOp) {
auto parentBlock = allocOp->getOperation()->getBlock();
KrnlGetRefOp currentAllocGetRef = nullptr;
parentBlock->walk([&currentAllocGetRef, allocOp](KrnlGetRefOp op) {
auto result = allocOp->getResult();
if (op.getOperands()[0] == result)
currentAllocGetRef = op;
});
return currentAllocGetRef;
}
//===----------------------------------------------------------------------===//
// Rewrite patterns.
//===----------------------------------------------------------------------===//
/*!
* RewritePattern that replaces:
* %mem1 = alloc() : memref<<dims1>x<type>>
@ -73,7 +187,7 @@ public:
if (!checkOpResultIsUsedByGetRef(&allocOp))
return failure();
// TODO: remove once we support the bundling of dynamic memory pools.
// Only handle constant AllocOps.
if (!hasAllConstantDimensions(memRefType))
return failure();
@ -85,12 +199,16 @@ public:
if (memRefShape.size() != 1)
return failure();
// TODO: Change this when dyanmic shapes are supported.
// TODO: Add support for dynamic shapes.
int64_t currentMemPoolSize = memRefShape[0];
// Get a KrnlGetRefOp which does not use the current alloc.
if (KrnlGetRefOp unbundledGetRef = getUnbundledGetRef(&allocOp)) {
// Make sure that this get ref uses a static alloc.
auto unbundledGetRefType =
convertToMemRefType(unbundledGetRef.getResult().getType());
if (!hasAllConstantDimensions(unbundledGetRefType))
return failure();
// Current memory pool size is the offset for the newly bundled
// internal MemRef. Emit the offset as a constant.
auto offset = rewriter.create<ConstantOp>(
@ -127,6 +245,201 @@ public:
}
};
/*!
* RewritePattern that merges a new dynamic AllocOp with the existing dynamic
* memory pool.
* %dyn_mempool = alloc(%a) : memref<?xi8>
* %new_alloc = alloc(%b) : memref<?xi8>
* %new_ref = krnl.getref %new_alloc 0 : memref<?xi8>
* =>
* %c = addi %a, %b
* %dyn_mempool = alloc(%c) : memref<?xi8>
* %new_ref = krnl.getref %dyn_mempool %a : memref<?xi8>
*/
class KrnlBundleDynamicMemoryPools : public OpRewritePattern<AllocOp> {
public:
using OpRewritePattern<AllocOp>::OpRewritePattern;
LogicalResult matchAndRewrite(
AllocOp allocOp, PatternRewriter &rewriter) const override {
auto loc = allocOp.getLoc();
auto memRefType = convertToMemRefType(allocOp.getResult().getType());
auto memRefShape = memRefType.getShape();
// If alloca result is not used by getref then it cannot be part of
// the memory pool.
if (!checkOpResultIsUsedByGetRef(&allocOp))
return failure();
// Only handle dynamic allocs here.
if (hasAllConstantDimensions(memRefType))
return failure();
// Alloc memory type must be byte.
if (getMemRefEltSizeInBytes(memRefType) != 1)
return failure();
// Rank of the allocated MemRef must be 1.
if (memRefShape.size() != 1)
return failure();
// Visit dependendent operations in the current parent block and assemble
// a trace of operations which participate in the computation of the size
// of the AllocOp.
auto parentBlock = allocOp.getOperation()->getBlock();
// Get function.
FuncOp function = getContainingFunction(allocOp);
Block *firstBlock = &function.getBody().front();
// If this is the alloc representing the memory pool and the function
// already has an init block, pattern matching must fail to avoid
// processing the dynamic memory pool a second time.
if (hasInitBlock(function)) {
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function);
if (allocOp == initState->dynamicMemoryPool)
return failure();
}
// Initialize work queue data structure.
Operation *op = allocOp.getOperation();
std::vector<Value> operandList;
for (const auto &operand : allocOp.getOperands()) {
operandList.emplace_back(operand);
}
// Check if list of operations depends on dynamic local AllocOp.
bool dependsOnLocalDynamicAlloc = false;
// Construct the list of Values on which the current AllocOp depends on.
llvm::SetVector<Operation *> dependentOps;
while (operandList.size() > 0) {
Value currentElement = operandList[0];
Operation *definingOperation = currentElement.getDefiningOp();
// If this value has not been seen before, process it.
if (dependentOps.count(definingOperation) == 0) {
// Add value to dependent values list.
dependentOps.insert(definingOperation);
// Add operands to work queue.
// printf("Processing the args of the following op:\n");
for (const auto &operand : definingOperation->getOperands()) {
// Check operand is not a block argument. If it is skip it, we
// consider block arguments to be leafs.
if (!isBlockArgument(firstBlock, operand)) {
operandList.emplace_back(operand);
// Check if the current operation is a dim or a load and the
// argument list involves a local AllocOp with dynamic sizes.
// If that's the case then it means that the whole set of
// instructions cannot be moved.
// Check if the current operation is a DimOp or a LoadOp.
if (llvm::dyn_cast<DimOp>(definingOperation) ||
llvm::dyn_cast<LoadOp>(definingOperation)) {
Operation *operandOp = operand.getDefiningOp();
if (operandOp) {
auto localAlloc = llvm::dyn_cast<AllocOp>(operandOp);
if (localAlloc) {
auto memRefType =
convertToMemRefType(localAlloc.getResult().getType());
if (!hasAllConstantDimensions(memRefType))
dependsOnLocalDynamicAlloc = true;
}
// If operand is a getref then this alloc cannot be bundled.
auto memPool = llvm::dyn_cast<KrnlGetRefOp>(operandOp);
if (memPool)
dependsOnLocalDynamicAlloc = true;
}
}
}
}
}
// Erase first element from work queue.
operandList.erase(operandList.begin());
}
if (dependsOnLocalDynamicAlloc)
return failure();
// Order the dependent values in the same order they appear in the code.
// One cannot iterate over and make changes to the order of the operations
// of a block. A temporary ordered list of dependent instructions is
// necessary.
llvm::SmallVector<Operation *, 32> orderedDependentOps;
for (auto &op :
llvm::make_range(parentBlock->begin(), std::prev(parentBlock->end())))
if (dependentOps.count(&op) > 0)
orderedDependentOps.emplace_back(&op);
// If no dynamic alloc is in the trace of the dependent operations,
// emit the size calculation in the init block, if one exists already,
// if not, create the init block.
bool addedInitBlock = addInitBlock(rewriter, loc, allocOp);
// Move the ordered dependent size calculation to the init block.
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function);
for (auto &op : orderedDependentOps)
op->moveBefore(initState->branchInit);
// Bundle MemRef type: <?xi8>
SmallVector<int64_t, 1> memPoolShape;
memPoolShape.emplace_back(-1);
auto bundledMemPoolMemRefType =
MemRefType::get(memPoolShape, rewriter.getIntegerType(8));
// Get the getref of the current allocOp. There is exactly one such getref.
KrnlGetRefOp currentAllocGetRef = getCurrentAllocGetRef(&allocOp);
if (!currentAllocGetRef)
return failure();
// Add the current alloc size to the current MemPool size.
Value dynamicMemoryPoolSize = initState->dynamicMemoryPool.getOperand(0);
if (addedInitBlock) {
Value zero = emitConstantOp(rewriter, loc, rewriter.getIndexType(), 0);
zero.getDefiningOp()->moveBefore(initState->branchInit);
dynamicMemoryPoolSize = zero;
}
AddIOp bundledAllocOperand = rewriter.create<AddIOp>(
loc, dynamicMemoryPoolSize, allocOp.getOperand(0));
bundledAllocOperand.getOperation()->moveBefore(initState->branchInit);
// The newly bundled MemRef expressed as a KrnlGetRefOp.
// Current memory pool size is the offset for the newly bundled
// internal MemRef.
Value integerDynamicMemoryPoolSize = rewriter.create<IndexCastOp>(
loc, dynamicMemoryPoolSize, rewriter.getIntegerType(64));
integerDynamicMemoryPoolSize.getDefiningOp()->moveBefore(
initState->branchInit);
// We need to emit a new alloc which contains the additional MemRef.
AllocOp bundledAlloc = rewriter.create<AllocOp>(
loc, bundledMemPoolMemRefType, bundledAllocOperand.getResult());
bundledAlloc.getOperation()->moveBefore(&initState->mainBlock->front());
KrnlGetRefOp bundledMemRef = rewriter.create<KrnlGetRefOp>(loc,
currentAllocGetRef.getResult().getType(), bundledAlloc,
integerDynamicMemoryPoolSize);
bundledMemRef.getOperation()->moveAfter(bundledAlloc);
// Replace old memory pool with new one.
rewriter.replaceOp(initState->dynamicMemoryPool, bundledAlloc.getResult());
// Replace old getref with new getref from new memory pool.
rewriter.replaceOp(currentAllocGetRef, bundledMemRef.getResult());
// Update MemPool size.
initState->dynamicMemoryPool = bundledAlloc;
return success();
}
};
/*!
* Function pass that enables memory pooling for MemRefs.
*/
@ -136,11 +449,22 @@ public:
void runOnFunction() override {
auto function = getFunction();
// ModuleOp module = cast<ModuleOp>(function.getParentOp());
initMap.insert(std::pair<FuncOp, std::unique_ptr<ONNXOperandsInitState>>(
function, std::make_unique<ONNXOperandsInitState>()));
// Initialize state for this function.
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function);
initState->initBlock = nullptr;
ConversionTarget target(getContext());
OwningRewritePatternList patterns;
patterns.insert<KrnlBundleMemoryPools>(&getContext());
patterns.insert<KrnlBundleMemoryPools, KrnlBundleDynamicMemoryPools>(
&getContext());
applyPatternsAndFoldGreedily(function, patterns);
initMap.erase(function);
}
};
} // namespace

View File

@ -51,3 +51,135 @@ func @test_pool_bundling(%arg0: memref<10x10xf32>, %arg1: memref<10x20xf32>) ->
// CHECK: dealloc [[MEMPOOL]] : memref<3200xi8>
// CHECK: return [[RES]] : memref<10x20xf32>
}
func @test_dynamic_pool_bundling(%arg0: memref<?x?xf32>) -> memref<?x10xf32> {
%c1 = constant 1 : index
%c0 = constant 0 : index
%cst = constant 0.000000e+00 : f32
%ind = constant 0 : index
%c4 = constant 4 : index
%c10 = constant 10 : index
%c0_i64 = constant 0 : i64
%0 = dim %arg0, %c0 : memref<?x?xf32>
%1 = muli %0, %c4 : index
%2 = muli %1, %c10 : index
%3 = alloc(%2) : memref<?xi8>
%4 = "krnl.getref"(%3, %c0_i64) : (memref<?xi8>, i64) -> memref<?x10xf32>
%6 = cmpi "sgt", %0, %0 : index
%7 = select %6, %0, %0 : index
%8 = muli %7, %c4 : index
%9 = muli %8, %c10 : index
%10 = alloc(%9) : memref<?xi8>
%11 = "krnl.getref"(%10, %c0_i64) : (memref<?xi8>, i64) -> memref<?x10xf32>
%12 = cmpi "eq", %0, %c1 : index
%13 = cmpi "eq", %0, %c1 : index
%15 = alloc(%0) : memref<?x10xf32>
affine.store %cst, %4[%ind, %ind] : memref<?x10xf32>
affine.store %cst, %11[%ind, %ind] : memref<?x10xf32>
affine.store %cst, %15[%ind, %ind] : memref<?x10xf32>
dealloc %10 : memref<?xi8>
dealloc %3 : memref<?xi8>
return %15 : memref<?x10xf32>
// CHECK-LABEL: test_dynamic_pool_bundling
// CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32
// CHECK: [[C0:%.+]] = constant 0 : index
// CHECK: [[C4:%.+]] = constant 4 : index
// CHECK: [[C10:%.+]] = constant 10 : index
// CHECK: [[C0_I64:%.+]] = constant 0 : i64
// CHECK: [[DIM:%.+]] = dim %arg0, [[C0]] : memref<?x?xf32>
// CHECK: [[SGT:%.+]] = cmpi "sgt", [[DIM]], [[DIM]] : index
// CHECK: [[SELECT:%.+]] = select [[SGT]], [[DIM]], [[DIM]] : index
// CHECK: [[MUL1:%.+]] = muli [[SELECT]], [[C4]] : index
// CHECK: [[OFFSET1:%.+]] = muli [[MUL1]], [[C10]] : index
// CHECK: [[MUL2:%.+]] = muli [[DIM]], [[C4]] : index
// CHECK: [[OFFSET2:%.+]] = muli [[MUL2]], [[C10]] : index
// CHECK: [[MEMPOOL_SIZE:%.+]] = addi [[OFFSET1]], [[OFFSET2]] : index
// CHECK: [[OFFSET1_I64:%.+]] = index_cast [[OFFSET1]] : index to i64
// CHECK: [[DYN_MEMPOOL:%.+]] = alloc([[MEMPOOL_SIZE]]) : memref<?xi8>
// CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
// CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[C0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM]]) : memref<?x10xf32>
// CHECK: affine.store [[CST]], [[DATA1]][0, 0] : memref<?x10xf32>
// CHECK: affine.store [[CST]], [[DATA2]][0, 0] : memref<?x10xf32>
// CHECK: affine.store [[CST]], [[RES]][0, 0] : memref<?x10xf32>
// CHECK: dealloc [[DYN_MEMPOOL]] : memref<?xi8>
// CHECK: return [[RES]] : memref<?x10xf32>
}
func @test_dynamic_and_static_pool_bundling(%arg0: memref<?x?xf32>, %arg1: memref<10x10xf32>) -> memref<?x10xf32> {
%c1 = constant 1 : index
%c0 = constant 0 : index
%cst = constant 0.000000e+00 : f32
%ind = constant 0 : index
%c4 = constant 4 : index
%c10 = constant 10 : index
%c0_i64 = constant 0 : i64
%0 = dim %arg0, %c0 : memref<?x?xf32>
%1 = muli %0, %c4 : index
%2 = muli %1, %c10 : index
%3 = alloc(%2) : memref<?xi8>
%4 = "krnl.getref"(%3, %c0_i64) : (memref<?xi8>, i64) -> memref<?x10xf32>
%const_alloc1 = alloc() : memref<800xi8>
%const_ref1 = "krnl.getref"(%const_alloc1, %c0_i64) : (memref<800xi8>, i64) -> memref<10x20xf32>
%const_alloc2 = alloc() : memref<400xi8>
%const_ref2 = "krnl.getref"(%const_alloc2, %c0_i64) : (memref<400xi8>, i64) -> memref<10x10xf32>
%6 = cmpi "sgt", %0, %0 : index
%7 = select %6, %0, %0 : index
%8 = muli %7, %c4 : index
%9 = muli %8, %c10 : index
%10 = alloc(%9) : memref<?xi8>
%11 = "krnl.getref"(%10, %c0_i64) : (memref<?xi8>, i64) -> memref<?x10xf32>
%12 = cmpi "eq", %0, %c1 : index
%13 = cmpi "eq", %0, %c1 : index
%15 = alloc(%0) : memref<?x10xf32>
%const_alloc3 = alloc() : memref<1600xi8>
%const_ref3 = "krnl.getref"(%const_alloc3, %c0_i64) : (memref<1600xi8>, i64) -> memref<10x40xf32>
affine.store %cst, %4[%ind, %ind] : memref<?x10xf32>
affine.store %cst, %11[%ind, %ind] : memref<?x10xf32>
affine.store %cst, %15[%ind, %ind] : memref<?x10xf32>
affine.store %cst, %const_ref2[%ind, %ind] : memref<10x10xf32>
affine.store %cst, %const_ref1[%ind, %ind] : memref<10x20xf32>
affine.store %cst, %const_ref3[%ind, %ind] : memref<10x40xf32>
dealloc %10 : memref<?xi8>
dealloc %3 : memref<?xi8>
dealloc %const_alloc1 : memref<800xi8>
dealloc %const_alloc2 : memref<400xi8>
dealloc %const_alloc3 : memref<1600xi8>
return %15 : memref<?x10xf32>
// CHECK-LABEL: test_dynamic_and_static_pool_bundling
// CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32
// CHECK: [[C0:%.+]] = constant 0 : index
// CHECK: [[C4:%.+]] = constant 4 : index
// CHECK: [[C10:%.+]] = constant 10 : index
// CHECK: [[C400_I64:%.+]] = constant 400 : i64
// CHECK: [[C2000_I64:%.+]] = constant 2000 : i64
// CHECK: [[C0_I64:%.+]] = constant 0 : i64
// CHECK: [[DIM:%.+]] = dim %arg0, [[C0]] : memref<?x?xf32>
// CHECK: [[SGT:%.+]] = cmpi "sgt", [[DIM]], [[DIM]] : index
// CHECK: [[SELECT:%.+]] = select [[SGT]], [[DIM]], [[DIM]] : index
// CHECK: [[MUL1:%.+]] = muli [[SELECT]], [[C4]] : index
// CHECK: [[OFFSET1:%.+]] = muli [[MUL1]], [[C10]] : index
// CHECK: [[MUL2:%.+]] = muli [[DIM]], [[C4]] : index
// CHECK: [[OFFSET2:%.+]] = muli [[MUL2]], [[C10]] : index
// CHECK: [[MEMPOOL_SIZE:%.+]] = addi [[OFFSET1]], [[OFFSET2]] : index
// CHECK: [[OFFSET1_I64:%.+]] = index_cast [[OFFSET1]] : index to i64
// CHECK: [[DYN_MEMPOOL:%.+]] = alloc([[MEMPOOL_SIZE]]) : memref<?xi8>
// CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
// CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[C0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
// CHECK: [[STATIC_MEMPOOL:%.+]] = alloc() : memref<2800xi8>
// CHECK: [[DATA3:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C2000_I64]]) : (memref<2800xi8>, i64) -> memref<10x20xf32>
// CHECK: [[DATA4:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C400_I64]]) : (memref<2800xi8>, i64) -> memref<10x40xf32>
// CHECK: [[DATA5:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C0_I64]]) : (memref<2800xi8>, i64) -> memref<10x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM]]) : memref<?x10xf32>
// CHECK: affine.store [[CST]], [[DATA1]][0, 0] : memref<?x10xf32>
// CHECK: affine.store [[CST]], [[DATA2]][0, 0] : memref<?x10xf32>
// CHECK: affine.store [[CST]], [[RES]][0, 0] : memref<?x10xf32>
// CHECK: affine.store [[CST]], [[DATA5]][0, 0] : memref<10x10xf32>
// CHECK: affine.store [[CST]], [[DATA3]][0, 0] : memref<10x20xf32>
// CHECK: affine.store [[CST]], [[DATA4]][0, 0] : memref<10x40xf32>
// CHECK: dealloc [[DYN_MEMPOOL]] : memref<?xi8>
// CHECK: dealloc [[STATIC_MEMPOOL]] : memref<2800xi8>
// CHECK: return [[RES]] : memref<?x10xf32>
}