Add support for emitting individual memory pools with dynamic sizes (#211)

* Reorganize main function.

* Follow review comments.

* Emit constants are globals in Krnl and LLVM dialects.

* Emit memory pools with dynamic sizes.

* Reformat.
This commit is contained in:
Gheorghe-Teodor Bercea 2020-07-30 12:24:07 -04:00 committed by GitHub
parent 4db3edc025
commit 029fb5eb67
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 102 additions and 13 deletions

View File

@ -295,8 +295,8 @@ std::vector<Value> getLoopIVsForBroadcasting(Location loc,
return newLoopIVs;
}
Value emitConstantOp(ConversionPatternRewriter &rewriter, Location loc,
Type type, double value) {
Value emitConstantOp(
PatternRewriter &rewriter, Location loc, Type type, double value) {
Attribute constantAttr;
auto typeKind = type.getKind();
if (typeKind == StandardTypes::F16) {
@ -486,3 +486,31 @@ int64_t getMemRefSizeInBytes(Value val) {
size *= getMemRefEltSizeInBytes(memRefType);
return size;
}
Value getDynamicMemRefSizeInBytes(
MemRefType type, Location loc, PatternRewriter &rewriter, AllocOp allocOp) {
// Initialize the size variable with the size in bytes of the type.
int64_t typeSize = getMemRefEltSizeInBytes(type);
Value result =
emitConstantOp(rewriter, loc, rewriter.getIndexType(), typeSize);
// Multiply all dimensions (constant and dynamic).
auto memRefShape = type.getShape();
auto rank = memRefShape.size();
int dynDimIdx = 0;
for (int idx = 0; idx < rank; ++idx) {
if (memRefShape[idx] < 0) {
// Dyanmic size.
auto dynamicDim = allocOp.getOperands()[dynDimIdx];
dynDimIdx++;
result = rewriter.create<MulIOp>(loc, result, dynamicDim);
} else {
// Static size.
auto staticDim = emitConstantOp(
rewriter, loc, rewriter.getIndexType(), memRefShape[idx]);
result = rewriter.create<MulIOp>(loc, result, staticDim);
}
}
return result;
}

View File

@ -86,7 +86,7 @@ std::vector<Value> getLoopIVsForBroadcasting(Location loc,
// Use this function for small values only to avoid unexpected loss in type
// casting.
Value emitConstantOp(
ConversionPatternRewriter &rewriter, Location loc, Type type, double value);
PatternRewriter &rewriter, Location loc, Type type, double value);
// Emit a positive infinity constant of a specific type.
// Supported types: F16, F32, F64, Int8, Int16, Int32, Int64.
@ -246,3 +246,6 @@ void populateLoweringONNXSplitOpPattern(
bool checkOpResultIsUsedByGetRef(AllocOp *allocOp);
int64_t getMemRefSizeInBytes(Value val);
Value getDynamicMemRefSizeInBytes(
MemRefType type, Location loc, PatternRewriter &rewriter, AllocOp allocOp);

View File

@ -73,6 +73,10 @@ public:
if (!checkOpResultIsUsedByGetRef(&allocOp))
return failure();
// TODO: remove once we support the bundling of dynamic memory pools.
if (!hasAllConstantDimensions(memRefType))
return failure();
// Alloc memory type must be byte.
if (getMemRefEltSizeInBytes(memRefType) != 1)
return failure();

View File

@ -61,23 +61,34 @@ public:
// TODO: Enable this pass for MemRef with dyanmic shapes.
// If alloc operation is not returned then it is a candidate for
// being included in the memory pool.
if (!hasAllConstantDimensions(memRefType) ||
checkOpResultIsReturned(&allocOp))
if (checkOpResultIsReturned(&allocOp))
return failure();
// Check the result of this alloc is not already used by a krnl.getref.
if (checkOpResultIsUsedByGetRef(&allocOp))
return failure();
// Compute total size.
int64_t totalSize = getMemRefSizeInBytes(allocOp.getResult());
// Emit new alloc.
AllocOp newAlloc;
SmallVector<int64_t, 1> memPoolShape;
memPoolShape.emplace_back(totalSize);
auto memPoolMemRefType =
MemRefType::get(memPoolShape, rewriter.getIntegerType(8));
auto newAlloc = rewriter.create<AllocOp>(loc, memPoolMemRefType);
if (hasAllConstantDimensions(memRefType)) {
// Compute total size.
int64_t totalSize = getMemRefSizeInBytes(allocOp.getResult());
// Emit new alloc.
memPoolShape.emplace_back(totalSize);
auto memPoolMemRefType =
MemRefType::get(memPoolShape, rewriter.getIntegerType(8));
newAlloc = rewriter.create<AllocOp>(loc, memPoolMemRefType);
} else {
memPoolShape.emplace_back(-1);
auto memPoolMemRefType =
MemRefType::get(memPoolShape, rewriter.getIntegerType(8));
Value dyanmicTotalSize =
getDynamicMemRefSizeInBytes(memRefType, loc, rewriter, allocOp);
newAlloc =
rewriter.create<AllocOp>(loc, memPoolMemRefType, dyanmicTotalSize);
}
// Emit new dealloc.
auto dealloc = rewriter.create<DeallocOp>(loc, newAlloc);

View File

@ -62,3 +62,46 @@ func @test_enable_memory_pool_2(%arg0: tensor<10x10xf32>, %arg1: tensor<10x20xf3
// CHECK: dealloc [[MEMPOOL0]] : memref<800xi8>
// CHECK: return [[RES]] : memref<10x20xf32>
}
// Two intermediate dynamic sized MemRefs.
func @test_enable_memory_pool_3(%arg0: tensor<?x?xf32>, %arg1: tensor<?x10xf32>, %arg2: tensor<10x10xf32>) -> tensor<*xf32> {
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x10xf32>) -> tensor<*xf32>
%1 = "onnx.Add"(%0, %0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
%2 = "onnx.MatMul"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %2 : tensor<*xf32>
// CHECK-LABEL: test_enable_memory_pool_3
// CHECK: [[CONST4:%.+]] = constant 4 : index
// CHECK: [[CONST10:%.+]] = constant 10 : index
// CHECK: [[CONST0_I64:%.+]] = constant 0 : i64
// CHECK: [[CONST1:%.+]] = constant 1 : index
// CHECK: [[CONST0:%.+]] = constant 0 : index
// CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32
// CHECK: [[DIM1:%.+]] = dim %arg0, [[CONST0]] : memref<?x?xf32>
// CHECK: [[TMP1:%.+]] = muli [[DIM1]], [[CONST4]] : index
// CHECK: [[TMP2:%.+]] = muli [[TMP1]], [[CONST10]] : index
// CHECK: [[MEMPOOL1:%.+]] = alloc([[TMP2]]) : memref<?xi8>
// CHECK: [[DATA1:%.+]] = "krnl.getref"([[MEMPOOL1]], [[CONST0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
// CHECK: krnl.define_loops 2
// CHECK: krnl.iterate
// CHECK: affine.store {{.*}}, [[DATA1]][%arg3, %arg4] : memref<?x10xf32>
// CHECK: [[CMP1:%.+]] = cmpi "sgt", [[DIM1]], [[DIM1]] : index
// CHECK: [[SELECT1:%.+]] = select [[CMP1]], [[DIM1]], [[DIM1]] : index
// CHECK: [[TMP3:%.+]] = muli [[SELECT1]], [[CONST4]] : index
// CHECK: [[TMP4:%.+]] = muli [[TMP3]], [[CONST10]] : index
// CHECK: [[MEMPOOL2:%.+]] = alloc([[TMP4]]) : memref<?xi8>
// CHECK: [[DATA2:%.+]] = "krnl.getref"([[MEMPOOL2]], [[CONST0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
// CHECK: krnl.define_loops 2
// CHECK: krnl.iterate
// CHECK: affine.store {{.*}}, [[DATA2]][%arg3, %arg4] : memref<?x10xf32>
// CHECK: [[DATA3:%.+]] = alloc([[DIM1]]) : memref<?x10xf32>
// CHECK: krnl.define_loops 2
// CHECK: krnl.iterate
// CHECK: affine.store [[CST]], [[DATA3]][%arg3, %arg4] : memref<?x10xf32>
// CHECK: krnl.define_loops 1
// CHECK: krnl.iterate
// CHECK: affine.store {{.*}}, [[DATA3]][%arg3, %arg4] : memref<?x10xf32>
// CHECK: dealloc [[MEMPOOL2]] : memref<?xi8>
// CHECK: dealloc [[MEMPOOL1]] : memref<?xi8>
// CHECK: return [[DATA3]] : memref<?x10xf32>
}