Assorted fixes for memory pool related passes (#244)

* Reorganize main function.

* Follow review comments.

* Emit constants are globals in Krnl and LLVM dialects.

* Make mempooling more robust.

* Fix.

* Update MainUtils.cpp

Additional canonicalization not required anymore.
This commit is contained in:
Gheorghe-Teodor Bercea 2020-08-11 17:34:59 -04:00 committed by GitHub
parent 2ee725d939
commit 2f41c2bf5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 20 additions and 9 deletions

View File

@ -1,5 +1,4 @@
//====------ ConvertONNXToKrnl.cpp - ONNX dialects to Krnl lowering //====------ ConvertONNXToKrnl.cpp - ONNX dialects to Krnl lowering -------===//
//--------===//
// //
// Copyright 2019 The IBM Research Authors. // Copyright 2019 The IBM Research Authors.
// //

View File

@ -43,6 +43,17 @@ MemRefType convertToMemRefType(Type type) {
return memRefType; return memRefType;
} }
/// Retrieve function which contains the current operation.
FuncOp getContainingFunction(Operation *op) {
Operation *parentFuncOp = op->getParentOp();
// While parent is not a FuncOp and its cast to a FuncOp is null.
while (!llvm::dyn_cast_or_null<FuncOp>(parentFuncOp))
parentFuncOp = parentFuncOp->getParentOp();
return cast<FuncOp>(parentFuncOp);
}
/// Insert an allocation and deallocation for the given MemRefType. /// Insert an allocation and deallocation for the given MemRefType.
Value insertAllocAndDealloc(MemRefType type, Location loc, Value insertAllocAndDealloc(MemRefType type, Location loc,
PatternRewriter &rewriter, bool insertDealloc, ArrayRef<Value> operands, PatternRewriter &rewriter, bool insertDealloc, ArrayRef<Value> operands,
@ -463,10 +474,10 @@ int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
} }
bool checkOpResultIsUsedByGetRef(AllocOp *allocOp) { bool checkOpResultIsUsedByGetRef(AllocOp *allocOp) {
auto parentBlock = allocOp->getOperation()->getBlock(); FuncOp function = getContainingFunction(allocOp->getOperation());
bool opIsUsedInGetRef = false; bool opIsUsedInGetRef = false;
parentBlock->walk([&opIsUsedInGetRef, allocOp](KrnlGetRefOp op) { function.walk([&opIsUsedInGetRef, allocOp](KrnlGetRefOp op) {
auto result = allocOp->getResult(); auto result = allocOp->getResult();
for (const auto &operand : op.getOperands()) for (const auto &operand : op.getOperands())
if (operand == result) if (operand == result)

View File

@ -42,6 +42,9 @@ bool hasAllScalarValues(ArrayRef<Value> values);
/// Get the corresponding MemRefType of a given TensorType/MemRefType. /// Get the corresponding MemRefType of a given TensorType/MemRefType.
MemRefType convertToMemRefType(Type type); MemRefType convertToMemRefType(Type type);
/// Retrieve function which contains the current operation.
FuncOp getContainingFunction(Operation *op);
/// Insert an allocation and deallocation for the given MemRefType. /// Insert an allocation and deallocation for the given MemRefType.
Value insertAllocAndDealloc(MemRefType type, Location loc, Value insertAllocAndDealloc(MemRefType type, Location loc,
PatternRewriter &rewriter, bool insertDealloc, PatternRewriter &rewriter, bool insertDealloc,

View File

@ -91,8 +91,6 @@ public:
// Get a KrnlGetRefOp which does not use the current alloc. // Get a KrnlGetRefOp which does not use the current alloc.
if (KrnlGetRefOp unbundledGetRef = getUnbundledGetRef(&allocOp)) { if (KrnlGetRefOp unbundledGetRef = getUnbundledGetRef(&allocOp)) {
unbundledGetRef.dump();
// Current memory pool size is the offset for the newly bundled // Current memory pool size is the offset for the newly bundled
// internal MemRef. Emit the offset as a constant. // internal MemRef. Emit the offset as a constant.
auto offset = rewriter.create<ConstantOp>( auto offset = rewriter.create<ConstantOp>(

View File

@ -24,10 +24,10 @@ using namespace mlir;
namespace { namespace {
bool checkOpResultIsReturned(AllocOp *allocOp) { bool checkOpResultIsReturned(AllocOp *allocOp) {
auto parentBlock = allocOp->getOperation()->getBlock(); FuncOp function = getContainingFunction(allocOp->getOperation());
bool opIsReturned = false; bool opIsReturned = false;
parentBlock->walk([&opIsReturned, allocOp](ReturnOp op) { function.walk([&opIsReturned, allocOp](ReturnOp op) {
auto result = allocOp->getResult(); auto result = allocOp->getResult();
for (const auto &operand : op.getOperands()) for (const auto &operand : op.getOperands())
if (operand == result) if (operand == result)

View File

@ -25,4 +25,4 @@ func @test_bundle_memory_pool(%arg0: tensor<10x10xf32>, %arg1: tensor<10x20xf32>
// CHECK: "krnl.getref"([[MEMPOOL]], [[CONST0]]) : (memref<3200xi8>, i64) -> memref<10x10xf32> // CHECK: "krnl.getref"([[MEMPOOL]], [[CONST0]]) : (memref<3200xi8>, i64) -> memref<10x10xf32>
// CHECK: dealloc [[MEMPOOL]] : memref<3200xi8> // CHECK: dealloc [[MEMPOOL]] : memref<3200xi8>
// CHECK: return [[RES]] : memref<10x20xf32> // CHECK: return [[RES]] : memref<10x20xf32>
} }