diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp index c57d0a4..8938478 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp @@ -295,8 +295,8 @@ std::vector getLoopIVsForBroadcasting(Location loc, return newLoopIVs; } -Value emitConstantOp(ConversionPatternRewriter &rewriter, Location loc, - Type type, double value) { +Value emitConstantOp( + PatternRewriter &rewriter, Location loc, Type type, double value) { Attribute constantAttr; auto typeKind = type.getKind(); if (typeKind == StandardTypes::F16) { @@ -486,3 +486,31 @@ int64_t getMemRefSizeInBytes(Value val) { size *= getMemRefEltSizeInBytes(memRefType); return size; } + +Value getDynamicMemRefSizeInBytes( + MemRefType type, Location loc, PatternRewriter &rewriter, AllocOp allocOp) { + // Initialize the size variable with the size in bytes of the type. + int64_t typeSize = getMemRefEltSizeInBytes(type); + Value result = + emitConstantOp(rewriter, loc, rewriter.getIndexType(), typeSize); + + // Multiply all dimensions (constant and dynamic). + auto memRefShape = type.getShape(); + auto rank = memRefShape.size(); + int dynDimIdx = 0; + for (int idx = 0; idx < rank; ++idx) { + if (memRefShape[idx] < 0) { + // Dyanmic size. + auto dynamicDim = allocOp.getOperands()[dynDimIdx]; + dynDimIdx++; + result = rewriter.create(loc, result, dynamicDim); + } else { + // Static size. + auto staticDim = emitConstantOp( + rewriter, loc, rewriter.getIndexType(), memRefShape[idx]); + result = rewriter.create(loc, result, staticDim); + } + } + + return result; +} diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index 6b6660c..3abc163 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -86,7 +86,7 @@ std::vector getLoopIVsForBroadcasting(Location loc, // Use this function for small values only to avoid unexpected loss in type // casting. Value emitConstantOp( - ConversionPatternRewriter &rewriter, Location loc, Type type, double value); + PatternRewriter &rewriter, Location loc, Type type, double value); // Emit a positive infinity constant of a specific type. // Supported types: F16, F32, F64, Int8, Int16, Int32, Int64. @@ -246,3 +246,6 @@ void populateLoweringONNXSplitOpPattern( bool checkOpResultIsUsedByGetRef(AllocOp *allocOp); int64_t getMemRefSizeInBytes(Value val); + +Value getDynamicMemRefSizeInBytes( + MemRefType type, Location loc, PatternRewriter &rewriter, AllocOp allocOp); diff --git a/src/Transform/BundleMemoryPools.cpp b/src/Transform/BundleMemoryPools.cpp index f8e1ddc..81e6700 100644 --- a/src/Transform/BundleMemoryPools.cpp +++ b/src/Transform/BundleMemoryPools.cpp @@ -73,6 +73,10 @@ public: if (!checkOpResultIsUsedByGetRef(&allocOp)) return failure(); + // TODO: remove once we support the bundling of dynamic memory pools. + if (!hasAllConstantDimensions(memRefType)) + return failure(); + // Alloc memory type must be byte. if (getMemRefEltSizeInBytes(memRefType) != 1) return failure(); diff --git a/src/Transform/EnableMemoryPool.cpp b/src/Transform/EnableMemoryPool.cpp index c41293f..2439b7d 100644 --- a/src/Transform/EnableMemoryPool.cpp +++ b/src/Transform/EnableMemoryPool.cpp @@ -61,23 +61,34 @@ public: // TODO: Enable this pass for MemRef with dyanmic shapes. // If alloc operation is not returned then it is a candidate for // being included in the memory pool. - if (!hasAllConstantDimensions(memRefType) || - checkOpResultIsReturned(&allocOp)) + if (checkOpResultIsReturned(&allocOp)) return failure(); // Check the result of this alloc is not already used by a krnl.getref. if (checkOpResultIsUsedByGetRef(&allocOp)) return failure(); - // Compute total size. - int64_t totalSize = getMemRefSizeInBytes(allocOp.getResult()); - - // Emit new alloc. + AllocOp newAlloc; SmallVector memPoolShape; - memPoolShape.emplace_back(totalSize); - auto memPoolMemRefType = - MemRefType::get(memPoolShape, rewriter.getIntegerType(8)); - auto newAlloc = rewriter.create(loc, memPoolMemRefType); + if (hasAllConstantDimensions(memRefType)) { + // Compute total size. + int64_t totalSize = getMemRefSizeInBytes(allocOp.getResult()); + + // Emit new alloc. + memPoolShape.emplace_back(totalSize); + auto memPoolMemRefType = + MemRefType::get(memPoolShape, rewriter.getIntegerType(8)); + newAlloc = rewriter.create(loc, memPoolMemRefType); + } else { + memPoolShape.emplace_back(-1); + auto memPoolMemRefType = + MemRefType::get(memPoolShape, rewriter.getIntegerType(8)); + + Value dyanmicTotalSize = + getDynamicMemRefSizeInBytes(memRefType, loc, rewriter, allocOp); + newAlloc = + rewriter.create(loc, memPoolMemRefType, dyanmicTotalSize); + } // Emit new dealloc. auto dealloc = rewriter.create(loc, newAlloc); diff --git a/test/mlir/onnx/onnx_enable_memory_pool.mlir b/test/mlir/onnx/onnx_enable_memory_pool.mlir index 62b305c..5ae18d2 100644 --- a/test/mlir/onnx/onnx_enable_memory_pool.mlir +++ b/test/mlir/onnx/onnx_enable_memory_pool.mlir @@ -62,3 +62,46 @@ func @test_enable_memory_pool_2(%arg0: tensor<10x10xf32>, %arg1: tensor<10x20xf3 // CHECK: dealloc [[MEMPOOL0]] : memref<800xi8> // CHECK: return [[RES]] : memref<10x20xf32> } + +// Two intermediate dynamic sized MemRefs. +func @test_enable_memory_pool_3(%arg0: tensor, %arg1: tensor, %arg2: tensor<10x10xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> + %1 = "onnx.Add"(%0, %0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %2 = "onnx.MatMul"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %2 : tensor<*xf32> + + // CHECK-LABEL: test_enable_memory_pool_3 + // CHECK: [[CONST4:%.+]] = constant 4 : index + // CHECK: [[CONST10:%.+]] = constant 10 : index + // CHECK: [[CONST0_I64:%.+]] = constant 0 : i64 + // CHECK: [[CONST1:%.+]] = constant 1 : index + // CHECK: [[CONST0:%.+]] = constant 0 : index + // CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32 + // CHECK: [[DIM1:%.+]] = dim %arg0, [[CONST0]] : memref + // CHECK: [[TMP1:%.+]] = muli [[DIM1]], [[CONST4]] : index + // CHECK: [[TMP2:%.+]] = muli [[TMP1]], [[CONST10]] : index + // CHECK: [[MEMPOOL1:%.+]] = alloc([[TMP2]]) : memref + // CHECK: [[DATA1:%.+]] = "krnl.getref"([[MEMPOOL1]], [[CONST0_I64]]) : (memref, i64) -> memref + // CHECK: krnl.define_loops 2 + // CHECK: krnl.iterate + // CHECK: affine.store {{.*}}, [[DATA1]][%arg3, %arg4] : memref + // CHECK: [[CMP1:%.+]] = cmpi "sgt", [[DIM1]], [[DIM1]] : index + // CHECK: [[SELECT1:%.+]] = select [[CMP1]], [[DIM1]], [[DIM1]] : index + // CHECK: [[TMP3:%.+]] = muli [[SELECT1]], [[CONST4]] : index + // CHECK: [[TMP4:%.+]] = muli [[TMP3]], [[CONST10]] : index + // CHECK: [[MEMPOOL2:%.+]] = alloc([[TMP4]]) : memref + // CHECK: [[DATA2:%.+]] = "krnl.getref"([[MEMPOOL2]], [[CONST0_I64]]) : (memref, i64) -> memref + // CHECK: krnl.define_loops 2 + // CHECK: krnl.iterate + // CHECK: affine.store {{.*}}, [[DATA2]][%arg3, %arg4] : memref + // CHECK: [[DATA3:%.+]] = alloc([[DIM1]]) : memref + // CHECK: krnl.define_loops 2 + // CHECK: krnl.iterate + // CHECK: affine.store [[CST]], [[DATA3]][%arg3, %arg4] : memref + // CHECK: krnl.define_loops 1 + // CHECK: krnl.iterate + // CHECK: affine.store {{.*}}, [[DATA3]][%arg3, %arg4] : memref + // CHECK: dealloc [[MEMPOOL2]] : memref + // CHECK: dealloc [[MEMPOOL1]] : memref + // CHECK: return [[DATA3]] : memref +}