diff --git a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp index 2571458..0639363 100644 --- a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp +++ b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp @@ -1,5 +1,4 @@ -//====------ ConvertONNXToKrnl.cpp - ONNX dialects to Krnl lowering -//--------===// +//====------ ConvertONNXToKrnl.cpp - ONNX dialects to Krnl lowering -------===// // // Copyright 2019 The IBM Research Authors. // diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp index 8938478..b854698 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp @@ -43,6 +43,17 @@ MemRefType convertToMemRefType(Type type) { return memRefType; } +/// Retrieve function which contains the current operation. +FuncOp getContainingFunction(Operation *op) { + Operation *parentFuncOp = op->getParentOp(); + + // While parent is not a FuncOp and its cast to a FuncOp is null. + while (!llvm::dyn_cast_or_null(parentFuncOp)) + parentFuncOp = parentFuncOp->getParentOp(); + + return cast(parentFuncOp); +} + /// Insert an allocation and deallocation for the given MemRefType. Value insertAllocAndDealloc(MemRefType type, Location loc, PatternRewriter &rewriter, bool insertDealloc, ArrayRef operands, @@ -463,10 +474,10 @@ int64_t ArrayAttrIntVal(ArrayAttr a, int i) { } bool checkOpResultIsUsedByGetRef(AllocOp *allocOp) { - auto parentBlock = allocOp->getOperation()->getBlock(); + FuncOp function = getContainingFunction(allocOp->getOperation()); bool opIsUsedInGetRef = false; - parentBlock->walk([&opIsUsedInGetRef, allocOp](KrnlGetRefOp op) { + function.walk([&opIsUsedInGetRef, allocOp](KrnlGetRefOp op) { auto result = allocOp->getResult(); for (const auto &operand : op.getOperands()) if (operand == result) diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index 3abc163..68e481f 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -42,6 +42,9 @@ bool hasAllScalarValues(ArrayRef values); /// Get the corresponding MemRefType of a given TensorType/MemRefType. MemRefType convertToMemRefType(Type type); +/// Retrieve function which contains the current operation. +FuncOp getContainingFunction(Operation *op); + /// Insert an allocation and deallocation for the given MemRefType. Value insertAllocAndDealloc(MemRefType type, Location loc, PatternRewriter &rewriter, bool insertDealloc, diff --git a/src/Transform/BundleMemoryPools.cpp b/src/Transform/BundleMemoryPools.cpp index 81e6700..b213fcf 100644 --- a/src/Transform/BundleMemoryPools.cpp +++ b/src/Transform/BundleMemoryPools.cpp @@ -91,8 +91,6 @@ public: // Get a KrnlGetRefOp which does not use the current alloc. if (KrnlGetRefOp unbundledGetRef = getUnbundledGetRef(&allocOp)) { - unbundledGetRef.dump(); - // Current memory pool size is the offset for the newly bundled // internal MemRef. Emit the offset as a constant. auto offset = rewriter.create( diff --git a/src/Transform/EnableMemoryPool.cpp b/src/Transform/EnableMemoryPool.cpp index 2439b7d..fc0824f 100644 --- a/src/Transform/EnableMemoryPool.cpp +++ b/src/Transform/EnableMemoryPool.cpp @@ -24,10 +24,10 @@ using namespace mlir; namespace { bool checkOpResultIsReturned(AllocOp *allocOp) { - auto parentBlock = allocOp->getOperation()->getBlock(); + FuncOp function = getContainingFunction(allocOp->getOperation()); bool opIsReturned = false; - parentBlock->walk([&opIsReturned, allocOp](ReturnOp op) { + function.walk([&opIsReturned, allocOp](ReturnOp op) { auto result = allocOp->getResult(); for (const auto &operand : op.getOperands()) if (operand == result) diff --git a/test/mlir/onnx/onnx_bundle_memory_pool.mlir b/test/mlir/onnx/onnx_bundle_memory_pool.mlir index def0cda..8f99b9d 100644 --- a/test/mlir/onnx/onnx_bundle_memory_pool.mlir +++ b/test/mlir/onnx/onnx_bundle_memory_pool.mlir @@ -25,4 +25,4 @@ func @test_bundle_memory_pool(%arg0: tensor<10x10xf32>, %arg1: tensor<10x20xf32> // CHECK: "krnl.getref"([[MEMPOOL]], [[CONST0]]) : (memref<3200xi8>, i64) -> memref<10x10xf32> // CHECK: dealloc [[MEMPOOL]] : memref<3200xi8> // CHECK: return [[RES]] : memref<10x20xf32> -} \ No newline at end of file +}