[MLIR] Add support for dealloc insertion (#386)
* Add support for dealloc op. * Check dealloc for returned result not present.
This commit is contained in:
parent
b2a1103915
commit
7fb2f80dce
|
@ -44,8 +44,9 @@ static MemRefType convertTensorToMemRef(TensorType type) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Insert an allocation and deallocation for the given MemRefType.
|
/// Insert an allocation and deallocation for the given MemRefType.
|
||||||
static Value* insertAllocAndDealloc(MemRefType type, Location loc,
|
static Value* insertAllocAndDealloc(
|
||||||
PatternRewriter& rewriter, Value* oldMemRef = nullptr) {
|
MemRefType type, Location loc, PatternRewriter& rewriter,
|
||||||
|
bool insertDealloc, Value *oldMemRef = nullptr) {
|
||||||
// Put together alloc operands for any dynamic dimensions of the memref.
|
// Put together alloc operands for any dynamic dimensions of the memref.
|
||||||
AllocOp alloc;
|
AllocOp alloc;
|
||||||
if (oldMemRef) {
|
if (oldMemRef) {
|
||||||
|
@ -54,7 +55,6 @@ static Value* insertAllocAndDealloc(MemRefType type, Location loc,
|
||||||
for (int i = 0; i < memRefShape.size(); ++i)
|
for (int i = 0; i < memRefShape.size(); ++i)
|
||||||
if (memRefShape[i] < 0)
|
if (memRefShape[i] < 0)
|
||||||
allocOperands.push_back(rewriter.create<DimOp>(loc, oldMemRef, i));
|
allocOperands.push_back(rewriter.create<DimOp>(loc, oldMemRef, i));
|
||||||
|
|
||||||
alloc = rewriter.create<AllocOp>(loc, type, allocOperands);
|
alloc = rewriter.create<AllocOp>(loc, type, allocOperands);
|
||||||
} else {
|
} else {
|
||||||
alloc = rewriter.create<AllocOp>(loc, type);
|
alloc = rewriter.create<AllocOp>(loc, type);
|
||||||
|
@ -66,9 +66,36 @@ static Value* insertAllocAndDealloc(MemRefType type, Location loc,
|
||||||
if (hasAllConstantDimensions(type))
|
if (hasAllConstantDimensions(type))
|
||||||
alloc.getOperation()->moveBefore(&parentBlock->front());
|
alloc.getOperation()->moveBefore(&parentBlock->front());
|
||||||
|
|
||||||
|
if (insertDealloc) {
|
||||||
|
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
|
||||||
|
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
||||||
|
}
|
||||||
|
|
||||||
return alloc;
|
return alloc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Determine if current function returns the result value of the
|
||||||
|
// current op being lowered. If it does then dealloc should not be
|
||||||
|
// inserted.
|
||||||
|
static bool checkInsertDealloc(Operation *currentOp) {
|
||||||
|
auto parentBlock = currentOp->getBlock();
|
||||||
|
|
||||||
|
bool insertDealloc = true;
|
||||||
|
parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) {
|
||||||
|
assert(currentOp->getNumResults() < 2 &&
|
||||||
|
"No more than one result supported (for now).");
|
||||||
|
// If there is at least one result to investigate.
|
||||||
|
if (currentOp->getNumResults() > 0) {
|
||||||
|
auto result = currentOp->getResult(0);
|
||||||
|
for(auto operand : op.getOperands())
|
||||||
|
if (operand == result)
|
||||||
|
insertDealloc = false;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return insertDealloc;
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -96,10 +123,14 @@ struct ONNXAddOpLowering : public ConversionPattern {
|
||||||
// TODO: verify that dimensions match.
|
// TODO: verify that dimensions match.
|
||||||
// TODO: can the dimension of the result differ after optimizations?
|
// TODO: can the dimension of the result differ after optimizations?
|
||||||
Value *alloc;
|
Value *alloc;
|
||||||
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
|
||||||
if (hasAllConstantDimensions(memRefType))
|
if (hasAllConstantDimensions(memRefType))
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
|
alloc = insertAllocAndDealloc(
|
||||||
|
memRefType, loc, rewriter, insertDealloc);
|
||||||
else
|
else
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, operands[0]);
|
alloc = insertAllocAndDealloc(
|
||||||
|
memRefType, loc, rewriter, insertDealloc, operands[0]);
|
||||||
|
|
||||||
// Number of loops
|
// Number of loops
|
||||||
auto memRefShape = memRefType.getShape();
|
auto memRefShape = memRefType.getShape();
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
// RUN: dlc-opt --shape-inference --lower-frontend %s -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
module {
|
||||||
|
func @test_sigmoid(%a1 : tensor<?x10xf32>, %a2 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Add"(%a1, %a2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
%1 = "onnx.Add"(%0, %a2) : (tensor<*xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: func @test_sigmoid([[ARG0:%.+]]: memref<?x10xf32>, [[ARG1:%.+]]: memref<?x10xf32>) -> memref<?x10xf32> {
|
||||||
|
|
||||||
|
/// First Add
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim [[ARG0]], 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim [[ARG0]], 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD1:%.+]] = load [[ARG0]][%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: [[LOAD2:%.+]] = load [[ARG1]][%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: [[ADDF:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
|
||||||
|
// CHECK: store [[ADDF]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
|
||||||
|
/// Second Add
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: [[LOAD2:%.+]] = load [[ARG1]][%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: [[ADDF:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
|
||||||
|
// CHECK: store [[ADDF]], [[RET_RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
|
||||||
|
/// Dealloc of first result.
|
||||||
|
// CHECK: dealloc [[RES]] : memref<?x10xf32>
|
||||||
|
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
|
||||||
|
|
||||||
|
// CHECK: return [[RET_RES]] : memref<?x10xf32>
|
Loading…
Reference in New Issue