Add support for emitting individual memory pools with dynamic sizes (#211)
* Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Emit memory pools with dynamic sizes. * Reformat.
This commit is contained in:
		
							parent
							
								
									4db3edc025
								
							
						
					
					
						commit
						029fb5eb67
					
				|  | @ -295,8 +295,8 @@ std::vector<Value> getLoopIVsForBroadcasting(Location loc, | |||
|   return newLoopIVs; | ||||
| } | ||||
| 
 | ||||
| Value emitConstantOp(ConversionPatternRewriter &rewriter, Location loc, | ||||
|     Type type, double value) { | ||||
| Value emitConstantOp( | ||||
|     PatternRewriter &rewriter, Location loc, Type type, double value) { | ||||
|   Attribute constantAttr; | ||||
|   auto typeKind = type.getKind(); | ||||
|   if (typeKind == StandardTypes::F16) { | ||||
|  | @ -486,3 +486,31 @@ int64_t getMemRefSizeInBytes(Value val) { | |||
|   size *= getMemRefEltSizeInBytes(memRefType); | ||||
|   return size; | ||||
| } | ||||
| 
 | ||||
| Value getDynamicMemRefSizeInBytes( | ||||
|     MemRefType type, Location loc, PatternRewriter &rewriter, AllocOp allocOp) { | ||||
|   // Initialize the size variable with the size in bytes of the type.
 | ||||
|   int64_t typeSize = getMemRefEltSizeInBytes(type); | ||||
|   Value result = | ||||
|       emitConstantOp(rewriter, loc, rewriter.getIndexType(), typeSize); | ||||
| 
 | ||||
|   // Multiply all dimensions (constant and dynamic).
 | ||||
|   auto memRefShape = type.getShape(); | ||||
|   auto rank = memRefShape.size(); | ||||
|   int dynDimIdx = 0; | ||||
|   for (int idx = 0; idx < rank; ++idx) { | ||||
|     if (memRefShape[idx] < 0) { | ||||
|       // Dyanmic size.
 | ||||
|       auto dynamicDim = allocOp.getOperands()[dynDimIdx]; | ||||
|       dynDimIdx++; | ||||
|       result = rewriter.create<MulIOp>(loc, result, dynamicDim); | ||||
|     } else { | ||||
|       // Static size.
 | ||||
|       auto staticDim = emitConstantOp( | ||||
|           rewriter, loc, rewriter.getIndexType(), memRefShape[idx]); | ||||
|       result = rewriter.create<MulIOp>(loc, result, staticDim); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   return result; | ||||
| } | ||||
|  |  | |||
|  | @ -86,7 +86,7 @@ std::vector<Value> getLoopIVsForBroadcasting(Location loc, | |||
| // Use this function for small values only to avoid unexpected loss in type
 | ||||
| // casting.
 | ||||
| Value emitConstantOp( | ||||
|     ConversionPatternRewriter &rewriter, Location loc, Type type, double value); | ||||
|     PatternRewriter &rewriter, Location loc, Type type, double value); | ||||
| 
 | ||||
| // Emit a positive infinity constant of a specific type.
 | ||||
| // Supported types: F16, F32, F64, Int8, Int16, Int32, Int64.
 | ||||
|  | @ -246,3 +246,6 @@ void populateLoweringONNXSplitOpPattern( | |||
| bool checkOpResultIsUsedByGetRef(AllocOp *allocOp); | ||||
| 
 | ||||
| int64_t getMemRefSizeInBytes(Value val); | ||||
| 
 | ||||
| Value getDynamicMemRefSizeInBytes( | ||||
|     MemRefType type, Location loc, PatternRewriter &rewriter, AllocOp allocOp); | ||||
|  |  | |||
|  | @ -73,6 +73,10 @@ public: | |||
|     if (!checkOpResultIsUsedByGetRef(&allocOp)) | ||||
|       return failure(); | ||||
| 
 | ||||
|     // TODO: remove once we support the bundling of dynamic memory pools.
 | ||||
|     if (!hasAllConstantDimensions(memRefType)) | ||||
|       return failure(); | ||||
| 
 | ||||
|     // Alloc memory type must be byte.
 | ||||
|     if (getMemRefEltSizeInBytes(memRefType) != 1) | ||||
|       return failure(); | ||||
|  |  | |||
|  | @ -61,23 +61,34 @@ public: | |||
|     // TODO: Enable this pass for MemRef with dyanmic shapes.
 | ||||
|     // If alloc operation is not returned then it is a candidate for
 | ||||
|     // being included in the memory pool.
 | ||||
|     if (!hasAllConstantDimensions(memRefType) || | ||||
|         checkOpResultIsReturned(&allocOp)) | ||||
|     if (checkOpResultIsReturned(&allocOp)) | ||||
|       return failure(); | ||||
| 
 | ||||
|     // Check the result of this alloc is not already used by a krnl.getref.
 | ||||
|     if (checkOpResultIsUsedByGetRef(&allocOp)) | ||||
|       return failure(); | ||||
| 
 | ||||
|     // Compute total size.
 | ||||
|     int64_t totalSize = getMemRefSizeInBytes(allocOp.getResult()); | ||||
| 
 | ||||
|     // Emit new alloc.
 | ||||
|     AllocOp newAlloc; | ||||
|     SmallVector<int64_t, 1> memPoolShape; | ||||
|     memPoolShape.emplace_back(totalSize); | ||||
|     auto memPoolMemRefType = | ||||
|         MemRefType::get(memPoolShape, rewriter.getIntegerType(8)); | ||||
|     auto newAlloc = rewriter.create<AllocOp>(loc, memPoolMemRefType); | ||||
|     if (hasAllConstantDimensions(memRefType)) { | ||||
|       // Compute total size.
 | ||||
|       int64_t totalSize = getMemRefSizeInBytes(allocOp.getResult()); | ||||
| 
 | ||||
|       // Emit new alloc.
 | ||||
|       memPoolShape.emplace_back(totalSize); | ||||
|       auto memPoolMemRefType = | ||||
|           MemRefType::get(memPoolShape, rewriter.getIntegerType(8)); | ||||
|       newAlloc = rewriter.create<AllocOp>(loc, memPoolMemRefType); | ||||
|     } else { | ||||
|       memPoolShape.emplace_back(-1); | ||||
|       auto memPoolMemRefType = | ||||
|           MemRefType::get(memPoolShape, rewriter.getIntegerType(8)); | ||||
| 
 | ||||
|       Value dyanmicTotalSize = | ||||
|           getDynamicMemRefSizeInBytes(memRefType, loc, rewriter, allocOp); | ||||
|       newAlloc = | ||||
|           rewriter.create<AllocOp>(loc, memPoolMemRefType, dyanmicTotalSize); | ||||
|     } | ||||
| 
 | ||||
|     // Emit new dealloc.
 | ||||
|     auto dealloc = rewriter.create<DeallocOp>(loc, newAlloc); | ||||
|  |  | |||
|  | @ -62,3 +62,46 @@ func @test_enable_memory_pool_2(%arg0: tensor<10x10xf32>, %arg1: tensor<10x20xf3 | |||
|   // CHECK: dealloc [[MEMPOOL0]] : memref<800xi8> | ||||
|   // CHECK: return [[RES]] : memref<10x20xf32> | ||||
| } | ||||
| 
 | ||||
| // Two intermediate dynamic sized MemRefs. | ||||
| func @test_enable_memory_pool_3(%arg0: tensor<?x?xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<10x10xf32>) -> tensor<*xf32> { | ||||
|   %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x10xf32>) -> tensor<*xf32> | ||||
|   %1 = "onnx.Add"(%0, %0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> | ||||
|   %2 = "onnx.MatMul"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> | ||||
|   return %2 : tensor<*xf32> | ||||
| 
 | ||||
|   // CHECK-LABEL: test_enable_memory_pool_3 | ||||
|   // CHECK: [[CONST4:%.+]] = constant 4 : index | ||||
|   // CHECK: [[CONST10:%.+]] = constant 10 : index | ||||
|   // CHECK: [[CONST0_I64:%.+]] = constant 0 : i64 | ||||
|   // CHECK: [[CONST1:%.+]] = constant 1 : index | ||||
|   // CHECK: [[CONST0:%.+]] = constant 0 : index | ||||
|   // CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32 | ||||
|   // CHECK: [[DIM1:%.+]] = dim %arg0, [[CONST0]] : memref<?x?xf32> | ||||
|   // CHECK: [[TMP1:%.+]] = muli [[DIM1]], [[CONST4]] : index | ||||
|   // CHECK: [[TMP2:%.+]] = muli [[TMP1]], [[CONST10]] : index | ||||
|   // CHECK: [[MEMPOOL1:%.+]] = alloc([[TMP2]]) : memref<?xi8> | ||||
|   // CHECK: [[DATA1:%.+]] = "krnl.getref"([[MEMPOOL1]], [[CONST0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32> | ||||
|   // CHECK: krnl.define_loops 2 | ||||
|   // CHECK: krnl.iterate | ||||
|   // CHECK: affine.store {{.*}}, [[DATA1]][%arg3, %arg4] : memref<?x10xf32> | ||||
|   // CHECK: [[CMP1:%.+]] = cmpi "sgt", [[DIM1]], [[DIM1]] : index | ||||
|   // CHECK: [[SELECT1:%.+]] = select [[CMP1]], [[DIM1]], [[DIM1]] : index | ||||
|   // CHECK: [[TMP3:%.+]] = muli [[SELECT1]], [[CONST4]] : index | ||||
|   // CHECK: [[TMP4:%.+]] = muli [[TMP3]], [[CONST10]] : index | ||||
|   // CHECK: [[MEMPOOL2:%.+]] = alloc([[TMP4]]) : memref<?xi8> | ||||
|   // CHECK: [[DATA2:%.+]] = "krnl.getref"([[MEMPOOL2]], [[CONST0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32> | ||||
|   // CHECK: krnl.define_loops 2 | ||||
|   // CHECK: krnl.iterate | ||||
|   // CHECK: affine.store {{.*}}, [[DATA2]][%arg3, %arg4] : memref<?x10xf32> | ||||
|   // CHECK: [[DATA3:%.+]] = alloc([[DIM1]]) : memref<?x10xf32> | ||||
|   // CHECK: krnl.define_loops 2 | ||||
|   // CHECK: krnl.iterate | ||||
|   // CHECK: affine.store [[CST]], [[DATA3]][%arg3, %arg4] : memref<?x10xf32> | ||||
|   // CHECK: krnl.define_loops 1 | ||||
|   // CHECK: krnl.iterate | ||||
|   // CHECK: affine.store {{.*}}, [[DATA3]][%arg3, %arg4] : memref<?x10xf32> | ||||
|   // CHECK: dealloc [[MEMPOOL2]] : memref<?xi8> | ||||
|   // CHECK: dealloc [[MEMPOOL1]] : memref<?xi8> | ||||
|   // CHECK: return [[DATA3]] : memref<?x10xf32> | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue