Add support for emitting individual memory pools with dynamic sizes (#211)
* Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Emit memory pools with dynamic sizes. * Reformat.
This commit is contained in:
		
							parent
							
								
									4db3edc025
								
							
						
					
					
						commit
						029fb5eb67
					
				|  | @ -295,8 +295,8 @@ std::vector<Value> getLoopIVsForBroadcasting(Location loc, | ||||||
|   return newLoopIVs; |   return newLoopIVs; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| Value emitConstantOp(ConversionPatternRewriter &rewriter, Location loc, | Value emitConstantOp( | ||||||
|     Type type, double value) { |     PatternRewriter &rewriter, Location loc, Type type, double value) { | ||||||
|   Attribute constantAttr; |   Attribute constantAttr; | ||||||
|   auto typeKind = type.getKind(); |   auto typeKind = type.getKind(); | ||||||
|   if (typeKind == StandardTypes::F16) { |   if (typeKind == StandardTypes::F16) { | ||||||
|  | @ -486,3 +486,31 @@ int64_t getMemRefSizeInBytes(Value val) { | ||||||
|   size *= getMemRefEltSizeInBytes(memRefType); |   size *= getMemRefEltSizeInBytes(memRefType); | ||||||
|   return size; |   return size; | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | Value getDynamicMemRefSizeInBytes( | ||||||
|  |     MemRefType type, Location loc, PatternRewriter &rewriter, AllocOp allocOp) { | ||||||
|  |   // Initialize the size variable with the size in bytes of the type.
 | ||||||
|  |   int64_t typeSize = getMemRefEltSizeInBytes(type); | ||||||
|  |   Value result = | ||||||
|  |       emitConstantOp(rewriter, loc, rewriter.getIndexType(), typeSize); | ||||||
|  | 
 | ||||||
|  |   // Multiply all dimensions (constant and dynamic).
 | ||||||
|  |   auto memRefShape = type.getShape(); | ||||||
|  |   auto rank = memRefShape.size(); | ||||||
|  |   int dynDimIdx = 0; | ||||||
|  |   for (int idx = 0; idx < rank; ++idx) { | ||||||
|  |     if (memRefShape[idx] < 0) { | ||||||
|  |       // Dyanmic size.
 | ||||||
|  |       auto dynamicDim = allocOp.getOperands()[dynDimIdx]; | ||||||
|  |       dynDimIdx++; | ||||||
|  |       result = rewriter.create<MulIOp>(loc, result, dynamicDim); | ||||||
|  |     } else { | ||||||
|  |       // Static size.
 | ||||||
|  |       auto staticDim = emitConstantOp( | ||||||
|  |           rewriter, loc, rewriter.getIndexType(), memRefShape[idx]); | ||||||
|  |       result = rewriter.create<MulIOp>(loc, result, staticDim); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   return result; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -86,7 +86,7 @@ std::vector<Value> getLoopIVsForBroadcasting(Location loc, | ||||||
| // Use this function for small values only to avoid unexpected loss in type
 | // Use this function for small values only to avoid unexpected loss in type
 | ||||||
| // casting.
 | // casting.
 | ||||||
| Value emitConstantOp( | Value emitConstantOp( | ||||||
|     ConversionPatternRewriter &rewriter, Location loc, Type type, double value); |     PatternRewriter &rewriter, Location loc, Type type, double value); | ||||||
| 
 | 
 | ||||||
| // Emit a positive infinity constant of a specific type.
 | // Emit a positive infinity constant of a specific type.
 | ||||||
| // Supported types: F16, F32, F64, Int8, Int16, Int32, Int64.
 | // Supported types: F16, F32, F64, Int8, Int16, Int32, Int64.
 | ||||||
|  | @ -246,3 +246,6 @@ void populateLoweringONNXSplitOpPattern( | ||||||
| bool checkOpResultIsUsedByGetRef(AllocOp *allocOp); | bool checkOpResultIsUsedByGetRef(AllocOp *allocOp); | ||||||
| 
 | 
 | ||||||
| int64_t getMemRefSizeInBytes(Value val); | int64_t getMemRefSizeInBytes(Value val); | ||||||
|  | 
 | ||||||
|  | Value getDynamicMemRefSizeInBytes( | ||||||
|  |     MemRefType type, Location loc, PatternRewriter &rewriter, AllocOp allocOp); | ||||||
|  |  | ||||||
|  | @ -73,6 +73,10 @@ public: | ||||||
|     if (!checkOpResultIsUsedByGetRef(&allocOp)) |     if (!checkOpResultIsUsedByGetRef(&allocOp)) | ||||||
|       return failure(); |       return failure(); | ||||||
| 
 | 
 | ||||||
|  |     // TODO: remove once we support the bundling of dynamic memory pools.
 | ||||||
|  |     if (!hasAllConstantDimensions(memRefType)) | ||||||
|  |       return failure(); | ||||||
|  | 
 | ||||||
|     // Alloc memory type must be byte.
 |     // Alloc memory type must be byte.
 | ||||||
|     if (getMemRefEltSizeInBytes(memRefType) != 1) |     if (getMemRefEltSizeInBytes(memRefType) != 1) | ||||||
|       return failure(); |       return failure(); | ||||||
|  |  | ||||||
|  | @ -61,23 +61,34 @@ public: | ||||||
|     // TODO: Enable this pass for MemRef with dyanmic shapes.
 |     // TODO: Enable this pass for MemRef with dyanmic shapes.
 | ||||||
|     // If alloc operation is not returned then it is a candidate for
 |     // If alloc operation is not returned then it is a candidate for
 | ||||||
|     // being included in the memory pool.
 |     // being included in the memory pool.
 | ||||||
|     if (!hasAllConstantDimensions(memRefType) || |     if (checkOpResultIsReturned(&allocOp)) | ||||||
|         checkOpResultIsReturned(&allocOp)) |  | ||||||
|       return failure(); |       return failure(); | ||||||
| 
 | 
 | ||||||
|     // Check the result of this alloc is not already used by a krnl.getref.
 |     // Check the result of this alloc is not already used by a krnl.getref.
 | ||||||
|     if (checkOpResultIsUsedByGetRef(&allocOp)) |     if (checkOpResultIsUsedByGetRef(&allocOp)) | ||||||
|       return failure(); |       return failure(); | ||||||
| 
 | 
 | ||||||
|  |     AllocOp newAlloc; | ||||||
|  |     SmallVector<int64_t, 1> memPoolShape; | ||||||
|  |     if (hasAllConstantDimensions(memRefType)) { | ||||||
|       // Compute total size.
 |       // Compute total size.
 | ||||||
|       int64_t totalSize = getMemRefSizeInBytes(allocOp.getResult()); |       int64_t totalSize = getMemRefSizeInBytes(allocOp.getResult()); | ||||||
| 
 | 
 | ||||||
|       // Emit new alloc.
 |       // Emit new alloc.
 | ||||||
|     SmallVector<int64_t, 1> memPoolShape; |  | ||||||
|       memPoolShape.emplace_back(totalSize); |       memPoolShape.emplace_back(totalSize); | ||||||
|       auto memPoolMemRefType = |       auto memPoolMemRefType = | ||||||
|           MemRefType::get(memPoolShape, rewriter.getIntegerType(8)); |           MemRefType::get(memPoolShape, rewriter.getIntegerType(8)); | ||||||
|     auto newAlloc = rewriter.create<AllocOp>(loc, memPoolMemRefType); |       newAlloc = rewriter.create<AllocOp>(loc, memPoolMemRefType); | ||||||
|  |     } else { | ||||||
|  |       memPoolShape.emplace_back(-1); | ||||||
|  |       auto memPoolMemRefType = | ||||||
|  |           MemRefType::get(memPoolShape, rewriter.getIntegerType(8)); | ||||||
|  | 
 | ||||||
|  |       Value dyanmicTotalSize = | ||||||
|  |           getDynamicMemRefSizeInBytes(memRefType, loc, rewriter, allocOp); | ||||||
|  |       newAlloc = | ||||||
|  |           rewriter.create<AllocOp>(loc, memPoolMemRefType, dyanmicTotalSize); | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     // Emit new dealloc.
 |     // Emit new dealloc.
 | ||||||
|     auto dealloc = rewriter.create<DeallocOp>(loc, newAlloc); |     auto dealloc = rewriter.create<DeallocOp>(loc, newAlloc); | ||||||
|  |  | ||||||
|  | @ -62,3 +62,46 @@ func @test_enable_memory_pool_2(%arg0: tensor<10x10xf32>, %arg1: tensor<10x20xf3 | ||||||
|   // CHECK: dealloc [[MEMPOOL0]] : memref<800xi8> |   // CHECK: dealloc [[MEMPOOL0]] : memref<800xi8> | ||||||
|   // CHECK: return [[RES]] : memref<10x20xf32> |   // CHECK: return [[RES]] : memref<10x20xf32> | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // Two intermediate dynamic sized MemRefs. | ||||||
|  | func @test_enable_memory_pool_3(%arg0: tensor<?x?xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<10x10xf32>) -> tensor<*xf32> { | ||||||
|  |   %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x10xf32>) -> tensor<*xf32> | ||||||
|  |   %1 = "onnx.Add"(%0, %0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> | ||||||
|  |   %2 = "onnx.MatMul"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> | ||||||
|  |   return %2 : tensor<*xf32> | ||||||
|  | 
 | ||||||
|  |   // CHECK-LABEL: test_enable_memory_pool_3 | ||||||
|  |   // CHECK: [[CONST4:%.+]] = constant 4 : index | ||||||
|  |   // CHECK: [[CONST10:%.+]] = constant 10 : index | ||||||
|  |   // CHECK: [[CONST0_I64:%.+]] = constant 0 : i64 | ||||||
|  |   // CHECK: [[CONST1:%.+]] = constant 1 : index | ||||||
|  |   // CHECK: [[CONST0:%.+]] = constant 0 : index | ||||||
|  |   // CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32 | ||||||
|  |   // CHECK: [[DIM1:%.+]] = dim %arg0, [[CONST0]] : memref<?x?xf32> | ||||||
|  |   // CHECK: [[TMP1:%.+]] = muli [[DIM1]], [[CONST4]] : index | ||||||
|  |   // CHECK: [[TMP2:%.+]] = muli [[TMP1]], [[CONST10]] : index | ||||||
|  |   // CHECK: [[MEMPOOL1:%.+]] = alloc([[TMP2]]) : memref<?xi8> | ||||||
|  |   // CHECK: [[DATA1:%.+]] = "krnl.getref"([[MEMPOOL1]], [[CONST0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32> | ||||||
|  |   // CHECK: krnl.define_loops 2 | ||||||
|  |   // CHECK: krnl.iterate | ||||||
|  |   // CHECK: affine.store {{.*}}, [[DATA1]][%arg3, %arg4] : memref<?x10xf32> | ||||||
|  |   // CHECK: [[CMP1:%.+]] = cmpi "sgt", [[DIM1]], [[DIM1]] : index | ||||||
|  |   // CHECK: [[SELECT1:%.+]] = select [[CMP1]], [[DIM1]], [[DIM1]] : index | ||||||
|  |   // CHECK: [[TMP3:%.+]] = muli [[SELECT1]], [[CONST4]] : index | ||||||
|  |   // CHECK: [[TMP4:%.+]] = muli [[TMP3]], [[CONST10]] : index | ||||||
|  |   // CHECK: [[MEMPOOL2:%.+]] = alloc([[TMP4]]) : memref<?xi8> | ||||||
|  |   // CHECK: [[DATA2:%.+]] = "krnl.getref"([[MEMPOOL2]], [[CONST0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32> | ||||||
|  |   // CHECK: krnl.define_loops 2 | ||||||
|  |   // CHECK: krnl.iterate | ||||||
|  |   // CHECK: affine.store {{.*}}, [[DATA2]][%arg3, %arg4] : memref<?x10xf32> | ||||||
|  |   // CHECK: [[DATA3:%.+]] = alloc([[DIM1]]) : memref<?x10xf32> | ||||||
|  |   // CHECK: krnl.define_loops 2 | ||||||
|  |   // CHECK: krnl.iterate | ||||||
|  |   // CHECK: affine.store [[CST]], [[DATA3]][%arg3, %arg4] : memref<?x10xf32> | ||||||
|  |   // CHECK: krnl.define_loops 1 | ||||||
|  |   // CHECK: krnl.iterate | ||||||
|  |   // CHECK: affine.store {{.*}}, [[DATA3]][%arg3, %arg4] : memref<?x10xf32> | ||||||
|  |   // CHECK: dealloc [[MEMPOOL2]] : memref<?xi8> | ||||||
|  |   // CHECK: dealloc [[MEMPOOL1]] : memref<?xi8> | ||||||
|  |   // CHECK: return [[DATA3]] : memref<?x10xf32> | ||||||
|  | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue