Add basic support for memory pool (#161)

* Reorganize main function.

* Follow review comments.

* Emit constants are globals in Krnl and LLVM dialects.

* Replace internal malloc with memory pool and getref instruction.

* Lower krnl.getref to LLVM.

* Fix formatting issues.

* Add tests.

* Add missing dependency.

* Improve LLVM lowering.

* Add test to show getref is generic.
This commit is contained in:
Gheorghe-Teodor Bercea 2020-06-09 16:48:33 -04:00 committed by GitHub
parent ca185002f2
commit 4ab96fbc6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 439 additions and 1 deletions

View File

@ -268,7 +268,8 @@ set(ONNXMLIRWholeArchiveLibs
OMAttributePromotion
OMPromotableConstOperandsOpInterface
OMElideConstants
OMElideKrnlGlobalConstants)
OMElideKrnlGlobalConstants
OMEnableMemoryPool)
# Function to construct linkage option for the static libraries that must be
# linked with --whole-archive (or equivalent).

View File

@ -57,6 +57,7 @@ target_link_libraries(MainUtils
OMResultTypeInferenceOpInterface
OMElideConstants
OMElideKrnlGlobalConstants
OMEnableMemoryPool
OMKrnlToAffine
OMKrnlToLLVM
OMONNXToKrnl

View File

@ -205,3 +205,21 @@ def KrnlGlobalOp : Op<Krnl_Dialect, "global"> {
let parser = ?;
let printer = ?;
}
def KrnlGetRefOp : Op<Krnl_Dialect, "getref"> {
let summary = "Krnl a MemRef from within another MemRef starting at a specific offset.";
let description = [{
Retreieves a MemRef from within another MemRef:
"krnl.getref"(%memref, %offset)
The offset is an integer which is used as an index into the input MemRef. It works
just like an array index.
}];
let arguments = (ins AnyTypeOf<[AnyMemRef]>:$mempool, AnyInteger:$offset);
let results = (outs AnyTypeOf<[AnyMemRef]>:$output);
let parser = ?;
let printer = ?;
}

View File

@ -102,6 +102,9 @@ void addONNXToKrnlPasses(mlir::PassManager &pm) {
// from ONNX dialect to Standard dialect exposes additional canonicalization
// oppertunities.
pm.addPass(mlir::createCanonicalizerPass());
// TODO: make this pass optional:
pm.addPass(mlir::createKrnlEnableMemoryPoolPass());
}
void addKrnlToAffinePasses(mlir::PassManager &pm) {

View File

@ -28,6 +28,9 @@ std::unique_ptr<Pass> createAttributePromotionPass();
/// Pass for eliding the values of constant operations.
std::unique_ptr<Pass> createElideConstantValuePass();
/// Pass for enabling a memory pool for MemRefs.
std::unique_ptr<Pass> createKrnlEnableMemoryPoolPass();
/// Add pass for lowering to Krnl IR.
std::unique_ptr<Pass> createLowerToKrnlPass();

View File

@ -38,4 +38,17 @@ add_dependencies(OMElideKrnlGlobalConstants OMKrnlOpsInc)
# Linking dependencies
add_dependencies(OMElideKrnlGlobalConstants OMKrnlOps)
add_library(OMEnableMemoryPool
EnableMemoryPool.cpp)
target_include_directories(OMEnableMemoryPool
PRIVATE
${ONNX_MLIR_SRC_ROOT}
${ONNX_MLIR_BIN_ROOT}
${ONNX_MLIR_SRC_ROOT})
target_link_libraries(OMEnableMemoryPool
onnx)
add_dependencies(OMEnableMemoryPool
OMKrnlOps
OMONNXOps)
add_subdirectory(ONNX)

View File

@ -0,0 +1,159 @@
//===-------- EnableMemoryPool.cpp - Enable Memory Pool for MemRefs -------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// For certain cases the number of individual memory allocations required for
// all internal tensors is large and needs to be mitigated. This pass enables a
// managed memory pool for allocating MemRefs.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Pass/Passes.hpp"
using namespace mlir;
namespace {
bool checkOpResultIsReturned(AllocOp *allocOp) {
auto parentBlock = allocOp->getOperation()->getBlock();
bool opIsReturned = false;
parentBlock->walk([&opIsReturned, allocOp](ReturnOp op) {
auto result = allocOp->getResult();
for (const auto &operand : op.getOperands())
if (operand == result)
opIsReturned = true;
});
return opIsReturned;
}
bool checkOpResultIsUsedByGetRef(AllocOp *allocOp) {
auto parentBlock = allocOp->getOperation()->getBlock();
bool opIsUsedInGetRef = false;
parentBlock->walk([&opIsUsedInGetRef, allocOp](KrnlGetRefOp op) {
auto result = allocOp->getResult();
for (const auto &operand : op.getOperands())
if (operand == result)
opIsUsedInGetRef = true;
});
return opIsUsedInGetRef;
}
/*!
* RewritePattern that replaces:
* %0 = alloc() : memref<<dims>x<type>>
* with:
* %mem = alloc() : memref<<dims>x<type>>
* %0 = krnl.getref %mem <offset> : memref<<dims>x<type>>
*
* For now, to enable testing, offset will always be 0.
*/
class KrnlEnableMemoryPool : public OpRewritePattern<AllocOp> {
public:
using OpRewritePattern<AllocOp>::OpRewritePattern;
LogicalResult matchAndRewrite(
AllocOp allocOp, PatternRewriter &rewriter) const override {
auto loc = allocOp.getLoc();
auto memRefType = convertToMemRefType(allocOp.getResult().getType());
// For now we only support constant tensors.
// TODO: Enable this pass for MemRef with dyanmic shapes.
// If alloc operation is not returned then it is a candidate for
// being included in the memory pool.
if (!hasAllConstantDimensions(memRefType) ||
checkOpResultIsReturned(&allocOp))
return failure();
// Check the result of this alloc is not already used by a krnl.getref.
if (checkOpResultIsUsedByGetRef(&allocOp))
return failure();
// Compute total size.
auto memRefShape = memRefType.getShape();
int64_t totalSize = 1;
for (int i = 0; i < memRefShape.size(); i++)
totalSize *= memRefShape[i];
totalSize *= getMemRefEltSizeInBytes(memRefType);
// Emit new alloc.
SmallVector<int64_t, 1> memPoolShape;
memPoolShape.emplace_back(totalSize);
auto memPoolMemRefType =
MemRefType::get(memPoolShape, rewriter.getIntegerType(8));
auto newAlloc = rewriter.create<AllocOp>(loc, memPoolMemRefType);
// Emit new dealloc.
auto dealloc = rewriter.create<DeallocOp>(loc, newAlloc);
auto parentBlock = allocOp.getOperation()->getBlock();
dealloc.getOperation()->moveBefore(&parentBlock->back());
// Get reference to local MemRef.
auto zero = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
auto poolMemRef =
rewriter.create<KrnlGetRefOp>(loc, memRefType, newAlloc, zero);
rewriter.replaceOp(allocOp, poolMemRef.getResult());
return success();
}
};
class KrnlEliminateOldDealloc : public OpRewritePattern<DeallocOp> {
public:
using OpRewritePattern<DeallocOp>::OpRewritePattern;
LogicalResult matchAndRewrite(
DeallocOp deallocOp, PatternRewriter &rewriter) const override {
if (auto getRefOp = llvm::dyn_cast<KrnlGetRefOp>(
deallocOp.getOperand().getDefiningOp())) {
rewriter.eraseOp(deallocOp);
return success();
}
return failure();
}
};
// TODO: Replace old dealloc with krnl.unsetref.
/*!
* Function pass that enables memory pooling for MemRefs.
*/
class KrnlEnableMemoryPoolPass
: public PassWrapper<KrnlEnableMemoryPoolPass, FunctionPass> {
public:
void runOnFunction() override {
auto function = getFunction();
ConversionTarget target(getContext());
OwningRewritePatternList patterns;
patterns.insert<KrnlEnableMemoryPool>(&getContext());
patterns.insert<KrnlEliminateOldDealloc>(&getContext());
applyPatternsAndFoldGreedily(function, patterns);
}
};
} // namespace
std::unique_ptr<Pass> mlir::createKrnlEnableMemoryPoolPass() {
return std::make_unique<KrnlEnableMemoryPoolPass>();
}
static PassRegistration<KrnlEnableMemoryPoolPass> pass("enable-memory-pool",
"Enable a memory pool for allocating internal MemRefs.");

View File

@ -165,6 +165,7 @@ void KrnlToAffineLoweringPass::runOnFunction() {
target.addLegalOp<KrnlMemcpyOp>();
target.addLegalOp<KrnlEntryPointOp>();
target.addLegalOp<KrnlGlobalOp>();
target.addLegalOp<KrnlGetRefOp>();
OwningRewritePatternList patterns;
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,

View File

@ -88,6 +88,71 @@ static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter,
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
}
//===----------------------------------------------------------------------===//
// KRNL to LLVM: KrnlGetRefOpLowering
//===----------------------------------------------------------------------===//
class KrnlGetRefOpLowering : public ConvertToLLVMPattern {
public:
explicit KrnlGetRefOpLowering(
MLIRContext *context, LLVMTypeConverter &lowering_)
: ConvertToLLVMPattern(
KrnlGetRefOp::getOperationName(), context, lowering_) {}
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto *context = op->getContext();
auto loc = op->getLoc();
auto *llvmDialect =
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
assert(llvmDialect && "expected llvm dialect to be registered");
KrnlGetRefOpOperandAdaptor operandAdaptor(operands);
// This is the type of the krnl.getref output. This type is used
// for the type of the internal MemRef.
auto type = op->getResult(0).getType();
auto memRefTy = type.cast<mlir::MemRefType>();
auto llvmMemRefType =
typeConverter.convertType(type).cast<LLVM::LLVMType>();
auto outputElementType =
typeConverter.convertType(memRefTy.getElementType());
// This is the start of the memory pool containing the output MemRef.
Type memPoolType = operandAdaptor.mempool()
.getType()
.cast<LLVM::LLVMType>()
.getStructElementType(1);
Value alignedMemPoolBase = rewriter.create<LLVM::ExtractValueOp>(loc,
memPoolType, operandAdaptor.mempool(), rewriter.getI64ArrayAttr(1));
// Get pointer using the offset.
auto offset = operandAdaptor.offset();
auto llvmMemPoolType =
typeConverter.convertType(memPoolType).cast<LLVM::LLVMType>();
auto outputMemPoolTypePtrAlloc = rewriter.create<LLVM::GEPOp>(
loc, llvmMemPoolType, alignedMemPoolBase, ArrayRef<Value>({offset}));
// Bitcast to output MemRef type i.e. from i8* to the element type
// of the output MemRef.
auto llvmOutputElementType = outputElementType.cast<LLVM::LLVMType>();
Value outputTypedPtrAlloc = rewriter.create<LLVM::BitcastOp>(
loc, llvmOutputElementType.getPointerTo(), outputMemPoolTypePtrAlloc);
// Create llvm MemRef from original MemRef and fill the data pointers.
auto llvmMemRef = MemRefDescriptor::fromStaticShape(
rewriter, loc, typeConverter, memRefTy, outputTypedPtrAlloc);
rewriter.replaceOp(op, {llvmMemRef});
return success();
}
private:
static int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
return (a.getValue()[i]).cast<IntegerAttr>().getInt();
}
};
//===----------------------------------------------------------------------===//
// KRNL to LLVM: KrnlGlobalOpLowering
//===----------------------------------------------------------------------===//
@ -648,6 +713,7 @@ void KrnlToLLVMLoweringPass::runOnOperation() {
/*useAlignedAlloc=*/false);
patterns.insert<KrnlGlobalOpLowering>(&getContext(), typeConverter);
patterns.insert<KrnlGetRefOpLowering>(&getContext(), typeConverter);
// Lower from the `krnl` dialect i.e. the Reshape operation.
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering>(

View File

@ -0,0 +1,38 @@
// RUN: onnx-mlir-opt --lower-krnl --lower-all-llvm %s -split-input-file | FileCheck %s
func @test_getref_lowering(%arg0: memref<2x2xf32>) -> memref<2x2xf32> {
%c13_i64 = constant 13 : i64
%1 = alloc() : memref<10x10xf32>
%2 = "krnl.getref"(%1, %c13_i64) : (memref<10x10xf32>, i64) -> memref<2x2xf32>
return %2 : memref<2x2xf32>
// CHECK-LABEL: test_getref_lowering
// CHECK: %[[OFFSET:.+]] = llvm.mlir.constant(13 : i64) : !llvm.i64
// CHECK: [[CONST_10_0:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64
// CHECK: [[CONST_10_1:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64
// CHECK: [[MUL1:%.+]] = llvm.mul [[CONST_10_0]], [[CONST_10_1]] : !llvm.i64
// CHECK: [[FLOAT_STAR:%.+]] = llvm.mlir.null : !llvm<"float*">
// CHECK: %[[CONST_1:.+]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: [[ELEM1:%.+]] = llvm.getelementptr [[FLOAT_STAR]][%[[CONST_1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK: [[ELEM_SIZE:%.+]] = llvm.ptrtoint [[ELEM1]] : !llvm<"float*"> to !llvm.i64
// CHECK: [[MUL2:%.+]] = llvm.mul [[MUL1]], [[ELEM_SIZE]] : !llvm.i64
// CHECK: [[MEMPOOL:%.+]] = llvm.call @malloc([[MUL2]]) : (!llvm.i64) -> !llvm<"i8*">
// CHECK: [[TYPED_MEMPOOL:%.+]] = llvm.bitcast [[MEMPOOL]] : !llvm<"i8*"> to !llvm<"float*">
// CHECK: [[MEMPOOL_MEMREF:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMREF1:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL]], [[MEMPOOL_MEMREF]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMREF2:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL]], [[MEMREF1]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.mlir.constant
// CHECK: llvm.insertvalue
// CHECK: llvm.mlir.constant
// CHECK: llvm.mlir.constant
// CHECK: llvm.insertvalue
// CHECK: llvm.insertvalue
// CHECK: llvm.insertvalue
// CHECK: llvm.insertvalue
// CHECK: [[MEMPOOL1:%.+]] = llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMPOOL_ALLOC:%.+]] = llvm.getelementptr [[MEMPOOL1]][%[[OFFSET]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK: [[TYPED_MEMPOOL_ALLOC:%.+]] = llvm.bitcast [[MEMPOOL_ALLOC]] : !llvm<"float*"> to !llvm<"float*">
// CHECK: [[MEMPOOLED_MEMREF:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMREF3:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL_ALLOC]], [[MEMPOOLED_MEMREF]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMREF4:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL_ALLOC]], [[MEMREF3]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
}

View File

@ -0,0 +1,66 @@
// RUN: onnx-mlir-opt --shape-inference --lower-frontend --enable-memory-pool --lower-krnl --lower-all-llvm %s -split-input-file | FileCheck %s
func @test_memory_pool(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> {
%0 = "onnx.Add"(%arg0, %arg0) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%1 = "onnx.Add"(%0, %arg0) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
return %1 : tensor<10x10xf32>
/// Define the offset inside the memory pool.
// CHECK: %[[OFFSET:.+]] = llvm.mlir.constant(0 : i64) : !llvm.i64
/// Allocate memory for the memory pool.
// CHECK: [[MEMPOOL_SIZE:%.+]] = llvm.mlir.constant(400 : index) : !llvm.i64
// CHECK: [[TMP1:%.+]] = llvm.mlir.null : !llvm<"i8*">
// CHECK: %[[CONST1:.+]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: [[TMP2:%.+]] = llvm.getelementptr [[TMP1]][%[[CONST1]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*">
// CHECK: [[TYPE_SIZE_IN_BYTES:%.+]] = llvm.ptrtoint [[TMP2]] : !llvm<"i8*"> to !llvm.i64
// CHECK: [[TOTAL_SIZE:%.+]] = llvm.mul [[MEMPOOL_SIZE]], [[TYPE_SIZE_IN_BYTES]] : !llvm.i64
// CHECK: [[ALLOC_MEM_POOL:%.+]] = llvm.call @malloc([[TOTAL_SIZE]]) : (!llvm.i64) -> !llvm<"i8*">
// CHECK: [[BITCAST_ALLOC_MEM_POOL:%.+]] = llvm.bitcast [[ALLOC_MEM_POOL]] : !llvm<"i8*"> to !llvm<"i8*">
/// MemRef representing the memory pool and which contains the memory allocated above.
// CHECK: [[MEMREF0:%.+]] = llvm.mlir.undef : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }">
// CHECK: [[TMP3:%.+]] = llvm.insertvalue [[BITCAST_ALLOC_MEM_POOL]], [[MEMREF0]][0] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }">
// CHECK: llvm.insertvalue [[BITCAST_ALLOC_MEM_POOL]], [[TMP3]][1] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }">
// CHECK: llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: llvm.insertvalue
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: llvm.insertvalue
// CHECK: [[TMP4:%.+]] = llvm.insertvalue {{.*}}[4, 0] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }">
/// Get reference within the memory pool where the data of the getref instruction has already been allocated.
// CHECK: [[MEMPOOL_BASE:%.+]] = llvm.extractvalue [[TMP4]][1] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }">
// CHECK: [[GETREF_MEMORY:%.+]] = llvm.getelementptr [[MEMPOOL_BASE]][%[[OFFSET]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*">
// CHECK: [[CASTED_GETREF_MEMORY:%.+]] = llvm.bitcast [[GETREF_MEMORY]] : !llvm<"i8*"> to !llvm<"float*">
/// Create MemRef for krnl.getref.
// CHECK: [[MEMREF1:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMREF1_TMP1:%.+]] = llvm.insertvalue [[CASTED_GETREF_MEMORY]], [[MEMREF1]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[MEMREF1_TMP2:%.+]] = llvm.insertvalue [[CASTED_GETREF_MEMORY]], [[MEMREF1_TMP1]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[CONST2:%.+]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: [[MEMREF1_TMP3:%.+]] = llvm.insertvalue [[CONST2]], [[MEMREF1_TMP2]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[CONST3:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64
// CHECK: [[MEMREF1_TMP4:%.+]] = llvm.insertvalue [[CONST3]], [[MEMREF1_TMP3]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[CONST4:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64
// CHECK: [[MEMREF1_TMP5:%.+]] = llvm.insertvalue [[CONST4]], [[MEMREF1_TMP4]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[CONST5:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64
// CHECK: [[MEMREF1_TMP6:%.+]] = llvm.insertvalue [[CONST5]], [[MEMREF1_TMP5]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[CONST6:%.+]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: [[MEMREF1_TMP7:%.+]] = llvm.insertvalue [[CONST6]], [[MEMREF1_TMP6]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
/// Usage of the getref MemRef.
// CHECK: [[MEM0:%.+]] = llvm.extractvalue [[MEMREF1_TMP7]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[CONST7:%.+]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: [[CONST8:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64
// CHECK: [[MUL1:%.+]] = llvm.mul {{.*}}, [[CONST8]] : !llvm.i64
// CHECK: [[ADD1:%.+]] = llvm.add [[CONST7]], [[MUL1]] : !llvm.i64
// CHECK: [[CONST9:%.+]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: [[MUL2:%.+]] = llvm.mul {{.*}}, [[CONST9]] : !llvm.i64
// CHECK: %[[ADD2:.+]] = llvm.add [[ADD1]], [[MUL2]] : !llvm.i64
// CHECK: llvm.getelementptr [[MEM0]][%[[ADD2]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
/// Deallocation of the memory pool.
// CHECK: [[MEMPOOL_BASE_UNALIGNED:%.+]] = llvm.extractvalue [[TMP4]][0] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }">
// CHECK: [[CASTED_MEMPOOL_BASE_UNALIGNED:%.+]] = llvm.bitcast [[MEMPOOL_BASE_UNALIGNED]] : !llvm<"i8*"> to !llvm<"i8*">
// CHECK: llvm.call @free([[CASTED_MEMPOOL_BASE_UNALIGNED]]) : (!llvm<"i8*">) -> ()
}

View File

@ -0,0 +1,69 @@
// RUN: onnx-mlir-opt --shape-inference --lower-frontend --enable-memory-pool %s -split-input-file | FileCheck %s
/// One intermediate value to allocate in the memory pool.
func @test_enable_memory_pool(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> {
%0 = "onnx.Add"(%arg0, %arg0) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%1 = "onnx.Add"(%0, %arg0) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
return %1 : tensor<10x10xf32>
// CHECK-LABEL: test_enable_memory_pool
// CHECK: [[CONST0:%.+]] = constant 0 : i64
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32>
// CHECK: [[MEMPOOL:%.+]] = alloc() : memref<400xi8>
// CHECK: [[GETREF:%.+]] = "krnl.getref"([[MEMPOOL]], [[CONST0]]) : (memref<400xi8>, i64) -> memref<10x10xf32>
// CHECK: krnl.define_loops
// CHECK: krnl.optimize_loops
// CHECK: krnl.iterate
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg1, %arg2] : memref<10x10xf32>
// CHECK: [[LOAD2:%.+]] = load %arg0[%arg1, %arg2] : memref<10x10xf32>
// CHECK: [[ADDF1:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[ADDF1]], [[GETREF]][%arg1, %arg2] : memref<10x10xf32>
// CHECK: krnl.define_loops
// CHECK: krnl.optimize_loops
// CHECK: krnl.iterate
// CHECK: dealloc [[MEMPOOL]] : memref<400xi8>
// CHECK: return [[RES]] : memref<10x10xf32>
}
/// Two intermediate values to allocate in the memory pool.
func @test_enable_memory_pool_2(%arg0: tensor<10x10xf32>, %arg1: tensor<10x20xf32>) -> tensor<10x20xf32> {
%0 = "onnx.Add"(%arg0, %arg0) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
%1 = "onnx.MatMul"(%0, %arg1) : (tensor<10x10xf32>, tensor<10x20xf32>) -> tensor<10x20xf32>
%2 = "onnx.Add"(%1, %arg1) : (tensor<10x20xf32>, tensor<10x20xf32>) -> tensor<10x20xf32>
return %2 : tensor<10x20xf32>
// CHECK-LABEL: test_enable_memory_pool_2
// CHECK: [[CONST0:%.+]] = constant 0 : i64
// CHECK: [[CONST1:%.+]] = constant 0.000000e+00 : f32
// CHECK: [[RES:%.+]] = alloc() : memref<10x20xf32>
// CHECK: [[MEMPOOL0:%.+]] = alloc() : memref<800xi8>
// CHECK: [[GETREF0:%.+]] = "krnl.getref"([[MEMPOOL0]], [[CONST0]]) : (memref<800xi8>, i64) -> memref<10x20xf32>
// CHECK: [[MEMPOOL1:%.+]] = alloc() : memref<400xi8>
// CHECK: [[GETREF1:%.+]] = "krnl.getref"([[MEMPOOL1]], [[CONST0]]) : (memref<400xi8>, i64) -> memref<10x10xf32>
// CHECK: krnl.define_loops
// CHECK: krnl.optimize_loops
// CHECK: krnl.iterate
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xf32>
// CHECK: [[LOAD2:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xf32>
// CHECK: [[ADDF1:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[ADDF1]], [[GETREF1]][%arg2, %arg3] : memref<10x10xf32>
// CHECK: krnl.define_loops
// CHECK: krnl.optimize_loops
// CHECK: krnl.iterate
// CHECK: [[LOAD3:%.+]] = load [[GETREF1]][%arg2, %arg4] : memref<10x10xf32>
// CHECK: [[LOAD4:%.+]] = load %arg1[%arg4, %arg3] : memref<10x20xf32>
// CHECK: [[LOAD5:%.+]] = load [[GETREF0]][%arg2, %arg3] : memref<10x20xf32>
// CHECK: [[MULF1:%.+]] = mulf [[LOAD3]], [[LOAD4]] : f32
// CHECK: [[ADDF2:%.+]] = addf [[LOAD5]], [[MULF1]] : f32
// CHECK: store [[ADDF2]], [[GETREF0]][%arg2, %arg3] : memref<10x20xf32>
// CHECK: krnl.define_loops
// CHECK: krnl.optimize_loops
// CHECK: krnl.iterate
// CHECK: [[LOAD6:%.+]] = load [[GETREF0]][%arg2, %arg3] : memref<10x20xf32>
// CHECK: [[LOAD7:%.+]] = load %arg1[%arg2, %arg3] : memref<10x20xf32>
// CHECK: [[ADDF3:%.+]] = addf [[LOAD6]], [[LOAD7]] : f32
// CHECK: store [[ADDF3]], [[RES]][%arg2, %arg3] : memref<10x20xf32>
// CHECK: dealloc [[MEMPOOL1]] : memref<400xi8>
// CHECK: dealloc [[MEMPOOL0]] : memref<800xi8>
// CHECK: return [[RES]] : memref<10x20xf32>
}