Add support for emitting individual memory pools with dynamic sizes (#211)
* Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Emit memory pools with dynamic sizes. * Reformat.
This commit is contained in:
parent
4db3edc025
commit
029fb5eb67
|
@ -295,8 +295,8 @@ std::vector<Value> getLoopIVsForBroadcasting(Location loc,
|
||||||
return newLoopIVs;
|
return newLoopIVs;
|
||||||
}
|
}
|
||||||
|
|
||||||
Value emitConstantOp(ConversionPatternRewriter &rewriter, Location loc,
|
Value emitConstantOp(
|
||||||
Type type, double value) {
|
PatternRewriter &rewriter, Location loc, Type type, double value) {
|
||||||
Attribute constantAttr;
|
Attribute constantAttr;
|
||||||
auto typeKind = type.getKind();
|
auto typeKind = type.getKind();
|
||||||
if (typeKind == StandardTypes::F16) {
|
if (typeKind == StandardTypes::F16) {
|
||||||
|
@ -486,3 +486,31 @@ int64_t getMemRefSizeInBytes(Value val) {
|
||||||
size *= getMemRefEltSizeInBytes(memRefType);
|
size *= getMemRefEltSizeInBytes(memRefType);
|
||||||
return size;
|
return size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value getDynamicMemRefSizeInBytes(
|
||||||
|
MemRefType type, Location loc, PatternRewriter &rewriter, AllocOp allocOp) {
|
||||||
|
// Initialize the size variable with the size in bytes of the type.
|
||||||
|
int64_t typeSize = getMemRefEltSizeInBytes(type);
|
||||||
|
Value result =
|
||||||
|
emitConstantOp(rewriter, loc, rewriter.getIndexType(), typeSize);
|
||||||
|
|
||||||
|
// Multiply all dimensions (constant and dynamic).
|
||||||
|
auto memRefShape = type.getShape();
|
||||||
|
auto rank = memRefShape.size();
|
||||||
|
int dynDimIdx = 0;
|
||||||
|
for (int idx = 0; idx < rank; ++idx) {
|
||||||
|
if (memRefShape[idx] < 0) {
|
||||||
|
// Dyanmic size.
|
||||||
|
auto dynamicDim = allocOp.getOperands()[dynDimIdx];
|
||||||
|
dynDimIdx++;
|
||||||
|
result = rewriter.create<MulIOp>(loc, result, dynamicDim);
|
||||||
|
} else {
|
||||||
|
// Static size.
|
||||||
|
auto staticDim = emitConstantOp(
|
||||||
|
rewriter, loc, rewriter.getIndexType(), memRefShape[idx]);
|
||||||
|
result = rewriter.create<MulIOp>(loc, result, staticDim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
|
@ -86,7 +86,7 @@ std::vector<Value> getLoopIVsForBroadcasting(Location loc,
|
||||||
// Use this function for small values only to avoid unexpected loss in type
|
// Use this function for small values only to avoid unexpected loss in type
|
||||||
// casting.
|
// casting.
|
||||||
Value emitConstantOp(
|
Value emitConstantOp(
|
||||||
ConversionPatternRewriter &rewriter, Location loc, Type type, double value);
|
PatternRewriter &rewriter, Location loc, Type type, double value);
|
||||||
|
|
||||||
// Emit a positive infinity constant of a specific type.
|
// Emit a positive infinity constant of a specific type.
|
||||||
// Supported types: F16, F32, F64, Int8, Int16, Int32, Int64.
|
// Supported types: F16, F32, F64, Int8, Int16, Int32, Int64.
|
||||||
|
@ -246,3 +246,6 @@ void populateLoweringONNXSplitOpPattern(
|
||||||
bool checkOpResultIsUsedByGetRef(AllocOp *allocOp);
|
bool checkOpResultIsUsedByGetRef(AllocOp *allocOp);
|
||||||
|
|
||||||
int64_t getMemRefSizeInBytes(Value val);
|
int64_t getMemRefSizeInBytes(Value val);
|
||||||
|
|
||||||
|
Value getDynamicMemRefSizeInBytes(
|
||||||
|
MemRefType type, Location loc, PatternRewriter &rewriter, AllocOp allocOp);
|
||||||
|
|
|
@ -73,6 +73,10 @@ public:
|
||||||
if (!checkOpResultIsUsedByGetRef(&allocOp))
|
if (!checkOpResultIsUsedByGetRef(&allocOp))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
// TODO: remove once we support the bundling of dynamic memory pools.
|
||||||
|
if (!hasAllConstantDimensions(memRefType))
|
||||||
|
return failure();
|
||||||
|
|
||||||
// Alloc memory type must be byte.
|
// Alloc memory type must be byte.
|
||||||
if (getMemRefEltSizeInBytes(memRefType) != 1)
|
if (getMemRefEltSizeInBytes(memRefType) != 1)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
|
@ -61,23 +61,34 @@ public:
|
||||||
// TODO: Enable this pass for MemRef with dyanmic shapes.
|
// TODO: Enable this pass for MemRef with dyanmic shapes.
|
||||||
// If alloc operation is not returned then it is a candidate for
|
// If alloc operation is not returned then it is a candidate for
|
||||||
// being included in the memory pool.
|
// being included in the memory pool.
|
||||||
if (!hasAllConstantDimensions(memRefType) ||
|
if (checkOpResultIsReturned(&allocOp))
|
||||||
checkOpResultIsReturned(&allocOp))
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Check the result of this alloc is not already used by a krnl.getref.
|
// Check the result of this alloc is not already used by a krnl.getref.
|
||||||
if (checkOpResultIsUsedByGetRef(&allocOp))
|
if (checkOpResultIsUsedByGetRef(&allocOp))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Compute total size.
|
AllocOp newAlloc;
|
||||||
int64_t totalSize = getMemRefSizeInBytes(allocOp.getResult());
|
|
||||||
|
|
||||||
// Emit new alloc.
|
|
||||||
SmallVector<int64_t, 1> memPoolShape;
|
SmallVector<int64_t, 1> memPoolShape;
|
||||||
memPoolShape.emplace_back(totalSize);
|
if (hasAllConstantDimensions(memRefType)) {
|
||||||
auto memPoolMemRefType =
|
// Compute total size.
|
||||||
MemRefType::get(memPoolShape, rewriter.getIntegerType(8));
|
int64_t totalSize = getMemRefSizeInBytes(allocOp.getResult());
|
||||||
auto newAlloc = rewriter.create<AllocOp>(loc, memPoolMemRefType);
|
|
||||||
|
// Emit new alloc.
|
||||||
|
memPoolShape.emplace_back(totalSize);
|
||||||
|
auto memPoolMemRefType =
|
||||||
|
MemRefType::get(memPoolShape, rewriter.getIntegerType(8));
|
||||||
|
newAlloc = rewriter.create<AllocOp>(loc, memPoolMemRefType);
|
||||||
|
} else {
|
||||||
|
memPoolShape.emplace_back(-1);
|
||||||
|
auto memPoolMemRefType =
|
||||||
|
MemRefType::get(memPoolShape, rewriter.getIntegerType(8));
|
||||||
|
|
||||||
|
Value dyanmicTotalSize =
|
||||||
|
getDynamicMemRefSizeInBytes(memRefType, loc, rewriter, allocOp);
|
||||||
|
newAlloc =
|
||||||
|
rewriter.create<AllocOp>(loc, memPoolMemRefType, dyanmicTotalSize);
|
||||||
|
}
|
||||||
|
|
||||||
// Emit new dealloc.
|
// Emit new dealloc.
|
||||||
auto dealloc = rewriter.create<DeallocOp>(loc, newAlloc);
|
auto dealloc = rewriter.create<DeallocOp>(loc, newAlloc);
|
||||||
|
|
|
@ -62,3 +62,46 @@ func @test_enable_memory_pool_2(%arg0: tensor<10x10xf32>, %arg1: tensor<10x20xf3
|
||||||
// CHECK: dealloc [[MEMPOOL0]] : memref<800xi8>
|
// CHECK: dealloc [[MEMPOOL0]] : memref<800xi8>
|
||||||
// CHECK: return [[RES]] : memref<10x20xf32>
|
// CHECK: return [[RES]] : memref<10x20xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Two intermediate dynamic sized MemRefs.
|
||||||
|
func @test_enable_memory_pool_3(%arg0: tensor<?x?xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<10x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
%1 = "onnx.Add"(%0, %0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
%2 = "onnx.MatMul"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
return %2 : tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_enable_memory_pool_3
|
||||||
|
// CHECK: [[CONST4:%.+]] = constant 4 : index
|
||||||
|
// CHECK: [[CONST10:%.+]] = constant 10 : index
|
||||||
|
// CHECK: [[CONST0_I64:%.+]] = constant 0 : i64
|
||||||
|
// CHECK: [[CONST1:%.+]] = constant 1 : index
|
||||||
|
// CHECK: [[CONST0:%.+]] = constant 0 : index
|
||||||
|
// CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32
|
||||||
|
// CHECK: [[DIM1:%.+]] = dim %arg0, [[CONST0]] : memref<?x?xf32>
|
||||||
|
// CHECK: [[TMP1:%.+]] = muli [[DIM1]], [[CONST4]] : index
|
||||||
|
// CHECK: [[TMP2:%.+]] = muli [[TMP1]], [[CONST10]] : index
|
||||||
|
// CHECK: [[MEMPOOL1:%.+]] = alloc([[TMP2]]) : memref<?xi8>
|
||||||
|
// CHECK: [[DATA1:%.+]] = "krnl.getref"([[MEMPOOL1]], [[CONST0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
||||||
|
// CHECK: krnl.define_loops 2
|
||||||
|
// CHECK: krnl.iterate
|
||||||
|
// CHECK: affine.store {{.*}}, [[DATA1]][%arg3, %arg4] : memref<?x10xf32>
|
||||||
|
// CHECK: [[CMP1:%.+]] = cmpi "sgt", [[DIM1]], [[DIM1]] : index
|
||||||
|
// CHECK: [[SELECT1:%.+]] = select [[CMP1]], [[DIM1]], [[DIM1]] : index
|
||||||
|
// CHECK: [[TMP3:%.+]] = muli [[SELECT1]], [[CONST4]] : index
|
||||||
|
// CHECK: [[TMP4:%.+]] = muli [[TMP3]], [[CONST10]] : index
|
||||||
|
// CHECK: [[MEMPOOL2:%.+]] = alloc([[TMP4]]) : memref<?xi8>
|
||||||
|
// CHECK: [[DATA2:%.+]] = "krnl.getref"([[MEMPOOL2]], [[CONST0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
||||||
|
// CHECK: krnl.define_loops 2
|
||||||
|
// CHECK: krnl.iterate
|
||||||
|
// CHECK: affine.store {{.*}}, [[DATA2]][%arg3, %arg4] : memref<?x10xf32>
|
||||||
|
// CHECK: [[DATA3:%.+]] = alloc([[DIM1]]) : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.define_loops 2
|
||||||
|
// CHECK: krnl.iterate
|
||||||
|
// CHECK: affine.store [[CST]], [[DATA3]][%arg3, %arg4] : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.define_loops 1
|
||||||
|
// CHECK: krnl.iterate
|
||||||
|
// CHECK: affine.store {{.*}}, [[DATA3]][%arg3, %arg4] : memref<?x10xf32>
|
||||||
|
// CHECK: dealloc [[MEMPOOL2]] : memref<?xi8>
|
||||||
|
// CHECK: dealloc [[MEMPOOL1]] : memref<?xi8>
|
||||||
|
// CHECK: return [[DATA3]] : memref<?x10xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue