Add support for emitting individual memory pools with dynamic sizes (#211)
* Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Emit memory pools with dynamic sizes. * Reformat.
This commit is contained in:
parent
4db3edc025
commit
029fb5eb67
|
@ -295,8 +295,8 @@ std::vector<Value> getLoopIVsForBroadcasting(Location loc,
|
|||
return newLoopIVs;
|
||||
}
|
||||
|
||||
Value emitConstantOp(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Type type, double value) {
|
||||
Value emitConstantOp(
|
||||
PatternRewriter &rewriter, Location loc, Type type, double value) {
|
||||
Attribute constantAttr;
|
||||
auto typeKind = type.getKind();
|
||||
if (typeKind == StandardTypes::F16) {
|
||||
|
@ -486,3 +486,31 @@ int64_t getMemRefSizeInBytes(Value val) {
|
|||
size *= getMemRefEltSizeInBytes(memRefType);
|
||||
return size;
|
||||
}
|
||||
|
||||
Value getDynamicMemRefSizeInBytes(
|
||||
MemRefType type, Location loc, PatternRewriter &rewriter, AllocOp allocOp) {
|
||||
// Initialize the size variable with the size in bytes of the type.
|
||||
int64_t typeSize = getMemRefEltSizeInBytes(type);
|
||||
Value result =
|
||||
emitConstantOp(rewriter, loc, rewriter.getIndexType(), typeSize);
|
||||
|
||||
// Multiply all dimensions (constant and dynamic).
|
||||
auto memRefShape = type.getShape();
|
||||
auto rank = memRefShape.size();
|
||||
int dynDimIdx = 0;
|
||||
for (int idx = 0; idx < rank; ++idx) {
|
||||
if (memRefShape[idx] < 0) {
|
||||
// Dyanmic size.
|
||||
auto dynamicDim = allocOp.getOperands()[dynDimIdx];
|
||||
dynDimIdx++;
|
||||
result = rewriter.create<MulIOp>(loc, result, dynamicDim);
|
||||
} else {
|
||||
// Static size.
|
||||
auto staticDim = emitConstantOp(
|
||||
rewriter, loc, rewriter.getIndexType(), memRefShape[idx]);
|
||||
result = rewriter.create<MulIOp>(loc, result, staticDim);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -86,7 +86,7 @@ std::vector<Value> getLoopIVsForBroadcasting(Location loc,
|
|||
// Use this function for small values only to avoid unexpected loss in type
|
||||
// casting.
|
||||
Value emitConstantOp(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Type type, double value);
|
||||
PatternRewriter &rewriter, Location loc, Type type, double value);
|
||||
|
||||
// Emit a positive infinity constant of a specific type.
|
||||
// Supported types: F16, F32, F64, Int8, Int16, Int32, Int64.
|
||||
|
@ -246,3 +246,6 @@ void populateLoweringONNXSplitOpPattern(
|
|||
bool checkOpResultIsUsedByGetRef(AllocOp *allocOp);
|
||||
|
||||
int64_t getMemRefSizeInBytes(Value val);
|
||||
|
||||
Value getDynamicMemRefSizeInBytes(
|
||||
MemRefType type, Location loc, PatternRewriter &rewriter, AllocOp allocOp);
|
||||
|
|
|
@ -73,6 +73,10 @@ public:
|
|||
if (!checkOpResultIsUsedByGetRef(&allocOp))
|
||||
return failure();
|
||||
|
||||
// TODO: remove once we support the bundling of dynamic memory pools.
|
||||
if (!hasAllConstantDimensions(memRefType))
|
||||
return failure();
|
||||
|
||||
// Alloc memory type must be byte.
|
||||
if (getMemRefEltSizeInBytes(memRefType) != 1)
|
||||
return failure();
|
||||
|
|
|
@ -61,23 +61,34 @@ public:
|
|||
// TODO: Enable this pass for MemRef with dyanmic shapes.
|
||||
// If alloc operation is not returned then it is a candidate for
|
||||
// being included in the memory pool.
|
||||
if (!hasAllConstantDimensions(memRefType) ||
|
||||
checkOpResultIsReturned(&allocOp))
|
||||
if (checkOpResultIsReturned(&allocOp))
|
||||
return failure();
|
||||
|
||||
// Check the result of this alloc is not already used by a krnl.getref.
|
||||
if (checkOpResultIsUsedByGetRef(&allocOp))
|
||||
return failure();
|
||||
|
||||
// Compute total size.
|
||||
int64_t totalSize = getMemRefSizeInBytes(allocOp.getResult());
|
||||
|
||||
// Emit new alloc.
|
||||
AllocOp newAlloc;
|
||||
SmallVector<int64_t, 1> memPoolShape;
|
||||
memPoolShape.emplace_back(totalSize);
|
||||
auto memPoolMemRefType =
|
||||
MemRefType::get(memPoolShape, rewriter.getIntegerType(8));
|
||||
auto newAlloc = rewriter.create<AllocOp>(loc, memPoolMemRefType);
|
||||
if (hasAllConstantDimensions(memRefType)) {
|
||||
// Compute total size.
|
||||
int64_t totalSize = getMemRefSizeInBytes(allocOp.getResult());
|
||||
|
||||
// Emit new alloc.
|
||||
memPoolShape.emplace_back(totalSize);
|
||||
auto memPoolMemRefType =
|
||||
MemRefType::get(memPoolShape, rewriter.getIntegerType(8));
|
||||
newAlloc = rewriter.create<AllocOp>(loc, memPoolMemRefType);
|
||||
} else {
|
||||
memPoolShape.emplace_back(-1);
|
||||
auto memPoolMemRefType =
|
||||
MemRefType::get(memPoolShape, rewriter.getIntegerType(8));
|
||||
|
||||
Value dyanmicTotalSize =
|
||||
getDynamicMemRefSizeInBytes(memRefType, loc, rewriter, allocOp);
|
||||
newAlloc =
|
||||
rewriter.create<AllocOp>(loc, memPoolMemRefType, dyanmicTotalSize);
|
||||
}
|
||||
|
||||
// Emit new dealloc.
|
||||
auto dealloc = rewriter.create<DeallocOp>(loc, newAlloc);
|
||||
|
|
|
@ -62,3 +62,46 @@ func @test_enable_memory_pool_2(%arg0: tensor<10x10xf32>, %arg1: tensor<10x20xf3
|
|||
// CHECK: dealloc [[MEMPOOL0]] : memref<800xi8>
|
||||
// CHECK: return [[RES]] : memref<10x20xf32>
|
||||
}
|
||||
|
||||
// Two intermediate dynamic sized MemRefs.
|
||||
func @test_enable_memory_pool_3(%arg0: tensor<?x?xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<10x10xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||
%1 = "onnx.Add"(%0, %0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
%2 = "onnx.MatMul"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %2 : tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: test_enable_memory_pool_3
|
||||
// CHECK: [[CONST4:%.+]] = constant 4 : index
|
||||
// CHECK: [[CONST10:%.+]] = constant 10 : index
|
||||
// CHECK: [[CONST0_I64:%.+]] = constant 0 : i64
|
||||
// CHECK: [[CONST1:%.+]] = constant 1 : index
|
||||
// CHECK: [[CONST0:%.+]] = constant 0 : index
|
||||
// CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: [[DIM1:%.+]] = dim %arg0, [[CONST0]] : memref<?x?xf32>
|
||||
// CHECK: [[TMP1:%.+]] = muli [[DIM1]], [[CONST4]] : index
|
||||
// CHECK: [[TMP2:%.+]] = muli [[TMP1]], [[CONST10]] : index
|
||||
// CHECK: [[MEMPOOL1:%.+]] = alloc([[TMP2]]) : memref<?xi8>
|
||||
// CHECK: [[DATA1:%.+]] = "krnl.getref"([[MEMPOOL1]], [[CONST0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
||||
// CHECK: krnl.define_loops 2
|
||||
// CHECK: krnl.iterate
|
||||
// CHECK: affine.store {{.*}}, [[DATA1]][%arg3, %arg4] : memref<?x10xf32>
|
||||
// CHECK: [[CMP1:%.+]] = cmpi "sgt", [[DIM1]], [[DIM1]] : index
|
||||
// CHECK: [[SELECT1:%.+]] = select [[CMP1]], [[DIM1]], [[DIM1]] : index
|
||||
// CHECK: [[TMP3:%.+]] = muli [[SELECT1]], [[CONST4]] : index
|
||||
// CHECK: [[TMP4:%.+]] = muli [[TMP3]], [[CONST10]] : index
|
||||
// CHECK: [[MEMPOOL2:%.+]] = alloc([[TMP4]]) : memref<?xi8>
|
||||
// CHECK: [[DATA2:%.+]] = "krnl.getref"([[MEMPOOL2]], [[CONST0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
||||
// CHECK: krnl.define_loops 2
|
||||
// CHECK: krnl.iterate
|
||||
// CHECK: affine.store {{.*}}, [[DATA2]][%arg3, %arg4] : memref<?x10xf32>
|
||||
// CHECK: [[DATA3:%.+]] = alloc([[DIM1]]) : memref<?x10xf32>
|
||||
// CHECK: krnl.define_loops 2
|
||||
// CHECK: krnl.iterate
|
||||
// CHECK: affine.store [[CST]], [[DATA3]][%arg3, %arg4] : memref<?x10xf32>
|
||||
// CHECK: krnl.define_loops 1
|
||||
// CHECK: krnl.iterate
|
||||
// CHECK: affine.store {{.*}}, [[DATA3]][%arg3, %arg4] : memref<?x10xf32>
|
||||
// CHECK: dealloc [[MEMPOOL2]] : memref<?xi8>
|
||||
// CHECK: dealloc [[MEMPOOL1]] : memref<?xi8>
|
||||
// CHECK: return [[DATA3]] : memref<?x10xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue