Add basic support for memory pool (#161)
* Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Replace internal malloc with memory pool and getref instruction. * Lower krnl.getref to LLVM. * Fix formatting issues. * Add tests. * Add missing dependency. * Improve LLVM lowering. * Add test to show getref is generic.
This commit is contained in:
parent
ca185002f2
commit
4ab96fbc6c
|
@ -268,7 +268,8 @@ set(ONNXMLIRWholeArchiveLibs
|
||||||
OMAttributePromotion
|
OMAttributePromotion
|
||||||
OMPromotableConstOperandsOpInterface
|
OMPromotableConstOperandsOpInterface
|
||||||
OMElideConstants
|
OMElideConstants
|
||||||
OMElideKrnlGlobalConstants)
|
OMElideKrnlGlobalConstants
|
||||||
|
OMEnableMemoryPool)
|
||||||
|
|
||||||
# Function to construct linkage option for the static libraries that must be
|
# Function to construct linkage option for the static libraries that must be
|
||||||
# linked with --whole-archive (or equivalent).
|
# linked with --whole-archive (or equivalent).
|
||||||
|
|
|
@ -57,6 +57,7 @@ target_link_libraries(MainUtils
|
||||||
OMResultTypeInferenceOpInterface
|
OMResultTypeInferenceOpInterface
|
||||||
OMElideConstants
|
OMElideConstants
|
||||||
OMElideKrnlGlobalConstants
|
OMElideKrnlGlobalConstants
|
||||||
|
OMEnableMemoryPool
|
||||||
OMKrnlToAffine
|
OMKrnlToAffine
|
||||||
OMKrnlToLLVM
|
OMKrnlToLLVM
|
||||||
OMONNXToKrnl
|
OMONNXToKrnl
|
||||||
|
|
|
@ -205,3 +205,21 @@ def KrnlGlobalOp : Op<Krnl_Dialect, "global"> {
|
||||||
let parser = ?;
|
let parser = ?;
|
||||||
let printer = ?;
|
let printer = ?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def KrnlGetRefOp : Op<Krnl_Dialect, "getref"> {
|
||||||
|
let summary = "Krnl a MemRef from within another MemRef starting at a specific offset.";
|
||||||
|
let description = [{
|
||||||
|
Retreieves a MemRef from within another MemRef:
|
||||||
|
|
||||||
|
"krnl.getref"(%memref, %offset)
|
||||||
|
|
||||||
|
The offset is an integer which is used as an index into the input MemRef. It works
|
||||||
|
just like an array index.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins AnyTypeOf<[AnyMemRef]>:$mempool, AnyInteger:$offset);
|
||||||
|
let results = (outs AnyTypeOf<[AnyMemRef]>:$output);
|
||||||
|
|
||||||
|
let parser = ?;
|
||||||
|
let printer = ?;
|
||||||
|
}
|
||||||
|
|
|
@ -102,6 +102,9 @@ void addONNXToKrnlPasses(mlir::PassManager &pm) {
|
||||||
// from ONNX dialect to Standard dialect exposes additional canonicalization
|
// from ONNX dialect to Standard dialect exposes additional canonicalization
|
||||||
// oppertunities.
|
// oppertunities.
|
||||||
pm.addPass(mlir::createCanonicalizerPass());
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
|
|
||||||
|
// TODO: make this pass optional:
|
||||||
|
pm.addPass(mlir::createKrnlEnableMemoryPoolPass());
|
||||||
}
|
}
|
||||||
|
|
||||||
void addKrnlToAffinePasses(mlir::PassManager &pm) {
|
void addKrnlToAffinePasses(mlir::PassManager &pm) {
|
||||||
|
|
|
@ -28,6 +28,9 @@ std::unique_ptr<Pass> createAttributePromotionPass();
|
||||||
/// Pass for eliding the values of constant operations.
|
/// Pass for eliding the values of constant operations.
|
||||||
std::unique_ptr<Pass> createElideConstantValuePass();
|
std::unique_ptr<Pass> createElideConstantValuePass();
|
||||||
|
|
||||||
|
/// Pass for enabling a memory pool for MemRefs.
|
||||||
|
std::unique_ptr<Pass> createKrnlEnableMemoryPoolPass();
|
||||||
|
|
||||||
/// Add pass for lowering to Krnl IR.
|
/// Add pass for lowering to Krnl IR.
|
||||||
std::unique_ptr<Pass> createLowerToKrnlPass();
|
std::unique_ptr<Pass> createLowerToKrnlPass();
|
||||||
|
|
||||||
|
|
|
@ -38,4 +38,17 @@ add_dependencies(OMElideKrnlGlobalConstants OMKrnlOpsInc)
|
||||||
# Linking dependencies
|
# Linking dependencies
|
||||||
add_dependencies(OMElideKrnlGlobalConstants OMKrnlOps)
|
add_dependencies(OMElideKrnlGlobalConstants OMKrnlOps)
|
||||||
|
|
||||||
|
add_library(OMEnableMemoryPool
|
||||||
|
EnableMemoryPool.cpp)
|
||||||
|
target_include_directories(OMEnableMemoryPool
|
||||||
|
PRIVATE
|
||||||
|
${ONNX_MLIR_SRC_ROOT}
|
||||||
|
${ONNX_MLIR_BIN_ROOT}
|
||||||
|
${ONNX_MLIR_SRC_ROOT})
|
||||||
|
target_link_libraries(OMEnableMemoryPool
|
||||||
|
onnx)
|
||||||
|
add_dependencies(OMEnableMemoryPool
|
||||||
|
OMKrnlOps
|
||||||
|
OMONNXOps)
|
||||||
|
|
||||||
add_subdirectory(ONNX)
|
add_subdirectory(ONNX)
|
||||||
|
|
|
@ -0,0 +1,159 @@
|
||||||
|
//===-------- EnableMemoryPool.cpp - Enable Memory Pool for MemRefs -------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019-2020 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// For certain cases the number of individual memory allocations required for
|
||||||
|
// all internal tensors is large and needs to be mitigated. This pass enables a
|
||||||
|
// managed memory pool for allocating MemRefs.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
|
||||||
|
#include "src/Dialect/Krnl/KrnlOps.hpp"
|
||||||
|
#include "src/Pass/Passes.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
bool checkOpResultIsReturned(AllocOp *allocOp) {
|
||||||
|
auto parentBlock = allocOp->getOperation()->getBlock();
|
||||||
|
|
||||||
|
bool opIsReturned = false;
|
||||||
|
parentBlock->walk([&opIsReturned, allocOp](ReturnOp op) {
|
||||||
|
auto result = allocOp->getResult();
|
||||||
|
for (const auto &operand : op.getOperands())
|
||||||
|
if (operand == result)
|
||||||
|
opIsReturned = true;
|
||||||
|
});
|
||||||
|
|
||||||
|
return opIsReturned;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool checkOpResultIsUsedByGetRef(AllocOp *allocOp) {
|
||||||
|
auto parentBlock = allocOp->getOperation()->getBlock();
|
||||||
|
|
||||||
|
bool opIsUsedInGetRef = false;
|
||||||
|
parentBlock->walk([&opIsUsedInGetRef, allocOp](KrnlGetRefOp op) {
|
||||||
|
auto result = allocOp->getResult();
|
||||||
|
for (const auto &operand : op.getOperands())
|
||||||
|
if (operand == result)
|
||||||
|
opIsUsedInGetRef = true;
|
||||||
|
});
|
||||||
|
|
||||||
|
return opIsUsedInGetRef;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* RewritePattern that replaces:
|
||||||
|
* %0 = alloc() : memref<<dims>x<type>>
|
||||||
|
* with:
|
||||||
|
* %mem = alloc() : memref<<dims>x<type>>
|
||||||
|
* %0 = krnl.getref %mem <offset> : memref<<dims>x<type>>
|
||||||
|
*
|
||||||
|
* For now, to enable testing, offset will always be 0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
class KrnlEnableMemoryPool : public OpRewritePattern<AllocOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<AllocOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
AllocOp allocOp, PatternRewriter &rewriter) const override {
|
||||||
|
auto loc = allocOp.getLoc();
|
||||||
|
|
||||||
|
auto memRefType = convertToMemRefType(allocOp.getResult().getType());
|
||||||
|
|
||||||
|
// For now we only support constant tensors.
|
||||||
|
// TODO: Enable this pass for MemRef with dyanmic shapes.
|
||||||
|
// If alloc operation is not returned then it is a candidate for
|
||||||
|
// being included in the memory pool.
|
||||||
|
if (!hasAllConstantDimensions(memRefType) ||
|
||||||
|
checkOpResultIsReturned(&allocOp))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Check the result of this alloc is not already used by a krnl.getref.
|
||||||
|
if (checkOpResultIsUsedByGetRef(&allocOp))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Compute total size.
|
||||||
|
auto memRefShape = memRefType.getShape();
|
||||||
|
int64_t totalSize = 1;
|
||||||
|
for (int i = 0; i < memRefShape.size(); i++)
|
||||||
|
totalSize *= memRefShape[i];
|
||||||
|
totalSize *= getMemRefEltSizeInBytes(memRefType);
|
||||||
|
|
||||||
|
// Emit new alloc.
|
||||||
|
SmallVector<int64_t, 1> memPoolShape;
|
||||||
|
memPoolShape.emplace_back(totalSize);
|
||||||
|
auto memPoolMemRefType =
|
||||||
|
MemRefType::get(memPoolShape, rewriter.getIntegerType(8));
|
||||||
|
auto newAlloc = rewriter.create<AllocOp>(loc, memPoolMemRefType);
|
||||||
|
|
||||||
|
// Emit new dealloc.
|
||||||
|
auto dealloc = rewriter.create<DeallocOp>(loc, newAlloc);
|
||||||
|
auto parentBlock = allocOp.getOperation()->getBlock();
|
||||||
|
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
||||||
|
|
||||||
|
// Get reference to local MemRef.
|
||||||
|
auto zero = rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||||
|
auto poolMemRef =
|
||||||
|
rewriter.create<KrnlGetRefOp>(loc, memRefType, newAlloc, zero);
|
||||||
|
|
||||||
|
rewriter.replaceOp(allocOp, poolMemRef.getResult());
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class KrnlEliminateOldDealloc : public OpRewritePattern<DeallocOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<DeallocOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
DeallocOp deallocOp, PatternRewriter &rewriter) const override {
|
||||||
|
if (auto getRefOp = llvm::dyn_cast<KrnlGetRefOp>(
|
||||||
|
deallocOp.getOperand().getDefiningOp())) {
|
||||||
|
rewriter.eraseOp(deallocOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: Replace old dealloc with krnl.unsetref.
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* Function pass that enables memory pooling for MemRefs.
|
||||||
|
*/
|
||||||
|
class KrnlEnableMemoryPoolPass
|
||||||
|
: public PassWrapper<KrnlEnableMemoryPoolPass, FunctionPass> {
|
||||||
|
public:
|
||||||
|
void runOnFunction() override {
|
||||||
|
auto function = getFunction();
|
||||||
|
|
||||||
|
ConversionTarget target(getContext());
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
patterns.insert<KrnlEnableMemoryPool>(&getContext());
|
||||||
|
patterns.insert<KrnlEliminateOldDealloc>(&getContext());
|
||||||
|
|
||||||
|
applyPatternsAndFoldGreedily(function, patterns);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> mlir::createKrnlEnableMemoryPoolPass() {
|
||||||
|
return std::make_unique<KrnlEnableMemoryPoolPass>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<KrnlEnableMemoryPoolPass> pass("enable-memory-pool",
|
||||||
|
"Enable a memory pool for allocating internal MemRefs.");
|
|
@ -165,6 +165,7 @@ void KrnlToAffineLoweringPass::runOnFunction() {
|
||||||
target.addLegalOp<KrnlMemcpyOp>();
|
target.addLegalOp<KrnlMemcpyOp>();
|
||||||
target.addLegalOp<KrnlEntryPointOp>();
|
target.addLegalOp<KrnlEntryPointOp>();
|
||||||
target.addLegalOp<KrnlGlobalOp>();
|
target.addLegalOp<KrnlGlobalOp>();
|
||||||
|
target.addLegalOp<KrnlGetRefOp>();
|
||||||
|
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
|
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
|
||||||
|
|
|
@ -88,6 +88,71 @@ static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter,
|
||||||
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
|
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// KRNL to LLVM: KrnlGetRefOpLowering
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class KrnlGetRefOpLowering : public ConvertToLLVMPattern {
|
||||||
|
public:
|
||||||
|
explicit KrnlGetRefOpLowering(
|
||||||
|
MLIRContext *context, LLVMTypeConverter &lowering_)
|
||||||
|
: ConvertToLLVMPattern(
|
||||||
|
KrnlGetRefOp::getOperationName(), context, lowering_) {}
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto *context = op->getContext();
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
auto *llvmDialect =
|
||||||
|
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||||
|
assert(llvmDialect && "expected llvm dialect to be registered");
|
||||||
|
|
||||||
|
KrnlGetRefOpOperandAdaptor operandAdaptor(operands);
|
||||||
|
|
||||||
|
// This is the type of the krnl.getref output. This type is used
|
||||||
|
// for the type of the internal MemRef.
|
||||||
|
auto type = op->getResult(0).getType();
|
||||||
|
auto memRefTy = type.cast<mlir::MemRefType>();
|
||||||
|
auto llvmMemRefType =
|
||||||
|
typeConverter.convertType(type).cast<LLVM::LLVMType>();
|
||||||
|
auto outputElementType =
|
||||||
|
typeConverter.convertType(memRefTy.getElementType());
|
||||||
|
|
||||||
|
// This is the start of the memory pool containing the output MemRef.
|
||||||
|
Type memPoolType = operandAdaptor.mempool()
|
||||||
|
.getType()
|
||||||
|
.cast<LLVM::LLVMType>()
|
||||||
|
.getStructElementType(1);
|
||||||
|
Value alignedMemPoolBase = rewriter.create<LLVM::ExtractValueOp>(loc,
|
||||||
|
memPoolType, operandAdaptor.mempool(), rewriter.getI64ArrayAttr(1));
|
||||||
|
|
||||||
|
// Get pointer using the offset.
|
||||||
|
auto offset = operandAdaptor.offset();
|
||||||
|
auto llvmMemPoolType =
|
||||||
|
typeConverter.convertType(memPoolType).cast<LLVM::LLVMType>();
|
||||||
|
auto outputMemPoolTypePtrAlloc = rewriter.create<LLVM::GEPOp>(
|
||||||
|
loc, llvmMemPoolType, alignedMemPoolBase, ArrayRef<Value>({offset}));
|
||||||
|
|
||||||
|
// Bitcast to output MemRef type i.e. from i8* to the element type
|
||||||
|
// of the output MemRef.
|
||||||
|
auto llvmOutputElementType = outputElementType.cast<LLVM::LLVMType>();
|
||||||
|
Value outputTypedPtrAlloc = rewriter.create<LLVM::BitcastOp>(
|
||||||
|
loc, llvmOutputElementType.getPointerTo(), outputMemPoolTypePtrAlloc);
|
||||||
|
|
||||||
|
// Create llvm MemRef from original MemRef and fill the data pointers.
|
||||||
|
auto llvmMemRef = MemRefDescriptor::fromStaticShape(
|
||||||
|
rewriter, loc, typeConverter, memRefTy, outputTypedPtrAlloc);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, {llvmMemRef});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
static int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
|
||||||
|
return (a.getValue()[i]).cast<IntegerAttr>().getInt();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// KRNL to LLVM: KrnlGlobalOpLowering
|
// KRNL to LLVM: KrnlGlobalOpLowering
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -648,6 +713,7 @@ void KrnlToLLVMLoweringPass::runOnOperation() {
|
||||||
/*useAlignedAlloc=*/false);
|
/*useAlignedAlloc=*/false);
|
||||||
|
|
||||||
patterns.insert<KrnlGlobalOpLowering>(&getContext(), typeConverter);
|
patterns.insert<KrnlGlobalOpLowering>(&getContext(), typeConverter);
|
||||||
|
patterns.insert<KrnlGetRefOpLowering>(&getContext(), typeConverter);
|
||||||
|
|
||||||
// Lower from the `krnl` dialect i.e. the Reshape operation.
|
// Lower from the `krnl` dialect i.e. the Reshape operation.
|
||||||
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering>(
|
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering>(
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
// RUN: onnx-mlir-opt --lower-krnl --lower-all-llvm %s -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
func @test_getref_lowering(%arg0: memref<2x2xf32>) -> memref<2x2xf32> {
|
||||||
|
%c13_i64 = constant 13 : i64
|
||||||
|
%1 = alloc() : memref<10x10xf32>
|
||||||
|
%2 = "krnl.getref"(%1, %c13_i64) : (memref<10x10xf32>, i64) -> memref<2x2xf32>
|
||||||
|
return %2 : memref<2x2xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_getref_lowering
|
||||||
|
// CHECK: %[[OFFSET:.+]] = llvm.mlir.constant(13 : i64) : !llvm.i64
|
||||||
|
// CHECK: [[CONST_10_0:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64
|
||||||
|
// CHECK: [[CONST_10_1:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64
|
||||||
|
// CHECK: [[MUL1:%.+]] = llvm.mul [[CONST_10_0]], [[CONST_10_1]] : !llvm.i64
|
||||||
|
// CHECK: [[FLOAT_STAR:%.+]] = llvm.mlir.null : !llvm<"float*">
|
||||||
|
// CHECK: %[[CONST_1:.+]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||||
|
// CHECK: [[ELEM1:%.+]] = llvm.getelementptr [[FLOAT_STAR]][%[[CONST_1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||||
|
// CHECK: [[ELEM_SIZE:%.+]] = llvm.ptrtoint [[ELEM1]] : !llvm<"float*"> to !llvm.i64
|
||||||
|
// CHECK: [[MUL2:%.+]] = llvm.mul [[MUL1]], [[ELEM_SIZE]] : !llvm.i64
|
||||||
|
// CHECK: [[MEMPOOL:%.+]] = llvm.call @malloc([[MUL2]]) : (!llvm.i64) -> !llvm<"i8*">
|
||||||
|
// CHECK: [[TYPED_MEMPOOL:%.+]] = llvm.bitcast [[MEMPOOL]] : !llvm<"i8*"> to !llvm<"float*">
|
||||||
|
// CHECK: [[MEMPOOL_MEMREF:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
// CHECK: [[MEMREF1:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL]], [[MEMPOOL_MEMREF]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
// CHECK: [[MEMREF2:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL]], [[MEMREF1]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
// CHECK: llvm.mlir.constant
|
||||||
|
// CHECK: llvm.insertvalue
|
||||||
|
// CHECK: llvm.mlir.constant
|
||||||
|
// CHECK: llvm.mlir.constant
|
||||||
|
// CHECK: llvm.insertvalue
|
||||||
|
// CHECK: llvm.insertvalue
|
||||||
|
// CHECK: llvm.insertvalue
|
||||||
|
// CHECK: llvm.insertvalue
|
||||||
|
// CHECK: [[MEMPOOL1:%.+]] = llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
// CHECK: [[MEMPOOL_ALLOC:%.+]] = llvm.getelementptr [[MEMPOOL1]][%[[OFFSET]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||||
|
// CHECK: [[TYPED_MEMPOOL_ALLOC:%.+]] = llvm.bitcast [[MEMPOOL_ALLOC]] : !llvm<"float*"> to !llvm<"float*">
|
||||||
|
// CHECK: [[MEMPOOLED_MEMREF:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
// CHECK: [[MEMREF3:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL_ALLOC]], [[MEMPOOLED_MEMREF]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
// CHECK: [[MEMREF4:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL_ALLOC]], [[MEMREF3]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
}
|
|
@ -0,0 +1,66 @@
|
||||||
|
// RUN: onnx-mlir-opt --shape-inference --lower-frontend --enable-memory-pool --lower-krnl --lower-all-llvm %s -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
func @test_memory_pool(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> {
|
||||||
|
%0 = "onnx.Add"(%arg0, %arg0) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||||
|
%1 = "onnx.Add"(%0, %arg0) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||||
|
return %1 : tensor<10x10xf32>
|
||||||
|
|
||||||
|
/// Define the offset inside the memory pool.
|
||||||
|
// CHECK: %[[OFFSET:.+]] = llvm.mlir.constant(0 : i64) : !llvm.i64
|
||||||
|
|
||||||
|
/// Allocate memory for the memory pool.
|
||||||
|
// CHECK: [[MEMPOOL_SIZE:%.+]] = llvm.mlir.constant(400 : index) : !llvm.i64
|
||||||
|
// CHECK: [[TMP1:%.+]] = llvm.mlir.null : !llvm<"i8*">
|
||||||
|
// CHECK: %[[CONST1:.+]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||||
|
// CHECK: [[TMP2:%.+]] = llvm.getelementptr [[TMP1]][%[[CONST1]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*">
|
||||||
|
// CHECK: [[TYPE_SIZE_IN_BYTES:%.+]] = llvm.ptrtoint [[TMP2]] : !llvm<"i8*"> to !llvm.i64
|
||||||
|
// CHECK: [[TOTAL_SIZE:%.+]] = llvm.mul [[MEMPOOL_SIZE]], [[TYPE_SIZE_IN_BYTES]] : !llvm.i64
|
||||||
|
// CHECK: [[ALLOC_MEM_POOL:%.+]] = llvm.call @malloc([[TOTAL_SIZE]]) : (!llvm.i64) -> !llvm<"i8*">
|
||||||
|
// CHECK: [[BITCAST_ALLOC_MEM_POOL:%.+]] = llvm.bitcast [[ALLOC_MEM_POOL]] : !llvm<"i8*"> to !llvm<"i8*">
|
||||||
|
|
||||||
|
/// MemRef representing the memory pool and which contains the memory allocated above.
|
||||||
|
// CHECK: [[MEMREF0:%.+]] = llvm.mlir.undef : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }">
|
||||||
|
// CHECK: [[TMP3:%.+]] = llvm.insertvalue [[BITCAST_ALLOC_MEM_POOL]], [[MEMREF0]][0] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }">
|
||||||
|
// CHECK: llvm.insertvalue [[BITCAST_ALLOC_MEM_POOL]], [[TMP3]][1] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }">
|
||||||
|
// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
|
||||||
|
// CHECK: llvm.insertvalue
|
||||||
|
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
|
||||||
|
// CHECK: llvm.insertvalue
|
||||||
|
// CHECK: [[TMP4:%.+]] = llvm.insertvalue {{.*}}[4, 0] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }">
|
||||||
|
|
||||||
|
/// Get reference within the memory pool where the data of the getref instruction has already been allocated.
|
||||||
|
// CHECK: [[MEMPOOL_BASE:%.+]] = llvm.extractvalue [[TMP4]][1] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }">
|
||||||
|
// CHECK: [[GETREF_MEMORY:%.+]] = llvm.getelementptr [[MEMPOOL_BASE]][%[[OFFSET]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*">
|
||||||
|
// CHECK: [[CASTED_GETREF_MEMORY:%.+]] = llvm.bitcast [[GETREF_MEMORY]] : !llvm<"i8*"> to !llvm<"float*">
|
||||||
|
|
||||||
|
/// Create MemRef for krnl.getref.
|
||||||
|
// CHECK: [[MEMREF1:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
// CHECK: [[MEMREF1_TMP1:%.+]] = llvm.insertvalue [[CASTED_GETREF_MEMORY]], [[MEMREF1]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
// CHECK: [[MEMREF1_TMP2:%.+]] = llvm.insertvalue [[CASTED_GETREF_MEMORY]], [[MEMREF1_TMP1]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
// CHECK: [[CONST2:%.+]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||||
|
// CHECK: [[MEMREF1_TMP3:%.+]] = llvm.insertvalue [[CONST2]], [[MEMREF1_TMP2]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
// CHECK: [[CONST3:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64
|
||||||
|
// CHECK: [[MEMREF1_TMP4:%.+]] = llvm.insertvalue [[CONST3]], [[MEMREF1_TMP3]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
// CHECK: [[CONST4:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64
|
||||||
|
// CHECK: [[MEMREF1_TMP5:%.+]] = llvm.insertvalue [[CONST4]], [[MEMREF1_TMP4]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
// CHECK: [[CONST5:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64
|
||||||
|
// CHECK: [[MEMREF1_TMP6:%.+]] = llvm.insertvalue [[CONST5]], [[MEMREF1_TMP5]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
// CHECK: [[CONST6:%.+]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||||
|
// CHECK: [[MEMREF1_TMP7:%.+]] = llvm.insertvalue [[CONST6]], [[MEMREF1_TMP6]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
|
||||||
|
/// Usage of the getref MemRef.
|
||||||
|
// CHECK: [[MEM0:%.+]] = llvm.extractvalue [[MEMREF1_TMP7]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
// CHECK: [[CONST7:%.+]] = llvm.mlir.constant(0 : index) : !llvm.i64
|
||||||
|
// CHECK: [[CONST8:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64
|
||||||
|
// CHECK: [[MUL1:%.+]] = llvm.mul {{.*}}, [[CONST8]] : !llvm.i64
|
||||||
|
// CHECK: [[ADD1:%.+]] = llvm.add [[CONST7]], [[MUL1]] : !llvm.i64
|
||||||
|
// CHECK: [[CONST9:%.+]] = llvm.mlir.constant(1 : index) : !llvm.i64
|
||||||
|
// CHECK: [[MUL2:%.+]] = llvm.mul {{.*}}, [[CONST9]] : !llvm.i64
|
||||||
|
// CHECK: %[[ADD2:.+]] = llvm.add [[ADD1]], [[MUL2]] : !llvm.i64
|
||||||
|
// CHECK: llvm.getelementptr [[MEM0]][%[[ADD2]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
|
||||||
|
|
||||||
|
/// Deallocation of the memory pool.
|
||||||
|
// CHECK: [[MEMPOOL_BASE_UNALIGNED:%.+]] = llvm.extractvalue [[TMP4]][0] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }">
|
||||||
|
// CHECK: [[CASTED_MEMPOOL_BASE_UNALIGNED:%.+]] = llvm.bitcast [[MEMPOOL_BASE_UNALIGNED]] : !llvm<"i8*"> to !llvm<"i8*">
|
||||||
|
// CHECK: llvm.call @free([[CASTED_MEMPOOL_BASE_UNALIGNED]]) : (!llvm<"i8*">) -> ()
|
||||||
|
}
|
|
@ -0,0 +1,69 @@
|
||||||
|
// RUN: onnx-mlir-opt --shape-inference --lower-frontend --enable-memory-pool %s -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
/// One intermediate value to allocate in the memory pool.
|
||||||
|
func @test_enable_memory_pool(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> {
|
||||||
|
%0 = "onnx.Add"(%arg0, %arg0) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||||
|
%1 = "onnx.Add"(%0, %arg0) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||||
|
return %1 : tensor<10x10xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_enable_memory_pool
|
||||||
|
// CHECK: [[CONST0:%.+]] = constant 0 : i64
|
||||||
|
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32>
|
||||||
|
// CHECK: [[MEMPOOL:%.+]] = alloc() : memref<400xi8>
|
||||||
|
// CHECK: [[GETREF:%.+]] = "krnl.getref"([[MEMPOOL]], [[CONST0]]) : (memref<400xi8>, i64) -> memref<10x10xf32>
|
||||||
|
// CHECK: krnl.define_loops
|
||||||
|
// CHECK: krnl.optimize_loops
|
||||||
|
// CHECK: krnl.iterate
|
||||||
|
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg1, %arg2] : memref<10x10xf32>
|
||||||
|
// CHECK: [[LOAD2:%.+]] = load %arg0[%arg1, %arg2] : memref<10x10xf32>
|
||||||
|
// CHECK: [[ADDF1:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
|
||||||
|
// CHECK: store [[ADDF1]], [[GETREF]][%arg1, %arg2] : memref<10x10xf32>
|
||||||
|
// CHECK: krnl.define_loops
|
||||||
|
// CHECK: krnl.optimize_loops
|
||||||
|
// CHECK: krnl.iterate
|
||||||
|
// CHECK: dealloc [[MEMPOOL]] : memref<400xi8>
|
||||||
|
// CHECK: return [[RES]] : memref<10x10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Two intermediate values to allocate in the memory pool.
|
||||||
|
func @test_enable_memory_pool_2(%arg0: tensor<10x10xf32>, %arg1: tensor<10x20xf32>) -> tensor<10x20xf32> {
|
||||||
|
%0 = "onnx.Add"(%arg0, %arg0) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||||
|
%1 = "onnx.MatMul"(%0, %arg1) : (tensor<10x10xf32>, tensor<10x20xf32>) -> tensor<10x20xf32>
|
||||||
|
%2 = "onnx.Add"(%1, %arg1) : (tensor<10x20xf32>, tensor<10x20xf32>) -> tensor<10x20xf32>
|
||||||
|
return %2 : tensor<10x20xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_enable_memory_pool_2
|
||||||
|
// CHECK: [[CONST0:%.+]] = constant 0 : i64
|
||||||
|
// CHECK: [[CONST1:%.+]] = constant 0.000000e+00 : f32
|
||||||
|
// CHECK: [[RES:%.+]] = alloc() : memref<10x20xf32>
|
||||||
|
// CHECK: [[MEMPOOL0:%.+]] = alloc() : memref<800xi8>
|
||||||
|
// CHECK: [[GETREF0:%.+]] = "krnl.getref"([[MEMPOOL0]], [[CONST0]]) : (memref<800xi8>, i64) -> memref<10x20xf32>
|
||||||
|
// CHECK: [[MEMPOOL1:%.+]] = alloc() : memref<400xi8>
|
||||||
|
// CHECK: [[GETREF1:%.+]] = "krnl.getref"([[MEMPOOL1]], [[CONST0]]) : (memref<400xi8>, i64) -> memref<10x10xf32>
|
||||||
|
// CHECK: krnl.define_loops
|
||||||
|
// CHECK: krnl.optimize_loops
|
||||||
|
// CHECK: krnl.iterate
|
||||||
|
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xf32>
|
||||||
|
// CHECK: [[LOAD2:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xf32>
|
||||||
|
// CHECK: [[ADDF1:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
|
||||||
|
// CHECK: store [[ADDF1]], [[GETREF1]][%arg2, %arg3] : memref<10x10xf32>
|
||||||
|
// CHECK: krnl.define_loops
|
||||||
|
// CHECK: krnl.optimize_loops
|
||||||
|
// CHECK: krnl.iterate
|
||||||
|
// CHECK: [[LOAD3:%.+]] = load [[GETREF1]][%arg2, %arg4] : memref<10x10xf32>
|
||||||
|
// CHECK: [[LOAD4:%.+]] = load %arg1[%arg4, %arg3] : memref<10x20xf32>
|
||||||
|
// CHECK: [[LOAD5:%.+]] = load [[GETREF0]][%arg2, %arg3] : memref<10x20xf32>
|
||||||
|
// CHECK: [[MULF1:%.+]] = mulf [[LOAD3]], [[LOAD4]] : f32
|
||||||
|
// CHECK: [[ADDF2:%.+]] = addf [[LOAD5]], [[MULF1]] : f32
|
||||||
|
// CHECK: store [[ADDF2]], [[GETREF0]][%arg2, %arg3] : memref<10x20xf32>
|
||||||
|
// CHECK: krnl.define_loops
|
||||||
|
// CHECK: krnl.optimize_loops
|
||||||
|
// CHECK: krnl.iterate
|
||||||
|
// CHECK: [[LOAD6:%.+]] = load [[GETREF0]][%arg2, %arg3] : memref<10x20xf32>
|
||||||
|
// CHECK: [[LOAD7:%.+]] = load %arg1[%arg2, %arg3] : memref<10x20xf32>
|
||||||
|
// CHECK: [[ADDF3:%.+]] = addf [[LOAD6]], [[LOAD7]] : f32
|
||||||
|
// CHECK: store [[ADDF3]], [[RES]][%arg2, %arg3] : memref<10x20xf32>
|
||||||
|
// CHECK: dealloc [[MEMPOOL1]] : memref<400xi8>
|
||||||
|
// CHECK: dealloc [[MEMPOOL0]] : memref<800xi8>
|
||||||
|
// CHECK: return [[RES]] : memref<10x20xf32>
|
||||||
|
}
|
Loading…
Reference in New Issue