Assorted fixes for memory pool related passes (#244)
* Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Make mempooling more robust. * Fix. * Update MainUtils.cpp Additional canonicalization not required anymore.
This commit is contained in:
parent
2ee725d939
commit
2f41c2bf5b
|
@ -1,5 +1,4 @@
|
||||||
//====------ ConvertONNXToKrnl.cpp - ONNX dialects to Krnl lowering
|
//====------ ConvertONNXToKrnl.cpp - ONNX dialects to Krnl lowering -------===//
|
||||||
//--------===//
|
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
|
|
@ -43,6 +43,17 @@ MemRefType convertToMemRefType(Type type) {
|
||||||
return memRefType;
|
return memRefType;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Retrieve function which contains the current operation.
|
||||||
|
FuncOp getContainingFunction(Operation *op) {
|
||||||
|
Operation *parentFuncOp = op->getParentOp();
|
||||||
|
|
||||||
|
// While parent is not a FuncOp and its cast to a FuncOp is null.
|
||||||
|
while (!llvm::dyn_cast_or_null<FuncOp>(parentFuncOp))
|
||||||
|
parentFuncOp = parentFuncOp->getParentOp();
|
||||||
|
|
||||||
|
return cast<FuncOp>(parentFuncOp);
|
||||||
|
}
|
||||||
|
|
||||||
/// Insert an allocation and deallocation for the given MemRefType.
|
/// Insert an allocation and deallocation for the given MemRefType.
|
||||||
Value insertAllocAndDealloc(MemRefType type, Location loc,
|
Value insertAllocAndDealloc(MemRefType type, Location loc,
|
||||||
PatternRewriter &rewriter, bool insertDealloc, ArrayRef<Value> operands,
|
PatternRewriter &rewriter, bool insertDealloc, ArrayRef<Value> operands,
|
||||||
|
@ -463,10 +474,10 @@ int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool checkOpResultIsUsedByGetRef(AllocOp *allocOp) {
|
bool checkOpResultIsUsedByGetRef(AllocOp *allocOp) {
|
||||||
auto parentBlock = allocOp->getOperation()->getBlock();
|
FuncOp function = getContainingFunction(allocOp->getOperation());
|
||||||
|
|
||||||
bool opIsUsedInGetRef = false;
|
bool opIsUsedInGetRef = false;
|
||||||
parentBlock->walk([&opIsUsedInGetRef, allocOp](KrnlGetRefOp op) {
|
function.walk([&opIsUsedInGetRef, allocOp](KrnlGetRefOp op) {
|
||||||
auto result = allocOp->getResult();
|
auto result = allocOp->getResult();
|
||||||
for (const auto &operand : op.getOperands())
|
for (const auto &operand : op.getOperands())
|
||||||
if (operand == result)
|
if (operand == result)
|
||||||
|
|
|
@ -42,6 +42,9 @@ bool hasAllScalarValues(ArrayRef<Value> values);
|
||||||
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
|
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
|
||||||
MemRefType convertToMemRefType(Type type);
|
MemRefType convertToMemRefType(Type type);
|
||||||
|
|
||||||
|
/// Retrieve function which contains the current operation.
|
||||||
|
FuncOp getContainingFunction(Operation *op);
|
||||||
|
|
||||||
/// Insert an allocation and deallocation for the given MemRefType.
|
/// Insert an allocation and deallocation for the given MemRefType.
|
||||||
Value insertAllocAndDealloc(MemRefType type, Location loc,
|
Value insertAllocAndDealloc(MemRefType type, Location loc,
|
||||||
PatternRewriter &rewriter, bool insertDealloc,
|
PatternRewriter &rewriter, bool insertDealloc,
|
||||||
|
|
|
@ -91,8 +91,6 @@ public:
|
||||||
|
|
||||||
// Get a KrnlGetRefOp which does not use the current alloc.
|
// Get a KrnlGetRefOp which does not use the current alloc.
|
||||||
if (KrnlGetRefOp unbundledGetRef = getUnbundledGetRef(&allocOp)) {
|
if (KrnlGetRefOp unbundledGetRef = getUnbundledGetRef(&allocOp)) {
|
||||||
unbundledGetRef.dump();
|
|
||||||
|
|
||||||
// Current memory pool size is the offset for the newly bundled
|
// Current memory pool size is the offset for the newly bundled
|
||||||
// internal MemRef. Emit the offset as a constant.
|
// internal MemRef. Emit the offset as a constant.
|
||||||
auto offset = rewriter.create<ConstantOp>(
|
auto offset = rewriter.create<ConstantOp>(
|
||||||
|
|
|
@ -24,10 +24,10 @@ using namespace mlir;
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
bool checkOpResultIsReturned(AllocOp *allocOp) {
|
bool checkOpResultIsReturned(AllocOp *allocOp) {
|
||||||
auto parentBlock = allocOp->getOperation()->getBlock();
|
FuncOp function = getContainingFunction(allocOp->getOperation());
|
||||||
|
|
||||||
bool opIsReturned = false;
|
bool opIsReturned = false;
|
||||||
parentBlock->walk([&opIsReturned, allocOp](ReturnOp op) {
|
function.walk([&opIsReturned, allocOp](ReturnOp op) {
|
||||||
auto result = allocOp->getResult();
|
auto result = allocOp->getResult();
|
||||||
for (const auto &operand : op.getOperands())
|
for (const auto &operand : op.getOperands())
|
||||||
if (operand == result)
|
if (operand == result)
|
||||||
|
|
Loading…
Reference in New Issue