Support per-block bundling for static memory pools. (#325)
* Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Support the bundling on a per-block basis. * Format. * Fix test. * Fix indent. * Improve data structure. * Format. * Simplify maps for static pools. * Format. * Clean-up.
This commit is contained in:
parent
75930ffbcf
commit
0db735f48d
|
@ -44,6 +44,10 @@ typedef struct ONNXOperandsInitState {
|
||||||
// Helper data structure for the bundling of dynamic AllocOps.
|
// Helper data structure for the bundling of dynamic AllocOps.
|
||||||
std::map<FuncOp, std::unique_ptr<ONNXOperandsInitState>> initMap;
|
std::map<FuncOp, std::unique_ptr<ONNXOperandsInitState>> initMap;
|
||||||
|
|
||||||
|
// Handling of static memory pool on a block-basis in each function.
|
||||||
|
typedef std::map<Block *, AllocOp> BlockToStaticPool;
|
||||||
|
std::map<FuncOp, std::unique_ptr<BlockToStaticPool>> staticPoolMap;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Helper functions.
|
// Helper functions.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -153,11 +157,11 @@ KrnlGetRefOp getCurrentAllocGetRef(AllocOp *allocOp) {
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* RewritePattern that replaces:
|
* RewritePattern that replaces:
|
||||||
* %mem1 = alloc() : memref<<dims1>x<type>>
|
* %mempool = alloc() : memref<<dims1>x<type>>
|
||||||
* %mem2 = alloc() : memref<<dims2>x<type>>
|
* %mem2 = alloc() : memref<<dims2>x<type>>
|
||||||
* %1 = krnl.getref %mem2 0 : memref<<dims2>x<type>>
|
* %1 = krnl.getref %mem2 0 : memref<<dims2>x<type>>
|
||||||
* =>
|
* =>
|
||||||
* %mem1 = alloc() : memref<<dims1 + dims2>x<type>>
|
* %mempool = alloc() : memref<<dims1 + dims2>x<type>>
|
||||||
* %1 = krnl.getref %mem1 <dims1> : memref<<dims2>x<type>>
|
* %1 = krnl.getref %mem1 <dims1> : memref<<dims2>x<type>>
|
||||||
*
|
*
|
||||||
*
|
*
|
||||||
|
@ -171,7 +175,7 @@ KrnlGetRefOp getCurrentAllocGetRef(AllocOp *allocOp) {
|
||||||
* operations are part of memory pooling.
|
* operations are part of memory pooling.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
class KrnlBundleMemoryPools : public OpRewritePattern<AllocOp> {
|
class KrnlBundleStaticMemoryPools : public OpRewritePattern<AllocOp> {
|
||||||
public:
|
public:
|
||||||
using OpRewritePattern<AllocOp>::OpRewritePattern;
|
using OpRewritePattern<AllocOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
@ -199,49 +203,85 @@ public:
|
||||||
if (memRefShape.size() != 1)
|
if (memRefShape.size() != 1)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
int64_t currentMemPoolSize = memRefShape[0];
|
FuncOp function = getContainingFunction(allocOp);
|
||||||
|
|
||||||
// Get a KrnlGetRefOp which does not use the current alloc.
|
if (staticPoolMap.count(function) == 0) {
|
||||||
if (KrnlGetRefOp unbundledGetRef = getUnbundledGetRef(&allocOp)) {
|
return failure();
|
||||||
// Make sure that this get ref uses a static alloc.
|
}
|
||||||
auto unbundledGetRefType =
|
|
||||||
convertToMemRefType(unbundledGetRef.getResult().getType());
|
|
||||||
if (!hasAllConstantDimensions(unbundledGetRefType))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Current memory pool size is the offset for the newly bundled
|
std::unique_ptr<BlockToStaticPool> &blockToStaticPool =
|
||||||
// internal MemRef. Emit the offset as a constant.
|
staticPoolMap.at(function);
|
||||||
auto offset = rewriter.create<ConstantOp>(
|
|
||||||
loc, rewriter.getIntegerAttr(
|
|
||||||
rewriter.getIntegerType(64), currentMemPoolSize));
|
|
||||||
|
|
||||||
// Size in bytes of the output of the krnl.getref operation.
|
// Get parent block.
|
||||||
int64_t unbundledTotalSize =
|
Block *parentBlock = allocOp.getOperation()->getBlock();
|
||||||
getMemRefSizeInBytes(unbundledGetRef.getResult());
|
|
||||||
|
|
||||||
// Compute new size.
|
if (blockToStaticPool->count(parentBlock) == 0) {
|
||||||
int64_t bundleTotalSize = unbundledTotalSize + currentMemPoolSize;
|
allocOp.getOperation()->moveBefore(&parentBlock->front());
|
||||||
|
// Create new entry in the block map.
|
||||||
// We need to emit a new alloc which contains the additional MemRef.
|
blockToStaticPool->insert(
|
||||||
SmallVector<int64_t, 1> newMemPoolShape;
|
std::pair<Block *, AllocOp>(parentBlock, allocOp));
|
||||||
newMemPoolShape.emplace_back(bundleTotalSize);
|
|
||||||
auto bundledMemPoolMemRefType =
|
|
||||||
MemRefType::get(newMemPoolShape, rewriter.getIntegerType(8));
|
|
||||||
auto bundledAlloc =
|
|
||||||
rewriter.create<AllocOp>(loc, bundledMemPoolMemRefType);
|
|
||||||
|
|
||||||
// The newly bundled MemRef expressed as a KrnlGetRefOp.
|
|
||||||
auto bundledMemRef = rewriter.create<KrnlGetRefOp>(
|
|
||||||
loc, unbundledGetRef.getResult().getType(), bundledAlloc, offset);
|
|
||||||
rewriter.replaceOp(unbundledGetRef, bundledMemRef.getResult());
|
|
||||||
|
|
||||||
// Replace old memory pool with new one.
|
|
||||||
rewriter.replaceOp(allocOp, bundledAlloc.getResult());
|
|
||||||
|
|
||||||
|
// This is the initial memory pool for this block and it is
|
||||||
|
// trivially bundled hence it's safe to return success.
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
return failure();
|
// If this parent block has been found present in the map, it means
|
||||||
|
// a static memory bundle already exists. Fetch it.
|
||||||
|
AllocOp staticMemPoolAlloc = blockToStaticPool->at(parentBlock);
|
||||||
|
|
||||||
|
// If this is the alloc representing the memory pool and the function
|
||||||
|
// already has an init block, pattern matching must fail to avoid
|
||||||
|
// processing the dynamic memory pool a second time.
|
||||||
|
if (allocOp == staticMemPoolAlloc)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto staticMemPoolShape =
|
||||||
|
convertToMemRefType(staticMemPoolAlloc.getResult().getType())
|
||||||
|
.getShape();
|
||||||
|
int64_t currentMemPoolSize = staticMemPoolShape[0];
|
||||||
|
|
||||||
|
// Get the getref of the current allocOp. There is exactly one such getref.
|
||||||
|
KrnlGetRefOp currentAllocGetRef = getCurrentAllocGetRef(&allocOp);
|
||||||
|
if (!currentAllocGetRef)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Current memory pool size is the offset for the newly bundled
|
||||||
|
// internal MemRef. Emit the offset as a constant.
|
||||||
|
auto offset = rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(
|
||||||
|
rewriter.getIntegerType(64), currentMemPoolSize));
|
||||||
|
|
||||||
|
// Size in bytes of the output of the krnl.getref operation.
|
||||||
|
int64_t unbundledTotalSize = memRefShape[0];
|
||||||
|
|
||||||
|
// Compute new size.
|
||||||
|
int64_t bundleTotalSize = unbundledTotalSize + currentMemPoolSize;
|
||||||
|
|
||||||
|
// We need to emit a new alloc which contains the additional MemRef.
|
||||||
|
SmallVector<int64_t, 1> newMemPoolShape;
|
||||||
|
newMemPoolShape.emplace_back(bundleTotalSize);
|
||||||
|
auto bundledMemPoolMemRefType =
|
||||||
|
MemRefType::get(newMemPoolShape, rewriter.getIntegerType(8));
|
||||||
|
auto newStaticMemPoolAlloc =
|
||||||
|
rewriter.create<AllocOp>(loc, bundledMemPoolMemRefType);
|
||||||
|
|
||||||
|
// The newly bundled MemRef expressed as a KrnlGetRefOp.
|
||||||
|
auto bundledMemRef = rewriter.create<KrnlGetRefOp>(loc,
|
||||||
|
currentAllocGetRef.getResult().getType(), newStaticMemPoolAlloc,
|
||||||
|
offset);
|
||||||
|
rewriter.replaceOp(currentAllocGetRef, bundledMemRef.getResult());
|
||||||
|
|
||||||
|
// Replace old memory pool with new one.
|
||||||
|
rewriter.replaceOp(staticMemPoolAlloc, newStaticMemPoolAlloc.getResult());
|
||||||
|
|
||||||
|
// Update data structure to contain the newly constructed static memory
|
||||||
|
// pool.
|
||||||
|
blockToStaticPool->erase(parentBlock);
|
||||||
|
blockToStaticPool->insert(
|
||||||
|
std::pair<Block *, AllocOp>(parentBlock, newStaticMemPoolAlloc));
|
||||||
|
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -448,23 +488,25 @@ class KrnlBundleMemoryPoolsPass
|
||||||
public:
|
public:
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
auto function = getFunction();
|
auto function = getFunction();
|
||||||
|
|
||||||
// ModuleOp module = cast<ModuleOp>(function.getParentOp());
|
|
||||||
initMap.insert(std::pair<FuncOp, std::unique_ptr<ONNXOperandsInitState>>(
|
initMap.insert(std::pair<FuncOp, std::unique_ptr<ONNXOperandsInitState>>(
|
||||||
function, std::make_unique<ONNXOperandsInitState>()));
|
function, std::make_unique<ONNXOperandsInitState>()));
|
||||||
|
|
||||||
|
staticPoolMap.insert(std::pair<FuncOp, std::unique_ptr<BlockToStaticPool>>(
|
||||||
|
function, std::make_unique<BlockToStaticPool>()));
|
||||||
|
|
||||||
// Initialize state for this function.
|
// Initialize state for this function.
|
||||||
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function);
|
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function);
|
||||||
initState->initBlock = nullptr;
|
initState->initBlock = nullptr;
|
||||||
|
|
||||||
ConversionTarget target(getContext());
|
ConversionTarget target(getContext());
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
patterns.insert<KrnlBundleMemoryPools, KrnlBundleDynamicMemoryPools>(
|
patterns.insert<KrnlBundleStaticMemoryPools, KrnlBundleDynamicMemoryPools>(
|
||||||
&getContext());
|
&getContext());
|
||||||
|
|
||||||
applyPatternsAndFoldGreedily(function, patterns);
|
applyPatternsAndFoldGreedily(function, patterns);
|
||||||
|
|
||||||
initMap.erase(function);
|
initMap.erase(function);
|
||||||
|
staticPoolMap.erase(function);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -31,10 +31,10 @@ func @test_pool_bundling(%arg0: memref<10x10xf32>, %arg1: memref<10x20xf32>) ->
|
||||||
// CHECK-LABEL: test_pool_bundling
|
// CHECK-LABEL: test_pool_bundling
|
||||||
// CHECK: [[CONST_0:%.+]] = constant 0 : i64
|
// CHECK: [[CONST_0:%.+]] = constant 0 : i64
|
||||||
// CHECK: [[CONST_CST:%.+]] = constant 0.000000e+00 : f32
|
// CHECK: [[CONST_CST:%.+]] = constant 0.000000e+00 : f32
|
||||||
// CHECK: [[CONST_400:%.+]] = constant 400 : i64
|
|
||||||
// CHECK: [[CONST_1200:%.+]] = constant 1200 : i64
|
|
||||||
// CHECK: [[CONST_2000:%.+]] = constant 2000 : i64
|
|
||||||
// CHECK: [[CONST_2400:%.+]] = constant 2400 : i64
|
// CHECK: [[CONST_2400:%.+]] = constant 2400 : i64
|
||||||
|
// CHECK: [[CONST_2000:%.+]] = constant 2000 : i64
|
||||||
|
// CHECK: [[CONST_1200:%.+]] = constant 1200 : i64
|
||||||
|
// CHECK: [[CONST_400:%.+]] = constant 400 : i64
|
||||||
// CHECK: [[RES:%.+]] = alloc() : memref<10x20xf32>
|
// CHECK: [[RES:%.+]] = alloc() : memref<10x20xf32>
|
||||||
// CHECK: [[MEMPOOL:%.+]] = alloc() : memref<3200xi8>
|
// CHECK: [[MEMPOOL:%.+]] = alloc() : memref<3200xi8>
|
||||||
// CHECK: [[MEMREF1:%.+]] = "krnl.getref"([[MEMPOOL]], [[CONST_2400]]) : (memref<3200xi8>, i64) -> memref<10x20xf32>
|
// CHECK: [[MEMREF1:%.+]] = "krnl.getref"([[MEMPOOL]], [[CONST_2400]]) : (memref<3200xi8>, i64) -> memref<10x20xf32>
|
||||||
|
@ -149,12 +149,12 @@ func @test_dynamic_and_static_pool_bundling(%arg0: memref<?x?xf32>, %arg1: memre
|
||||||
return %15 : memref<?x10xf32>
|
return %15 : memref<?x10xf32>
|
||||||
|
|
||||||
// CHECK-LABEL: test_dynamic_and_static_pool_bundling
|
// CHECK-LABEL: test_dynamic_and_static_pool_bundling
|
||||||
|
// CHECK: [[C1200_I64:%.+]] = constant 1200 : i64
|
||||||
// CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32
|
// CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32
|
||||||
// CHECK: [[C0:%.+]] = constant 0 : index
|
// CHECK: [[C0:%.+]] = constant 0 : index
|
||||||
// CHECK: [[C4:%.+]] = constant 4 : index
|
// CHECK: [[C4:%.+]] = constant 4 : index
|
||||||
// CHECK: [[C10:%.+]] = constant 10 : index
|
// CHECK: [[C10:%.+]] = constant 10 : index
|
||||||
// CHECK: [[C400_I64:%.+]] = constant 400 : i64
|
// CHECK: [[C400_I64:%.+]] = constant 400 : i64
|
||||||
// CHECK: [[C2000_I64:%.+]] = constant 2000 : i64
|
|
||||||
// CHECK: [[C0_I64:%.+]] = constant 0 : i64
|
// CHECK: [[C0_I64:%.+]] = constant 0 : i64
|
||||||
// CHECK: [[DIM:%.+]] = dim %arg0, [[C0]] : memref<?x?xf32>
|
// CHECK: [[DIM:%.+]] = dim %arg0, [[C0]] : memref<?x?xf32>
|
||||||
// CHECK: [[SGT:%.+]] = cmpi "sgt", [[DIM]], [[DIM]] : index
|
// CHECK: [[SGT:%.+]] = cmpi "sgt", [[DIM]], [[DIM]] : index
|
||||||
|
@ -169,17 +169,199 @@ func @test_dynamic_and_static_pool_bundling(%arg0: memref<?x?xf32>, %arg1: memre
|
||||||
// CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
// CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
||||||
// CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[C0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
// CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[C0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
||||||
// CHECK: [[STATIC_MEMPOOL:%.+]] = alloc() : memref<2800xi8>
|
// CHECK: [[STATIC_MEMPOOL:%.+]] = alloc() : memref<2800xi8>
|
||||||
// CHECK: [[DATA3:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C2000_I64]]) : (memref<2800xi8>, i64) -> memref<10x20xf32>
|
// CHECK: [[DATA3:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C1200_I64]]) : (memref<2800xi8>, i64) -> memref<10x40xf32>
|
||||||
// CHECK: [[DATA4:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C400_I64]]) : (memref<2800xi8>, i64) -> memref<10x40xf32>
|
// CHECK: [[DATA4:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C400_I64]]) : (memref<2800xi8>, i64) -> memref<10x20xf32>
|
||||||
// CHECK: [[DATA5:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C0_I64]]) : (memref<2800xi8>, i64) -> memref<10x10xf32>
|
// CHECK: [[DATA5:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C0_I64]]) : (memref<2800xi8>, i64) -> memref<10x10xf32>
|
||||||
// CHECK: [[RES:%.+]] = alloc([[DIM]]) : memref<?x10xf32>
|
// CHECK: [[RES:%.+]] = alloc([[DIM]]) : memref<?x10xf32>
|
||||||
// CHECK: affine.store [[CST]], [[DATA1]][0, 0] : memref<?x10xf32>
|
// CHECK: affine.store [[CST]], [[DATA1]][0, 0] : memref<?x10xf32>
|
||||||
// CHECK: affine.store [[CST]], [[DATA2]][0, 0] : memref<?x10xf32>
|
// CHECK: affine.store [[CST]], [[DATA2]][0, 0] : memref<?x10xf32>
|
||||||
// CHECK: affine.store [[CST]], [[RES]][0, 0] : memref<?x10xf32>
|
// CHECK: affine.store [[CST]], [[RES]][0, 0] : memref<?x10xf32>
|
||||||
// CHECK: affine.store [[CST]], [[DATA5]][0, 0] : memref<10x10xf32>
|
// CHECK: affine.store [[CST]], [[DATA5]][0, 0] : memref<10x10xf32>
|
||||||
// CHECK: affine.store [[CST]], [[DATA3]][0, 0] : memref<10x20xf32>
|
// CHECK: affine.store [[CST]], [[DATA4]][0, 0] : memref<10x20xf32>
|
||||||
// CHECK: affine.store [[CST]], [[DATA4]][0, 0] : memref<10x40xf32>
|
// CHECK: affine.store [[CST]], [[DATA3]][0, 0] : memref<10x40xf32>
|
||||||
// CHECK: dealloc [[DYN_MEMPOOL]] : memref<?xi8>
|
// CHECK: dealloc [[DYN_MEMPOOL]] : memref<?xi8>
|
||||||
// CHECK: dealloc [[STATIC_MEMPOOL]] : memref<2800xi8>
|
// CHECK: dealloc [[STATIC_MEMPOOL]] : memref<2800xi8>
|
||||||
// CHECK: return [[RES]] : memref<?x10xf32>
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Test bundling inside a sub-block.
|
||||||
|
func @static_mem_pool_rnn_subblock(%arg0: memref<1x3x2xf32>, %arg1: memref<1x4x2xf32>, %arg2: memref<1x4x4xf32>) -> memref<1x3x4xf32> attributes {input_names = ["X", "W", "R"], output_names = ["Y"]} {
|
||||||
|
%cst = constant 0.000000e+00 : f32
|
||||||
|
%c0_i64 = constant 0 : i64
|
||||||
|
%0 = alloc() : memref<1x3x4xf32>
|
||||||
|
%2 = krnl.define_loops 1
|
||||||
|
krnl.iterate(%2) with (%2 -> %arg3 = 0 to 1) {
|
||||||
|
%3:2 = krnl.define_loops 2
|
||||||
|
krnl.iterate(%3#0, %3#1) with (%3#0 -> %arg4 = 0 to 3, %3#1 -> %arg5 = 0 to 4) {
|
||||||
|
%4 = alloc() : memref<4xi8>
|
||||||
|
%5 = "krnl.getref"(%4, %c0_i64) : (memref<4xi8>, i64) -> memref<f32>
|
||||||
|
%6 = affine.load %0[0, symbol(%arg4), symbol(%arg5)] : memref<1x3x4xf32>
|
||||||
|
%7 = alloc() : memref<4xi8>
|
||||||
|
%8 = "krnl.getref"(%7, %c0_i64) : (memref<4xi8>, i64) -> memref<f32>
|
||||||
|
affine.store %cst, %8[] : memref<f32>
|
||||||
|
%9 = alloc() : memref<4xi8>
|
||||||
|
%10 = "krnl.getref"(%9, %c0_i64) : (memref<4xi8>, i64) -> memref<f32>
|
||||||
|
affine.store %cst, %10[] : memref<f32>
|
||||||
|
%11 = krnl.define_loops 1
|
||||||
|
krnl.iterate(%11) with (%11 -> %arg6 = 0 to 2) {
|
||||||
|
%25 = affine.load %arg0[symbol(%arg3), symbol(%arg4), symbol(%arg6)] : memref<1x3x2xf32>
|
||||||
|
%26 = affine.load %arg1[0, symbol(%arg5), symbol(%arg6)] : memref<1x4x2xf32>
|
||||||
|
%27 = mulf %25, %26 : f32
|
||||||
|
%28 = affine.load %8[] : memref<f32>
|
||||||
|
%29 = addf %28, %27 : f32
|
||||||
|
affine.store %29, %8[] : memref<f32>
|
||||||
|
%30 = affine.load %arg2[0, symbol(%arg5), symbol(%arg6)] : memref<1x4x4xf32>
|
||||||
|
%31 = mulf %6, %30 : f32
|
||||||
|
%32 = affine.load %10[] : memref<f32>
|
||||||
|
%33 = addf %32, %31 : f32
|
||||||
|
affine.store %33, %10[] : memref<f32>
|
||||||
|
}
|
||||||
|
%12 = affine.load %8[] : memref<f32>
|
||||||
|
%13 = affine.load %10[] : memref<f32>
|
||||||
|
%14 = addf %12, %13 : f32
|
||||||
|
%15 = alloc() : memref<4xi8>
|
||||||
|
%16 = "krnl.getref"(%15, %c0_i64) : (memref<4xi8>, i64) -> memref<f32>
|
||||||
|
affine.store %14, %16[] : memref<f32>
|
||||||
|
%17 = affine.load %16[] : memref<f32>
|
||||||
|
%18 = subf %cst, %17 : f32
|
||||||
|
%19 = exp %17 : f32
|
||||||
|
%20 = exp %18 : f32
|
||||||
|
%21 = subf %19, %20 : f32
|
||||||
|
%22 = addf %19, %20 : f32
|
||||||
|
%23 = divf %21, %22 : f32
|
||||||
|
affine.store %23, %5[] : memref<f32>
|
||||||
|
%24 = affine.load %5[] : memref<f32>
|
||||||
|
affine.store %24, %0[0, symbol(%arg4), symbol(%arg5)] : memref<1x3x4xf32>
|
||||||
|
dealloc %15 : memref<4xi8>
|
||||||
|
dealloc %9 : memref<4xi8>
|
||||||
|
dealloc %7 : memref<4xi8>
|
||||||
|
dealloc %4 : memref<4xi8>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return %0 : memref<1x3x4xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: static_mem_pool_rnn_subblock
|
||||||
|
// CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32
|
||||||
|
// CHECK: [[C0:%.+]] = constant 0 : i64
|
||||||
|
// CHECK: [[C12:%.+]] = constant 12 : i64
|
||||||
|
// CHECK: [[C8:%.+]] = constant 8 : i64
|
||||||
|
// CHECK: [[C4:%.+]] = constant 4 : i64
|
||||||
|
// CHECK: [[RES:%.+]] = alloc() : memref<1x3x4xf32>
|
||||||
|
// CHECK: krnl.define_loops 1
|
||||||
|
// CHECK: krnl.iterate
|
||||||
|
// CHECK: krnl.define_loops 2
|
||||||
|
// CHECK: krnl.iterate
|
||||||
|
// CHECK: [[STATIC_MEM_POOL:%.+]] = alloc() : memref<16xi8>
|
||||||
|
// CHECK: [[REF1:%.+]] = "krnl.getref"([[STATIC_MEM_POOL]], [[C12]]) : (memref<16xi8>, i64) -> memref<f32>
|
||||||
|
// CHECK: affine.load
|
||||||
|
// CHECK: [[REF2:%.+]] = "krnl.getref"([[STATIC_MEM_POOL]], [[C8]]) : (memref<16xi8>, i64) -> memref<f32>
|
||||||
|
// CHECK: affine.store
|
||||||
|
// CHECK: [[REF3:%.+]] = "krnl.getref"([[STATIC_MEM_POOL]], [[C4]]) : (memref<16xi8>, i64) -> memref<f32>
|
||||||
|
// CHECK: affine.store
|
||||||
|
// CHECK: krnl.define_loops 1
|
||||||
|
// CHECK: krnl.iterate
|
||||||
|
// CHECK: [[REF4:%.+]] = "krnl.getref"([[STATIC_MEM_POOL]], [[C0]]) : (memref<16xi8>, i64) -> memref<f32>
|
||||||
|
// CHECK: dealloc [[STATIC_MEM_POOL]] : memref<16xi8>
|
||||||
|
// CHECK: return [[RES]] : memref<1x3x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Test bundling inside a sub-block and in the main block.
|
||||||
|
func @static_mem_pool_rnn_sub_and_main_block(%arg0: memref<1x3x2xf32>, %arg1: memref<1x4x2xf32>, %arg2: memref<1x4x4xf32>) -> memref<1x3x4xf32> attributes {input_names = ["X", "W", "R"], output_names = ["Y"]} {
|
||||||
|
%cst = constant 0.000000e+00 : f32
|
||||||
|
%c0_i64 = constant 0 : i64
|
||||||
|
%0 = alloc() : memref<1x3x4xf32>
|
||||||
|
%mem0 = alloc() : memref<4xi8>
|
||||||
|
%ref0 = "krnl.getref"(%mem0, %c0_i64) : (memref<4xi8>, i64) -> memref<f32>
|
||||||
|
%mem1 = alloc() : memref<4xi8>
|
||||||
|
%ref1 = "krnl.getref"(%mem1, %c0_i64) : (memref<4xi8>, i64) -> memref<f32>
|
||||||
|
%2 = krnl.define_loops 1
|
||||||
|
krnl.iterate(%2) with (%2 -> %arg3 = 0 to 1) {
|
||||||
|
%3:2 = krnl.define_loops 2
|
||||||
|
krnl.iterate(%3#0, %3#1) with (%3#0 -> %arg4 = 0 to 3, %3#1 -> %arg5 = 0 to 4) {
|
||||||
|
%4 = alloc() : memref<4xi8>
|
||||||
|
%5 = "krnl.getref"(%4, %c0_i64) : (memref<4xi8>, i64) -> memref<f32>
|
||||||
|
%6 = affine.load %0[0, symbol(%arg4), symbol(%arg5)] : memref<1x3x4xf32>
|
||||||
|
%7 = alloc() : memref<4xi8>
|
||||||
|
%8 = "krnl.getref"(%7, %c0_i64) : (memref<4xi8>, i64) -> memref<f32>
|
||||||
|
affine.store %cst, %8[] : memref<f32>
|
||||||
|
%9 = alloc() : memref<4xi8>
|
||||||
|
%10 = "krnl.getref"(%9, %c0_i64) : (memref<4xi8>, i64) -> memref<f32>
|
||||||
|
affine.store %cst, %10[] : memref<f32>
|
||||||
|
%11 = krnl.define_loops 1
|
||||||
|
krnl.iterate(%11) with (%11 -> %arg6 = 0 to 2) {
|
||||||
|
%25 = affine.load %arg0[symbol(%arg3), symbol(%arg4), symbol(%arg6)] : memref<1x3x2xf32>
|
||||||
|
%26 = affine.load %arg1[0, symbol(%arg5), symbol(%arg6)] : memref<1x4x2xf32>
|
||||||
|
%27 = mulf %25, %26 : f32
|
||||||
|
%28 = affine.load %8[] : memref<f32>
|
||||||
|
%29 = addf %28, %27 : f32
|
||||||
|
affine.store %29, %8[] : memref<f32>
|
||||||
|
%30 = affine.load %arg2[0, symbol(%arg5), symbol(%arg6)] : memref<1x4x4xf32>
|
||||||
|
%31 = mulf %6, %30 : f32
|
||||||
|
%32 = affine.load %10[] : memref<f32>
|
||||||
|
%33 = addf %32, %31 : f32
|
||||||
|
affine.store %33, %10[] : memref<f32>
|
||||||
|
affine.store %33, %ref0[] : memref<f32>
|
||||||
|
}
|
||||||
|
%12 = affine.load %8[] : memref<f32>
|
||||||
|
%13 = affine.load %10[] : memref<f32>
|
||||||
|
%14 = addf %12, %13 : f32
|
||||||
|
%15 = alloc() : memref<4xi8>
|
||||||
|
%16 = "krnl.getref"(%15, %c0_i64) : (memref<4xi8>, i64) -> memref<f32>
|
||||||
|
affine.store %14, %16[] : memref<f32>
|
||||||
|
%17 = affine.load %16[] : memref<f32>
|
||||||
|
%18 = subf %cst, %17 : f32
|
||||||
|
%19 = exp %17 : f32
|
||||||
|
%20 = exp %18 : f32
|
||||||
|
%21 = subf %19, %20 : f32
|
||||||
|
%22 = addf %19, %20 : f32
|
||||||
|
%23 = divf %21, %22 : f32
|
||||||
|
affine.store %23, %5[] : memref<f32>
|
||||||
|
%24 = affine.load %5[] : memref<f32>
|
||||||
|
affine.store %24, %0[0, symbol(%arg4), symbol(%arg5)] : memref<1x3x4xf32>
|
||||||
|
affine.store %24, %ref1[] : memref<f32>
|
||||||
|
dealloc %15 : memref<4xi8>
|
||||||
|
dealloc %9 : memref<4xi8>
|
||||||
|
dealloc %7 : memref<4xi8>
|
||||||
|
dealloc %4 : memref<4xi8>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
%mem2 = alloc() : memref<4xi8>
|
||||||
|
%ref2 = "krnl.getref"(%mem2, %c0_i64) : (memref<4xi8>, i64) -> memref<f32>
|
||||||
|
%val = affine.load %ref1[] : memref<f32>
|
||||||
|
affine.store %val, %ref2[] : memref<f32>
|
||||||
|
dealloc %mem2 : memref<4xi8>
|
||||||
|
dealloc %mem1 : memref<4xi8>
|
||||||
|
dealloc %mem0 : memref<4xi8>
|
||||||
|
return %0 : memref<1x3x4xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: static_mem_pool_rnn_sub_and_main_block
|
||||||
|
// CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32
|
||||||
|
// CHECK: [[C0:%.+]] = constant 0 : i64
|
||||||
|
// CHECK: [[C12:%.+]] = constant 12 : i64
|
||||||
|
// CHECK: [[C8:%.+]] = constant 8 : i64
|
||||||
|
// CHECK: [[C4:%.+]] = constant 4 : i64
|
||||||
|
// CHECK: [[RES:%.+]] = alloc() : memref<1x3x4xf32>
|
||||||
|
// CHECK: [[STATIC_MEM_POOL_MAIN:%.+]] = alloc() : memref<12xi8>
|
||||||
|
// CHECK: [[MAIN_REF_0:%.+]] = "krnl.getref"([[STATIC_MEM_POOL_MAIN]], [[C8]]) : (memref<12xi8>, i64) -> memref<f32>
|
||||||
|
// CHECK: [[MAIN_REF_1:%.+]] = "krnl.getref"([[STATIC_MEM_POOL_MAIN]], [[C4]]) : (memref<12xi8>, i64) -> memref<f32>
|
||||||
|
// CHECK: krnl.define_loops 1
|
||||||
|
// CHECK: krnl.iterate
|
||||||
|
// CHECK: krnl.define_loops 2
|
||||||
|
// CHECK: krnl.iterate
|
||||||
|
// CHECK: [[STATIC_MEM_POOL:%.+]] = alloc() : memref<16xi8>
|
||||||
|
// CHECK: [[REF1:%.+]] = "krnl.getref"([[STATIC_MEM_POOL]], [[C12]]) : (memref<16xi8>, i64) -> memref<f32>
|
||||||
|
// CHECK: affine.load
|
||||||
|
// CHECK: [[REF2:%.+]] = "krnl.getref"([[STATIC_MEM_POOL]], [[C8]]) : (memref<16xi8>, i64) -> memref<f32>
|
||||||
|
// CHECK: affine.store
|
||||||
|
// CHECK: [[REF3:%.+]] = "krnl.getref"([[STATIC_MEM_POOL]], [[C4]]) : (memref<16xi8>, i64) -> memref<f32>
|
||||||
|
// CHECK: affine.store
|
||||||
|
// CHECK: krnl.define_loops 1
|
||||||
|
// CHECK: krnl.iterate
|
||||||
|
// CHECK: [[REF4:%.+]] = "krnl.getref"([[STATIC_MEM_POOL]], [[C0]]) : (memref<16xi8>, i64) -> memref<f32>
|
||||||
|
// CHECK: dealloc [[STATIC_MEM_POOL]] : memref<16xi8>
|
||||||
|
// CHECK: [[MAIN_REF_2:%.+]] = "krnl.getref"([[STATIC_MEM_POOL_MAIN]], [[C0]]) : (memref<12xi8>, i64) -> memref<f32>
|
||||||
|
// CHECK: [[LOAD:%.+]] = affine.load [[MAIN_REF_1]][] : memref<f32>
|
||||||
|
// CHECK: affine.store [[LOAD]], [[MAIN_REF_2]][] : memref<f32>
|
||||||
|
// CHECK: dealloc [[STATIC_MEM_POOL_MAIN]] : memref<12xi8>
|
||||||
|
// CHECK: return [[RES]] : memref<1x3x4xf32>
|
||||||
|
}
|
||||||
|
|
|
@ -12,10 +12,10 @@ func @test_bundle_memory_pool(%arg0: tensor<10x10xf32>, %arg1: tensor<10x20xf32>
|
||||||
// CHECK-LABEL: test_bundle_memory_pool
|
// CHECK-LABEL: test_bundle_memory_pool
|
||||||
// CHECK: [[CONST0:%.+]] = constant 0 : i64
|
// CHECK: [[CONST0:%.+]] = constant 0 : i64
|
||||||
// CHECK: [[CONST00:%.+]] = constant 0.000000e+00 : f32
|
// CHECK: [[CONST00:%.+]] = constant 0.000000e+00 : f32
|
||||||
// CHECK: [[CONST400:%.+]] = constant 400 : i64
|
|
||||||
// CHECK: [[CONST1200:%.+]] = constant 1200 : i64
|
|
||||||
// CHECK: [[CONST2000:%.+]] = constant 2000 : i64
|
|
||||||
// CHECK: [[CONST2400:%.+]] = constant 2400 : i64
|
// CHECK: [[CONST2400:%.+]] = constant 2400 : i64
|
||||||
|
// CHECK: [[CONST2000:%.+]] = constant 2000 : i64
|
||||||
|
// CHECK: [[CONST1200:%.+]] = constant 1200 : i64
|
||||||
|
// CHECK: [[CONST400:%.+]] = constant 400 : i64
|
||||||
// CHECK: [[RES:%.+]] = alloc() : memref<10x20xf32>
|
// CHECK: [[RES:%.+]] = alloc() : memref<10x20xf32>
|
||||||
// CHECK: [[MEMPOOL:%.+]] = alloc() : memref<3200xi8>
|
// CHECK: [[MEMPOOL:%.+]] = alloc() : memref<3200xi8>
|
||||||
// CHECK: "krnl.getref"([[MEMPOOL]], [[CONST2400]]) : (memref<3200xi8>, i64) -> memref<10x20xf32>
|
// CHECK: "krnl.getref"([[MEMPOOL]], [[CONST2400]]) : (memref<3200xi8>, i64) -> memref<10x20xf32>
|
||||||
|
|
Loading…
Reference in New Issue