Emit the dynamic memory pool (#290)
* Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Add support for bundling dynamic memory pools. * Add dynamic bundling. * Clean-up code. * Clean-up file. * Add test for bundling dynamic memory pool. * Fixes. Simplify data structure. Add mixed test. * Remove unused import.
This commit is contained in:
parent
930e20f682
commit
9f69b2f317
|
@ -15,6 +15,7 @@
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "llvm/ADT/SetVector.h"
|
||||||
|
|
||||||
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
|
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
|
||||||
#include "src/Dialect/Krnl/KrnlOps.hpp"
|
#include "src/Dialect/Krnl/KrnlOps.hpp"
|
||||||
|
@ -24,6 +25,102 @@ using namespace mlir;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Insertion point for initialization instructions and the blocks used for
|
||||||
|
// inserting the initialization and main code. These blocks will disappear
|
||||||
|
// when the first canonicalization is performed because the init block
|
||||||
|
// unconditionally branches into the second block. These blocks exist only for
|
||||||
|
// the purpose of this optimization.
|
||||||
|
// The information is recorded on a per function basis.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
typedef struct ONNXOperandsInitState {
|
||||||
|
Block *initBlock;
|
||||||
|
Block *mainBlock;
|
||||||
|
BranchOp branchInit;
|
||||||
|
AllocOp dynamicMemoryPool;
|
||||||
|
} ONNXOperandsInitState;
|
||||||
|
|
||||||
|
// Helper data structure for the bundling of dynamic AllocOps.
|
||||||
|
std::map<FuncOp, std::unique_ptr<ONNXOperandsInitState>> initMap;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Helper functions.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Retrieve function which contains the current operation.
|
||||||
|
FuncOp getContainingFunction(AllocOp op) {
|
||||||
|
Operation *parentFuncOp = op.getParentOp();
|
||||||
|
|
||||||
|
// While parent is not a FuncOp and its cast to a FuncOp is null.
|
||||||
|
while (!llvm::dyn_cast_or_null<FuncOp>(parentFuncOp))
|
||||||
|
parentFuncOp = parentFuncOp->getParentOp();
|
||||||
|
|
||||||
|
return cast<FuncOp>(parentFuncOp);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool hasInitBlock(FuncOp function) {
|
||||||
|
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function);
|
||||||
|
return initState->initBlock != nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool addInitBlock(PatternRewriter &rewriter, Location loc, AllocOp allocOp) {
|
||||||
|
// If this is the first time we encounter an operation in this
|
||||||
|
// function, we create an entry inside the initMap and split the
|
||||||
|
// function body into an init block and a main block.
|
||||||
|
//
|
||||||
|
// function func_name() {
|
||||||
|
// ... init block ...
|
||||||
|
// br ^bb1
|
||||||
|
// ^bb1: // pred: ^bb0
|
||||||
|
// ... main block ...
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Note: the block ^bb0 being the first block has its label omitted.
|
||||||
|
//
|
||||||
|
FuncOp function = getContainingFunction(allocOp);
|
||||||
|
// If the function does not contain an init block, create one.
|
||||||
|
if (!hasInitBlock(function)) {
|
||||||
|
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function);
|
||||||
|
initState = std::make_unique<ONNXOperandsInitState>();
|
||||||
|
|
||||||
|
// All input arguments are considered as part of the initialization block
|
||||||
|
// so add them to the operandsInInitBlock set.
|
||||||
|
Block *functionBlock = &function.front();
|
||||||
|
PatternRewriter::InsertionGuard insertGuard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(functionBlock);
|
||||||
|
|
||||||
|
initState->initBlock = rewriter.getInsertionBlock();
|
||||||
|
auto currentPoint = rewriter.getInsertionPoint();
|
||||||
|
initState->mainBlock =
|
||||||
|
rewriter.splitBlock(initState->initBlock, currentPoint);
|
||||||
|
|
||||||
|
rewriter.setInsertionPointToEnd(initState->initBlock);
|
||||||
|
|
||||||
|
// Insert a branch operation from initBlock to mainBlock. This
|
||||||
|
// ensures the final code contains legal blocks.
|
||||||
|
initState->branchInit =
|
||||||
|
rewriter.create<BranchOp>(loc, initState->mainBlock);
|
||||||
|
|
||||||
|
rewriter.setInsertionPointToStart(initState->mainBlock);
|
||||||
|
|
||||||
|
// Save a reference to the current dynamic memory pool value.
|
||||||
|
initState->dynamicMemoryPool = allocOp;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isBlockArgument(Block *block, Value operand) {
|
||||||
|
for (auto arg : block->getArguments())
|
||||||
|
if (operand == arg)
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
KrnlGetRefOp getUnbundledGetRef(AllocOp *memPool) {
|
KrnlGetRefOp getUnbundledGetRef(AllocOp *memPool) {
|
||||||
auto parentBlock = memPool->getOperation()->getBlock();
|
auto parentBlock = memPool->getOperation()->getBlock();
|
||||||
|
|
||||||
|
@ -37,6 +134,23 @@ KrnlGetRefOp getUnbundledGetRef(AllocOp *memPool) {
|
||||||
return unbundledGetRef;
|
return unbundledGetRef;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
KrnlGetRefOp getCurrentAllocGetRef(AllocOp *allocOp) {
|
||||||
|
auto parentBlock = allocOp->getOperation()->getBlock();
|
||||||
|
|
||||||
|
KrnlGetRefOp currentAllocGetRef = nullptr;
|
||||||
|
parentBlock->walk([¤tAllocGetRef, allocOp](KrnlGetRefOp op) {
|
||||||
|
auto result = allocOp->getResult();
|
||||||
|
if (op.getOperands()[0] == result)
|
||||||
|
currentAllocGetRef = op;
|
||||||
|
});
|
||||||
|
|
||||||
|
return currentAllocGetRef;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Rewrite patterns.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* RewritePattern that replaces:
|
* RewritePattern that replaces:
|
||||||
* %mem1 = alloc() : memref<<dims1>x<type>>
|
* %mem1 = alloc() : memref<<dims1>x<type>>
|
||||||
|
@ -73,7 +187,7 @@ public:
|
||||||
if (!checkOpResultIsUsedByGetRef(&allocOp))
|
if (!checkOpResultIsUsedByGetRef(&allocOp))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// TODO: remove once we support the bundling of dynamic memory pools.
|
// Only handle constant AllocOps.
|
||||||
if (!hasAllConstantDimensions(memRefType))
|
if (!hasAllConstantDimensions(memRefType))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -85,12 +199,16 @@ public:
|
||||||
if (memRefShape.size() != 1)
|
if (memRefShape.size() != 1)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// TODO: Change this when dyanmic shapes are supported.
|
|
||||||
// TODO: Add support for dynamic shapes.
|
|
||||||
int64_t currentMemPoolSize = memRefShape[0];
|
int64_t currentMemPoolSize = memRefShape[0];
|
||||||
|
|
||||||
// Get a KrnlGetRefOp which does not use the current alloc.
|
// Get a KrnlGetRefOp which does not use the current alloc.
|
||||||
if (KrnlGetRefOp unbundledGetRef = getUnbundledGetRef(&allocOp)) {
|
if (KrnlGetRefOp unbundledGetRef = getUnbundledGetRef(&allocOp)) {
|
||||||
|
// Make sure that this get ref uses a static alloc.
|
||||||
|
auto unbundledGetRefType =
|
||||||
|
convertToMemRefType(unbundledGetRef.getResult().getType());
|
||||||
|
if (!hasAllConstantDimensions(unbundledGetRefType))
|
||||||
|
return failure();
|
||||||
|
|
||||||
// Current memory pool size is the offset for the newly bundled
|
// Current memory pool size is the offset for the newly bundled
|
||||||
// internal MemRef. Emit the offset as a constant.
|
// internal MemRef. Emit the offset as a constant.
|
||||||
auto offset = rewriter.create<ConstantOp>(
|
auto offset = rewriter.create<ConstantOp>(
|
||||||
|
@ -127,6 +245,201 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* RewritePattern that merges a new dynamic AllocOp with the existing dynamic
|
||||||
|
* memory pool.
|
||||||
|
* %dyn_mempool = alloc(%a) : memref<?xi8>
|
||||||
|
* %new_alloc = alloc(%b) : memref<?xi8>
|
||||||
|
* %new_ref = krnl.getref %new_alloc 0 : memref<?xi8>
|
||||||
|
* =>
|
||||||
|
* %c = addi %a, %b
|
||||||
|
* %dyn_mempool = alloc(%c) : memref<?xi8>
|
||||||
|
* %new_ref = krnl.getref %dyn_mempool %a : memref<?xi8>
|
||||||
|
*/
|
||||||
|
|
||||||
|
class KrnlBundleDynamicMemoryPools : public OpRewritePattern<AllocOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<AllocOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
AllocOp allocOp, PatternRewriter &rewriter) const override {
|
||||||
|
auto loc = allocOp.getLoc();
|
||||||
|
|
||||||
|
auto memRefType = convertToMemRefType(allocOp.getResult().getType());
|
||||||
|
auto memRefShape = memRefType.getShape();
|
||||||
|
|
||||||
|
// If alloca result is not used by getref then it cannot be part of
|
||||||
|
// the memory pool.
|
||||||
|
if (!checkOpResultIsUsedByGetRef(&allocOp))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Only handle dynamic allocs here.
|
||||||
|
if (hasAllConstantDimensions(memRefType))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Alloc memory type must be byte.
|
||||||
|
if (getMemRefEltSizeInBytes(memRefType) != 1)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Rank of the allocated MemRef must be 1.
|
||||||
|
if (memRefShape.size() != 1)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Visit dependendent operations in the current parent block and assemble
|
||||||
|
// a trace of operations which participate in the computation of the size
|
||||||
|
// of the AllocOp.
|
||||||
|
auto parentBlock = allocOp.getOperation()->getBlock();
|
||||||
|
|
||||||
|
// Get function.
|
||||||
|
FuncOp function = getContainingFunction(allocOp);
|
||||||
|
Block *firstBlock = &function.getBody().front();
|
||||||
|
|
||||||
|
// If this is the alloc representing the memory pool and the function
|
||||||
|
// already has an init block, pattern matching must fail to avoid
|
||||||
|
// processing the dynamic memory pool a second time.
|
||||||
|
if (hasInitBlock(function)) {
|
||||||
|
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function);
|
||||||
|
if (allocOp == initState->dynamicMemoryPool)
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize work queue data structure.
|
||||||
|
Operation *op = allocOp.getOperation();
|
||||||
|
std::vector<Value> operandList;
|
||||||
|
for (const auto &operand : allocOp.getOperands()) {
|
||||||
|
operandList.emplace_back(operand);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if list of operations depends on dynamic local AllocOp.
|
||||||
|
bool dependsOnLocalDynamicAlloc = false;
|
||||||
|
|
||||||
|
// Construct the list of Values on which the current AllocOp depends on.
|
||||||
|
llvm::SetVector<Operation *> dependentOps;
|
||||||
|
while (operandList.size() > 0) {
|
||||||
|
Value currentElement = operandList[0];
|
||||||
|
Operation *definingOperation = currentElement.getDefiningOp();
|
||||||
|
|
||||||
|
// If this value has not been seen before, process it.
|
||||||
|
if (dependentOps.count(definingOperation) == 0) {
|
||||||
|
// Add value to dependent values list.
|
||||||
|
dependentOps.insert(definingOperation);
|
||||||
|
|
||||||
|
// Add operands to work queue.
|
||||||
|
// printf("Processing the args of the following op:\n");
|
||||||
|
for (const auto &operand : definingOperation->getOperands()) {
|
||||||
|
// Check operand is not a block argument. If it is skip it, we
|
||||||
|
// consider block arguments to be leafs.
|
||||||
|
if (!isBlockArgument(firstBlock, operand)) {
|
||||||
|
operandList.emplace_back(operand);
|
||||||
|
|
||||||
|
// Check if the current operation is a dim or a load and the
|
||||||
|
// argument list involves a local AllocOp with dynamic sizes.
|
||||||
|
// If that's the case then it means that the whole set of
|
||||||
|
// instructions cannot be moved.
|
||||||
|
// Check if the current operation is a DimOp or a LoadOp.
|
||||||
|
if (llvm::dyn_cast<DimOp>(definingOperation) ||
|
||||||
|
llvm::dyn_cast<LoadOp>(definingOperation)) {
|
||||||
|
Operation *operandOp = operand.getDefiningOp();
|
||||||
|
if (operandOp) {
|
||||||
|
auto localAlloc = llvm::dyn_cast<AllocOp>(operandOp);
|
||||||
|
if (localAlloc) {
|
||||||
|
auto memRefType =
|
||||||
|
convertToMemRefType(localAlloc.getResult().getType());
|
||||||
|
if (!hasAllConstantDimensions(memRefType))
|
||||||
|
dependsOnLocalDynamicAlloc = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If operand is a getref then this alloc cannot be bundled.
|
||||||
|
auto memPool = llvm::dyn_cast<KrnlGetRefOp>(operandOp);
|
||||||
|
if (memPool)
|
||||||
|
dependsOnLocalDynamicAlloc = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Erase first element from work queue.
|
||||||
|
operandList.erase(operandList.begin());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dependsOnLocalDynamicAlloc)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Order the dependent values in the same order they appear in the code.
|
||||||
|
// One cannot iterate over and make changes to the order of the operations
|
||||||
|
// of a block. A temporary ordered list of dependent instructions is
|
||||||
|
// necessary.
|
||||||
|
llvm::SmallVector<Operation *, 32> orderedDependentOps;
|
||||||
|
for (auto &op :
|
||||||
|
llvm::make_range(parentBlock->begin(), std::prev(parentBlock->end())))
|
||||||
|
if (dependentOps.count(&op) > 0)
|
||||||
|
orderedDependentOps.emplace_back(&op);
|
||||||
|
|
||||||
|
// If no dynamic alloc is in the trace of the dependent operations,
|
||||||
|
// emit the size calculation in the init block, if one exists already,
|
||||||
|
// if not, create the init block.
|
||||||
|
bool addedInitBlock = addInitBlock(rewriter, loc, allocOp);
|
||||||
|
|
||||||
|
// Move the ordered dependent size calculation to the init block.
|
||||||
|
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function);
|
||||||
|
for (auto &op : orderedDependentOps)
|
||||||
|
op->moveBefore(initState->branchInit);
|
||||||
|
|
||||||
|
// Bundle MemRef type: <?xi8>
|
||||||
|
SmallVector<int64_t, 1> memPoolShape;
|
||||||
|
memPoolShape.emplace_back(-1);
|
||||||
|
auto bundledMemPoolMemRefType =
|
||||||
|
MemRefType::get(memPoolShape, rewriter.getIntegerType(8));
|
||||||
|
|
||||||
|
// Get the getref of the current allocOp. There is exactly one such getref.
|
||||||
|
KrnlGetRefOp currentAllocGetRef = getCurrentAllocGetRef(&allocOp);
|
||||||
|
if (!currentAllocGetRef)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Add the current alloc size to the current MemPool size.
|
||||||
|
Value dynamicMemoryPoolSize = initState->dynamicMemoryPool.getOperand(0);
|
||||||
|
if (addedInitBlock) {
|
||||||
|
Value zero = emitConstantOp(rewriter, loc, rewriter.getIndexType(), 0);
|
||||||
|
zero.getDefiningOp()->moveBefore(initState->branchInit);
|
||||||
|
dynamicMemoryPoolSize = zero;
|
||||||
|
}
|
||||||
|
|
||||||
|
AddIOp bundledAllocOperand = rewriter.create<AddIOp>(
|
||||||
|
loc, dynamicMemoryPoolSize, allocOp.getOperand(0));
|
||||||
|
bundledAllocOperand.getOperation()->moveBefore(initState->branchInit);
|
||||||
|
|
||||||
|
// The newly bundled MemRef expressed as a KrnlGetRefOp.
|
||||||
|
// Current memory pool size is the offset for the newly bundled
|
||||||
|
// internal MemRef.
|
||||||
|
Value integerDynamicMemoryPoolSize = rewriter.create<IndexCastOp>(
|
||||||
|
loc, dynamicMemoryPoolSize, rewriter.getIntegerType(64));
|
||||||
|
integerDynamicMemoryPoolSize.getDefiningOp()->moveBefore(
|
||||||
|
initState->branchInit);
|
||||||
|
|
||||||
|
// We need to emit a new alloc which contains the additional MemRef.
|
||||||
|
AllocOp bundledAlloc = rewriter.create<AllocOp>(
|
||||||
|
loc, bundledMemPoolMemRefType, bundledAllocOperand.getResult());
|
||||||
|
bundledAlloc.getOperation()->moveBefore(&initState->mainBlock->front());
|
||||||
|
|
||||||
|
KrnlGetRefOp bundledMemRef = rewriter.create<KrnlGetRefOp>(loc,
|
||||||
|
currentAllocGetRef.getResult().getType(), bundledAlloc,
|
||||||
|
integerDynamicMemoryPoolSize);
|
||||||
|
bundledMemRef.getOperation()->moveAfter(bundledAlloc);
|
||||||
|
|
||||||
|
// Replace old memory pool with new one.
|
||||||
|
rewriter.replaceOp(initState->dynamicMemoryPool, bundledAlloc.getResult());
|
||||||
|
|
||||||
|
// Replace old getref with new getref from new memory pool.
|
||||||
|
rewriter.replaceOp(currentAllocGetRef, bundledMemRef.getResult());
|
||||||
|
|
||||||
|
// Update MemPool size.
|
||||||
|
initState->dynamicMemoryPool = bundledAlloc;
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* Function pass that enables memory pooling for MemRefs.
|
* Function pass that enables memory pooling for MemRefs.
|
||||||
*/
|
*/
|
||||||
|
@ -136,11 +449,22 @@ public:
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
auto function = getFunction();
|
auto function = getFunction();
|
||||||
|
|
||||||
|
// ModuleOp module = cast<ModuleOp>(function.getParentOp());
|
||||||
|
initMap.insert(std::pair<FuncOp, std::unique_ptr<ONNXOperandsInitState>>(
|
||||||
|
function, std::make_unique<ONNXOperandsInitState>()));
|
||||||
|
|
||||||
|
// Initialize state for this function.
|
||||||
|
std::unique_ptr<ONNXOperandsInitState> &initState = initMap.at(function);
|
||||||
|
initState->initBlock = nullptr;
|
||||||
|
|
||||||
ConversionTarget target(getContext());
|
ConversionTarget target(getContext());
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
patterns.insert<KrnlBundleMemoryPools>(&getContext());
|
patterns.insert<KrnlBundleMemoryPools, KrnlBundleDynamicMemoryPools>(
|
||||||
|
&getContext());
|
||||||
|
|
||||||
applyPatternsAndFoldGreedily(function, patterns);
|
applyPatternsAndFoldGreedily(function, patterns);
|
||||||
|
|
||||||
|
initMap.erase(function);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -51,3 +51,135 @@ func @test_pool_bundling(%arg0: memref<10x10xf32>, %arg1: memref<10x20xf32>) ->
|
||||||
// CHECK: dealloc [[MEMPOOL]] : memref<3200xi8>
|
// CHECK: dealloc [[MEMPOOL]] : memref<3200xi8>
|
||||||
// CHECK: return [[RES]] : memref<10x20xf32>
|
// CHECK: return [[RES]] : memref<10x20xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @test_dynamic_pool_bundling(%arg0: memref<?x?xf32>) -> memref<?x10xf32> {
|
||||||
|
%c1 = constant 1 : index
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%cst = constant 0.000000e+00 : f32
|
||||||
|
%ind = constant 0 : index
|
||||||
|
%c4 = constant 4 : index
|
||||||
|
%c10 = constant 10 : index
|
||||||
|
%c0_i64 = constant 0 : i64
|
||||||
|
%0 = dim %arg0, %c0 : memref<?x?xf32>
|
||||||
|
%1 = muli %0, %c4 : index
|
||||||
|
%2 = muli %1, %c10 : index
|
||||||
|
%3 = alloc(%2) : memref<?xi8>
|
||||||
|
%4 = "krnl.getref"(%3, %c0_i64) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
||||||
|
%6 = cmpi "sgt", %0, %0 : index
|
||||||
|
%7 = select %6, %0, %0 : index
|
||||||
|
%8 = muli %7, %c4 : index
|
||||||
|
%9 = muli %8, %c10 : index
|
||||||
|
%10 = alloc(%9) : memref<?xi8>
|
||||||
|
%11 = "krnl.getref"(%10, %c0_i64) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
||||||
|
%12 = cmpi "eq", %0, %c1 : index
|
||||||
|
%13 = cmpi "eq", %0, %c1 : index
|
||||||
|
%15 = alloc(%0) : memref<?x10xf32>
|
||||||
|
affine.store %cst, %4[%ind, %ind] : memref<?x10xf32>
|
||||||
|
affine.store %cst, %11[%ind, %ind] : memref<?x10xf32>
|
||||||
|
affine.store %cst, %15[%ind, %ind] : memref<?x10xf32>
|
||||||
|
dealloc %10 : memref<?xi8>
|
||||||
|
dealloc %3 : memref<?xi8>
|
||||||
|
return %15 : memref<?x10xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_dynamic_pool_bundling
|
||||||
|
// CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32
|
||||||
|
// CHECK: [[C0:%.+]] = constant 0 : index
|
||||||
|
// CHECK: [[C4:%.+]] = constant 4 : index
|
||||||
|
// CHECK: [[C10:%.+]] = constant 10 : index
|
||||||
|
// CHECK: [[C0_I64:%.+]] = constant 0 : i64
|
||||||
|
// CHECK: [[DIM:%.+]] = dim %arg0, [[C0]] : memref<?x?xf32>
|
||||||
|
// CHECK: [[SGT:%.+]] = cmpi "sgt", [[DIM]], [[DIM]] : index
|
||||||
|
// CHECK: [[SELECT:%.+]] = select [[SGT]], [[DIM]], [[DIM]] : index
|
||||||
|
// CHECK: [[MUL1:%.+]] = muli [[SELECT]], [[C4]] : index
|
||||||
|
// CHECK: [[OFFSET1:%.+]] = muli [[MUL1]], [[C10]] : index
|
||||||
|
// CHECK: [[MUL2:%.+]] = muli [[DIM]], [[C4]] : index
|
||||||
|
// CHECK: [[OFFSET2:%.+]] = muli [[MUL2]], [[C10]] : index
|
||||||
|
// CHECK: [[MEMPOOL_SIZE:%.+]] = addi [[OFFSET1]], [[OFFSET2]] : index
|
||||||
|
// CHECK: [[OFFSET1_I64:%.+]] = index_cast [[OFFSET1]] : index to i64
|
||||||
|
// CHECK: [[DYN_MEMPOOL:%.+]] = alloc([[MEMPOOL_SIZE]]) : memref<?xi8>
|
||||||
|
// CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
||||||
|
// CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[C0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM]]) : memref<?x10xf32>
|
||||||
|
// CHECK: affine.store [[CST]], [[DATA1]][0, 0] : memref<?x10xf32>
|
||||||
|
// CHECK: affine.store [[CST]], [[DATA2]][0, 0] : memref<?x10xf32>
|
||||||
|
// CHECK: affine.store [[CST]], [[RES]][0, 0] : memref<?x10xf32>
|
||||||
|
// CHECK: dealloc [[DYN_MEMPOOL]] : memref<?xi8>
|
||||||
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_dynamic_and_static_pool_bundling(%arg0: memref<?x?xf32>, %arg1: memref<10x10xf32>) -> memref<?x10xf32> {
|
||||||
|
%c1 = constant 1 : index
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%cst = constant 0.000000e+00 : f32
|
||||||
|
%ind = constant 0 : index
|
||||||
|
%c4 = constant 4 : index
|
||||||
|
%c10 = constant 10 : index
|
||||||
|
%c0_i64 = constant 0 : i64
|
||||||
|
%0 = dim %arg0, %c0 : memref<?x?xf32>
|
||||||
|
%1 = muli %0, %c4 : index
|
||||||
|
%2 = muli %1, %c10 : index
|
||||||
|
%3 = alloc(%2) : memref<?xi8>
|
||||||
|
%4 = "krnl.getref"(%3, %c0_i64) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
||||||
|
%const_alloc1 = alloc() : memref<800xi8>
|
||||||
|
%const_ref1 = "krnl.getref"(%const_alloc1, %c0_i64) : (memref<800xi8>, i64) -> memref<10x20xf32>
|
||||||
|
%const_alloc2 = alloc() : memref<400xi8>
|
||||||
|
%const_ref2 = "krnl.getref"(%const_alloc2, %c0_i64) : (memref<400xi8>, i64) -> memref<10x10xf32>
|
||||||
|
%6 = cmpi "sgt", %0, %0 : index
|
||||||
|
%7 = select %6, %0, %0 : index
|
||||||
|
%8 = muli %7, %c4 : index
|
||||||
|
%9 = muli %8, %c10 : index
|
||||||
|
%10 = alloc(%9) : memref<?xi8>
|
||||||
|
%11 = "krnl.getref"(%10, %c0_i64) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
||||||
|
%12 = cmpi "eq", %0, %c1 : index
|
||||||
|
%13 = cmpi "eq", %0, %c1 : index
|
||||||
|
%15 = alloc(%0) : memref<?x10xf32>
|
||||||
|
%const_alloc3 = alloc() : memref<1600xi8>
|
||||||
|
%const_ref3 = "krnl.getref"(%const_alloc3, %c0_i64) : (memref<1600xi8>, i64) -> memref<10x40xf32>
|
||||||
|
affine.store %cst, %4[%ind, %ind] : memref<?x10xf32>
|
||||||
|
affine.store %cst, %11[%ind, %ind] : memref<?x10xf32>
|
||||||
|
affine.store %cst, %15[%ind, %ind] : memref<?x10xf32>
|
||||||
|
affine.store %cst, %const_ref2[%ind, %ind] : memref<10x10xf32>
|
||||||
|
affine.store %cst, %const_ref1[%ind, %ind] : memref<10x20xf32>
|
||||||
|
affine.store %cst, %const_ref3[%ind, %ind] : memref<10x40xf32>
|
||||||
|
dealloc %10 : memref<?xi8>
|
||||||
|
dealloc %3 : memref<?xi8>
|
||||||
|
dealloc %const_alloc1 : memref<800xi8>
|
||||||
|
dealloc %const_alloc2 : memref<400xi8>
|
||||||
|
dealloc %const_alloc3 : memref<1600xi8>
|
||||||
|
return %15 : memref<?x10xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_dynamic_and_static_pool_bundling
|
||||||
|
// CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32
|
||||||
|
// CHECK: [[C0:%.+]] = constant 0 : index
|
||||||
|
// CHECK: [[C4:%.+]] = constant 4 : index
|
||||||
|
// CHECK: [[C10:%.+]] = constant 10 : index
|
||||||
|
// CHECK: [[C400_I64:%.+]] = constant 400 : i64
|
||||||
|
// CHECK: [[C2000_I64:%.+]] = constant 2000 : i64
|
||||||
|
// CHECK: [[C0_I64:%.+]] = constant 0 : i64
|
||||||
|
// CHECK: [[DIM:%.+]] = dim %arg0, [[C0]] : memref<?x?xf32>
|
||||||
|
// CHECK: [[SGT:%.+]] = cmpi "sgt", [[DIM]], [[DIM]] : index
|
||||||
|
// CHECK: [[SELECT:%.+]] = select [[SGT]], [[DIM]], [[DIM]] : index
|
||||||
|
// CHECK: [[MUL1:%.+]] = muli [[SELECT]], [[C4]] : index
|
||||||
|
// CHECK: [[OFFSET1:%.+]] = muli [[MUL1]], [[C10]] : index
|
||||||
|
// CHECK: [[MUL2:%.+]] = muli [[DIM]], [[C4]] : index
|
||||||
|
// CHECK: [[OFFSET2:%.+]] = muli [[MUL2]], [[C10]] : index
|
||||||
|
// CHECK: [[MEMPOOL_SIZE:%.+]] = addi [[OFFSET1]], [[OFFSET2]] : index
|
||||||
|
// CHECK: [[OFFSET1_I64:%.+]] = index_cast [[OFFSET1]] : index to i64
|
||||||
|
// CHECK: [[DYN_MEMPOOL:%.+]] = alloc([[MEMPOOL_SIZE]]) : memref<?xi8>
|
||||||
|
// CHECK: [[DATA1:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[OFFSET1_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
||||||
|
// CHECK: [[DATA2:%.+]] = "krnl.getref"([[DYN_MEMPOOL]], [[C0_I64]]) : (memref<?xi8>, i64) -> memref<?x10xf32>
|
||||||
|
// CHECK: [[STATIC_MEMPOOL:%.+]] = alloc() : memref<2800xi8>
|
||||||
|
// CHECK: [[DATA3:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C2000_I64]]) : (memref<2800xi8>, i64) -> memref<10x20xf32>
|
||||||
|
// CHECK: [[DATA4:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C400_I64]]) : (memref<2800xi8>, i64) -> memref<10x40xf32>
|
||||||
|
// CHECK: [[DATA5:%.+]] = "krnl.getref"([[STATIC_MEMPOOL]], [[C0_I64]]) : (memref<2800xi8>, i64) -> memref<10x10xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM]]) : memref<?x10xf32>
|
||||||
|
// CHECK: affine.store [[CST]], [[DATA1]][0, 0] : memref<?x10xf32>
|
||||||
|
// CHECK: affine.store [[CST]], [[DATA2]][0, 0] : memref<?x10xf32>
|
||||||
|
// CHECK: affine.store [[CST]], [[RES]][0, 0] : memref<?x10xf32>
|
||||||
|
// CHECK: affine.store [[CST]], [[DATA5]][0, 0] : memref<10x10xf32>
|
||||||
|
// CHECK: affine.store [[CST]], [[DATA3]][0, 0] : memref<10x20xf32>
|
||||||
|
// CHECK: affine.store [[CST]], [[DATA4]][0, 0] : memref<10x40xf32>
|
||||||
|
// CHECK: dealloc [[DYN_MEMPOOL]] : memref<?xi8>
|
||||||
|
// CHECK: dealloc [[STATIC_MEMPOOL]] : memref<2800xi8>
|
||||||
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue