[MLIR] Add support for dealloc insertion (#386)

* Add support for dealloc op.

* Check dealloc for returned result not present.
This commit is contained in:
GHEORGHE-TEOD BERCEA 2019-11-27 23:52:05 -05:00 committed by Tian Jin
parent b2a1103915
commit 7fb2f80dce
2 changed files with 82 additions and 6 deletions

View File

@ -44,8 +44,9 @@ static MemRefType convertTensorToMemRef(TensorType type) {
} }
/// Insert an allocation and deallocation for the given MemRefType. /// Insert an allocation and deallocation for the given MemRefType.
static Value* insertAllocAndDealloc(MemRefType type, Location loc, static Value* insertAllocAndDealloc(
PatternRewriter& rewriter, Value* oldMemRef = nullptr) { MemRefType type, Location loc, PatternRewriter& rewriter,
bool insertDealloc, Value *oldMemRef = nullptr) {
// Put together alloc operands for any dynamic dimensions of the memref. // Put together alloc operands for any dynamic dimensions of the memref.
AllocOp alloc; AllocOp alloc;
if (oldMemRef) { if (oldMemRef) {
@ -54,7 +55,6 @@ static Value* insertAllocAndDealloc(MemRefType type, Location loc,
for (int i = 0; i < memRefShape.size(); ++i) for (int i = 0; i < memRefShape.size(); ++i)
if (memRefShape[i] < 0) if (memRefShape[i] < 0)
allocOperands.push_back(rewriter.create<DimOp>(loc, oldMemRef, i)); allocOperands.push_back(rewriter.create<DimOp>(loc, oldMemRef, i));
alloc = rewriter.create<AllocOp>(loc, type, allocOperands); alloc = rewriter.create<AllocOp>(loc, type, allocOperands);
} else { } else {
alloc = rewriter.create<AllocOp>(loc, type); alloc = rewriter.create<AllocOp>(loc, type);
@ -66,9 +66,36 @@ static Value* insertAllocAndDealloc(MemRefType type, Location loc,
if (hasAllConstantDimensions(type)) if (hasAllConstantDimensions(type))
alloc.getOperation()->moveBefore(&parentBlock->front()); alloc.getOperation()->moveBefore(&parentBlock->front());
if (insertDealloc) {
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
dealloc.getOperation()->moveBefore(&parentBlock->back());
}
return alloc; return alloc;
} }
// Determine if current function returns the result value of the
// current op being lowered. If it does then dealloc should not be
// inserted.
static bool checkInsertDealloc(Operation *currentOp) {
auto parentBlock = currentOp->getBlock();
bool insertDealloc = true;
parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) {
assert(currentOp->getNumResults() < 2 &&
"No more than one result supported (for now).");
// If there is at least one result to investigate.
if (currentOp->getNumResults() > 0) {
auto result = currentOp->getResult(0);
for(auto operand : op.getOperands())
if (operand == result)
insertDealloc = false;
}
});
return insertDealloc;
}
namespace { namespace {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -95,11 +122,15 @@ struct ONNXAddOpLowering : public ConversionPattern {
// dimensions with the result at this pre-optimization phase. // dimensions with the result at this pre-optimization phase.
// TODO: verify that dimensions match. // TODO: verify that dimensions match.
// TODO: can the dimension of the result differ after optimizations? // TODO: can the dimension of the result differ after optimizations?
Value* alloc; Value *alloc;
bool insertDealloc = checkInsertDealloc(op);
if (hasAllConstantDimensions(memRefType)) if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter); alloc = insertAllocAndDealloc(
memRefType, loc, rewriter, insertDealloc);
else else
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, operands[0]); alloc = insertAllocAndDealloc(
memRefType, loc, rewriter, insertDealloc, operands[0]);
// Number of loops // Number of loops
auto memRefShape = memRefType.getShape(); auto memRefShape = memRefType.getShape();

View File

@ -0,0 +1,45 @@
// RUN: dlc-opt --shape-inference --lower-frontend %s -split-input-file | FileCheck %s
module {
func @test_sigmoid(%a1 : tensor<?x10xf32>, %a2 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Add"(%a1, %a2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
%1 = "onnx.Add"(%0, %a2) : (tensor<*xf32>, tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
}
}
// CHECK: func @test_sigmoid([[ARG0:%.+]]: memref<?x10xf32>, [[ARG1:%.+]]: memref<?x10xf32>) -> memref<?x10xf32> {
/// First Add
// CHECK: [[DIM_0:%.+]] = dim [[ARG0]], 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim [[ARG0]], 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load [[ARG0]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load [[ARG1]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[ADDF:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[ADDF]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
/// Second Add
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[LOAD2:%.+]] = load [[ARG1]][%arg2, %arg3] : memref<?x10xf32>
// CHECK: [[ADDF:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[ADDF]], [[RET_RES]][%arg2, %arg3] : memref<?x10xf32>
/// Dealloc of first result.
// CHECK: dealloc [[RES]] : memref<?x10xf32>
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
// CHECK: return [[RET_RES]] : memref<?x10xf32>