[MLIR] Add support for dealloc insertion (#386)
* Add support for dealloc op. * Check dealloc for returned result not present.
This commit is contained in:
parent
b2a1103915
commit
7fb2f80dce
|
@ -44,8 +44,9 @@ static MemRefType convertTensorToMemRef(TensorType type) {
|
|||
}
|
||||
|
||||
/// Insert an allocation and deallocation for the given MemRefType.
|
||||
static Value* insertAllocAndDealloc(MemRefType type, Location loc,
|
||||
PatternRewriter& rewriter, Value* oldMemRef = nullptr) {
|
||||
static Value* insertAllocAndDealloc(
|
||||
MemRefType type, Location loc, PatternRewriter& rewriter,
|
||||
bool insertDealloc, Value *oldMemRef = nullptr) {
|
||||
// Put together alloc operands for any dynamic dimensions of the memref.
|
||||
AllocOp alloc;
|
||||
if (oldMemRef) {
|
||||
|
@ -54,7 +55,6 @@ static Value* insertAllocAndDealloc(MemRefType type, Location loc,
|
|||
for (int i = 0; i < memRefShape.size(); ++i)
|
||||
if (memRefShape[i] < 0)
|
||||
allocOperands.push_back(rewriter.create<DimOp>(loc, oldMemRef, i));
|
||||
|
||||
alloc = rewriter.create<AllocOp>(loc, type, allocOperands);
|
||||
} else {
|
||||
alloc = rewriter.create<AllocOp>(loc, type);
|
||||
|
@ -66,9 +66,36 @@ static Value* insertAllocAndDealloc(MemRefType type, Location loc,
|
|||
if (hasAllConstantDimensions(type))
|
||||
alloc.getOperation()->moveBefore(&parentBlock->front());
|
||||
|
||||
if (insertDealloc) {
|
||||
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
|
||||
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
||||
}
|
||||
|
||||
return alloc;
|
||||
}
|
||||
|
||||
// Determine if current function returns the result value of the
|
||||
// current op being lowered. If it does then dealloc should not be
|
||||
// inserted.
|
||||
static bool checkInsertDealloc(Operation *currentOp) {
|
||||
auto parentBlock = currentOp->getBlock();
|
||||
|
||||
bool insertDealloc = true;
|
||||
parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) {
|
||||
assert(currentOp->getNumResults() < 2 &&
|
||||
"No more than one result supported (for now).");
|
||||
// If there is at least one result to investigate.
|
||||
if (currentOp->getNumResults() > 0) {
|
||||
auto result = currentOp->getResult(0);
|
||||
for(auto operand : op.getOperands())
|
||||
if (operand == result)
|
||||
insertDealloc = false;
|
||||
}
|
||||
});
|
||||
|
||||
return insertDealloc;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -96,10 +123,14 @@ struct ONNXAddOpLowering : public ConversionPattern {
|
|||
// TODO: verify that dimensions match.
|
||||
// TODO: can the dimension of the result differ after optimizations?
|
||||
Value *alloc;
|
||||
bool insertDealloc = checkInsertDealloc(op);
|
||||
|
||||
if (hasAllConstantDimensions(memRefType))
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
|
||||
alloc = insertAllocAndDealloc(
|
||||
memRefType, loc, rewriter, insertDealloc);
|
||||
else
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, operands[0]);
|
||||
alloc = insertAllocAndDealloc(
|
||||
memRefType, loc, rewriter, insertDealloc, operands[0]);
|
||||
|
||||
// Number of loops
|
||||
auto memRefShape = memRefType.getShape();
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
// RUN: dlc-opt --shape-inference --lower-frontend %s -split-input-file | FileCheck %s
|
||||
|
||||
module {
|
||||
func @test_sigmoid(%a1 : tensor<?x10xf32>, %a2 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Add"(%a1, %a2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||
%1 = "onnx.Add"(%0, %a2) : (tensor<*xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK: func @test_sigmoid([[ARG0:%.+]]: memref<?x10xf32>, [[ARG1:%.+]]: memref<?x10xf32>) -> memref<?x10xf32> {
|
||||
|
||||
/// First Add
|
||||
// CHECK: [[DIM_0:%.+]] = dim [[ARG0]], 0 : memref<?x10xf32>
|
||||
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim [[ARG0]], 0 : memref<?x10xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load [[ARG0]][%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[LOAD2:%.+]] = load [[ARG1]][%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[ADDF:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
|
||||
// CHECK: store [[ADDF]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
|
||||
/// Second Add
|
||||
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[LOAD2:%.+]] = load [[ARG1]][%arg2, %arg3] : memref<?x10xf32>
|
||||
// CHECK: [[ADDF:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
|
||||
// CHECK: store [[ADDF]], [[RET_RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||
|
||||
/// Dealloc of first result.
|
||||
// CHECK: dealloc [[RES]] : memref<?x10xf32>
|
||||
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
|
||||
|
||||
// CHECK: return [[RET_RES]] : memref<?x10xf32>
|
Loading…
Reference in New Issue