[MLIR] Add support for dealloc insertion (#386)
* Add support for dealloc op. * Check dealloc for returned result not present.
This commit is contained in:
		
							parent
							
								
									b2a1103915
								
							
						
					
					
						commit
						7fb2f80dce
					
				|  | @ -44,8 +44,9 @@ static MemRefType convertTensorToMemRef(TensorType type) { | |||
| } | ||||
| 
 | ||||
| /// Insert an allocation and deallocation for the given MemRefType.
 | ||||
| static Value* insertAllocAndDealloc(MemRefType type, Location loc, | ||||
|     PatternRewriter& rewriter, Value* oldMemRef = nullptr) { | ||||
| static Value* insertAllocAndDealloc( | ||||
|     MemRefType type, Location loc, PatternRewriter& rewriter, | ||||
|     bool insertDealloc, Value *oldMemRef = nullptr) { | ||||
|   // Put together alloc operands for any dynamic dimensions of the memref.
 | ||||
|   AllocOp alloc; | ||||
|   if (oldMemRef) { | ||||
|  | @ -54,7 +55,6 @@ static Value* insertAllocAndDealloc(MemRefType type, Location loc, | |||
|     for (int i = 0; i < memRefShape.size(); ++i) | ||||
|       if (memRefShape[i] < 0) | ||||
|         allocOperands.push_back(rewriter.create<DimOp>(loc, oldMemRef, i)); | ||||
| 
 | ||||
|     alloc = rewriter.create<AllocOp>(loc, type, allocOperands); | ||||
|   } else { | ||||
|     alloc = rewriter.create<AllocOp>(loc, type); | ||||
|  | @ -66,9 +66,36 @@ static Value* insertAllocAndDealloc(MemRefType type, Location loc, | |||
|   if (hasAllConstantDimensions(type)) | ||||
|     alloc.getOperation()->moveBefore(&parentBlock->front()); | ||||
| 
 | ||||
|   if (insertDealloc) { | ||||
|     auto dealloc = rewriter.create<DeallocOp>(loc, alloc); | ||||
|     dealloc.getOperation()->moveBefore(&parentBlock->back()); | ||||
|   } | ||||
| 
 | ||||
|   return alloc; | ||||
| } | ||||
| 
 | ||||
| // Determine if current function returns the result value of the
 | ||||
| // current op being lowered. If it does then dealloc should not be
 | ||||
| // inserted.
 | ||||
| static bool checkInsertDealloc(Operation *currentOp) { | ||||
|   auto parentBlock = currentOp->getBlock(); | ||||
| 
 | ||||
|   bool insertDealloc = true; | ||||
|   parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) { | ||||
|     assert(currentOp->getNumResults() < 2 && | ||||
|         "No more than one result supported (for now)."); | ||||
|     // If there is at least one result to investigate.
 | ||||
|     if (currentOp->getNumResults() > 0) { | ||||
|       auto result = currentOp->getResult(0); | ||||
|       for(auto operand : op.getOperands()) | ||||
|         if (operand == result) | ||||
|           insertDealloc = false; | ||||
|     } | ||||
|   }); | ||||
| 
 | ||||
|   return insertDealloc; | ||||
| } | ||||
| 
 | ||||
| namespace { | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
|  | @ -96,10 +123,14 @@ struct ONNXAddOpLowering : public ConversionPattern { | |||
|     // TODO: verify that dimensions match.
 | ||||
|     // TODO: can the dimension of the result differ after optimizations?
 | ||||
|     Value *alloc; | ||||
|     bool insertDealloc = checkInsertDealloc(op); | ||||
| 
 | ||||
|     if (hasAllConstantDimensions(memRefType)) | ||||
|       alloc = insertAllocAndDealloc(memRefType, loc, rewriter); | ||||
|       alloc = insertAllocAndDealloc( | ||||
|           memRefType, loc, rewriter, insertDealloc); | ||||
|     else | ||||
|       alloc = insertAllocAndDealloc(memRefType, loc, rewriter, operands[0]); | ||||
|       alloc = insertAllocAndDealloc( | ||||
|           memRefType, loc, rewriter, insertDealloc, operands[0]); | ||||
| 
 | ||||
|     // Number of loops
 | ||||
|     auto memRefShape = memRefType.getShape(); | ||||
|  |  | |||
|  | @ -0,0 +1,45 @@ | |||
| // RUN: dlc-opt --shape-inference --lower-frontend %s -split-input-file | FileCheck %s | ||||
| 
 | ||||
| module { | ||||
|   func @test_sigmoid(%a1 : tensor<?x10xf32>, %a2 : tensor<?x10xf32>) -> tensor<*xf32> { | ||||
|     %0 = "onnx.Add"(%a1, %a2) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32> | ||||
|     %1 = "onnx.Add"(%0, %a2) : (tensor<*xf32>, tensor<?x10xf32>) -> tensor<*xf32> | ||||
|     "std.return"(%1) : (tensor<*xf32>) -> () | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| // CHECK: func @test_sigmoid([[ARG0:%.+]]: memref<?x10xf32>, [[ARG1:%.+]]: memref<?x10xf32>) -> memref<?x10xf32> { | ||||
| 
 | ||||
| /// First Add | ||||
| // CHECK: [[DIM_0:%.+]] = dim [[ARG0]], 0 : memref<?x10xf32> | ||||
| // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32> | ||||
| // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 | ||||
| // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops  { | ||||
| // CHECK:   krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 | ||||
| // CHECK: } : () -> (!krnl.loop, !krnl.loop) | ||||
| // CHECK: [[DIM_2:%.+]] = dim [[ARG0]], 0 : memref<?x10xf32> | ||||
| // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { | ||||
| // CHECK: [[LOAD1:%.+]] = load [[ARG0]][%arg2, %arg3] : memref<?x10xf32> | ||||
| // CHECK: [[LOAD2:%.+]] = load [[ARG1]][%arg2, %arg3] : memref<?x10xf32> | ||||
| // CHECK: [[ADDF:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32 | ||||
| // CHECK: store [[ADDF]], [[RES]][%arg2, %arg3] : memref<?x10xf32> | ||||
| 
 | ||||
| /// Second Add | ||||
| // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32> | ||||
| // CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32> | ||||
| // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 | ||||
| // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops  { | ||||
| // CHECK:   krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 | ||||
| // CHECK: } : () -> (!krnl.loop, !krnl.loop) | ||||
| // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32> | ||||
| // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { | ||||
| // CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xf32> | ||||
| // CHECK: [[LOAD2:%.+]] = load [[ARG1]][%arg2, %arg3] : memref<?x10xf32> | ||||
| // CHECK: [[ADDF:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32 | ||||
| // CHECK: store [[ADDF]], [[RET_RES]][%arg2, %arg3] : memref<?x10xf32> | ||||
| 
 | ||||
| /// Dealloc of first result. | ||||
| // CHECK: dealloc [[RES]] : memref<?x10xf32> | ||||
| // CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32> | ||||
| 
 | ||||
| // CHECK: return [[RET_RES]] : memref<?x10xf32> | ||||
		Loading…
	
		Reference in New Issue