diff --git a/MLIR.cmake b/MLIR.cmake index f189ea9..e084172 100644 --- a/MLIR.cmake +++ b/MLIR.cmake @@ -268,7 +268,8 @@ set(ONNXMLIRWholeArchiveLibs OMAttributePromotion OMPromotableConstOperandsOpInterface OMElideConstants - OMElideKrnlGlobalConstants) + OMElideKrnlGlobalConstants + OMEnableMemoryPool) # Function to construct linkage option for the static libraries that must be # linked with --whole-archive (or equivalent). diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 9a92f3c..2dfa041 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -57,6 +57,7 @@ target_link_libraries(MainUtils OMResultTypeInferenceOpInterface OMElideConstants OMElideKrnlGlobalConstants + OMEnableMemoryPool OMKrnlToAffine OMKrnlToLLVM OMONNXToKrnl diff --git a/src/Dialect/Krnl/KrnlOps.td b/src/Dialect/Krnl/KrnlOps.td index acf4eb7..6b72883 100644 --- a/src/Dialect/Krnl/KrnlOps.td +++ b/src/Dialect/Krnl/KrnlOps.td @@ -205,3 +205,21 @@ def KrnlGlobalOp : Op { let parser = ?; let printer = ?; } + +def KrnlGetRefOp : Op { + let summary = "Krnl a MemRef from within another MemRef starting at a specific offset."; + let description = [{ + Retreieves a MemRef from within another MemRef: + + "krnl.getref"(%memref, %offset) + + The offset is an integer which is used as an index into the input MemRef. It works + just like an array index. + }]; + + let arguments = (ins AnyTypeOf<[AnyMemRef]>:$mempool, AnyInteger:$offset); + let results = (outs AnyTypeOf<[AnyMemRef]>:$output); + + let parser = ?; + let printer = ?; +} diff --git a/src/MainUtils.cpp b/src/MainUtils.cpp index 175ff1f..4de03ae 100644 --- a/src/MainUtils.cpp +++ b/src/MainUtils.cpp @@ -102,6 +102,9 @@ void addONNXToKrnlPasses(mlir::PassManager &pm) { // from ONNX dialect to Standard dialect exposes additional canonicalization // oppertunities. pm.addPass(mlir::createCanonicalizerPass()); + + // TODO: make this pass optional: + pm.addPass(mlir::createKrnlEnableMemoryPoolPass()); } void addKrnlToAffinePasses(mlir::PassManager &pm) { diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index 7a80240..fff6dbf 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -28,6 +28,9 @@ std::unique_ptr createAttributePromotionPass(); /// Pass for eliding the values of constant operations. std::unique_ptr createElideConstantValuePass(); +/// Pass for enabling a memory pool for MemRefs. +std::unique_ptr createKrnlEnableMemoryPoolPass(); + /// Add pass for lowering to Krnl IR. std::unique_ptr createLowerToKrnlPass(); diff --git a/src/Transform/CMakeLists.txt b/src/Transform/CMakeLists.txt index 64a28ab..8c5cdb4 100644 --- a/src/Transform/CMakeLists.txt +++ b/src/Transform/CMakeLists.txt @@ -38,4 +38,17 @@ add_dependencies(OMElideKrnlGlobalConstants OMKrnlOpsInc) # Linking dependencies add_dependencies(OMElideKrnlGlobalConstants OMKrnlOps) +add_library(OMEnableMemoryPool + EnableMemoryPool.cpp) +target_include_directories(OMEnableMemoryPool + PRIVATE + ${ONNX_MLIR_SRC_ROOT} + ${ONNX_MLIR_BIN_ROOT} + ${ONNX_MLIR_SRC_ROOT}) +target_link_libraries(OMEnableMemoryPool + onnx) +add_dependencies(OMEnableMemoryPool + OMKrnlOps + OMONNXOps) + add_subdirectory(ONNX) diff --git a/src/Transform/EnableMemoryPool.cpp b/src/Transform/EnableMemoryPool.cpp new file mode 100644 index 0000000..3d5bf75 --- /dev/null +++ b/src/Transform/EnableMemoryPool.cpp @@ -0,0 +1,159 @@ +//===-------- EnableMemoryPool.cpp - Enable Memory Pool for MemRefs -------===// +// +// Copyright 2019-2020 The IBM Research Authors. +// +// ============================================================================= +// +// For certain cases the number of individual memory allocations required for +// all internal tensors is large and needs to be mitigated. This pass enables a +// managed memory pool for allocating MemRefs. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "src/Pass/Passes.hpp" + +using namespace mlir; + +namespace { + +bool checkOpResultIsReturned(AllocOp *allocOp) { + auto parentBlock = allocOp->getOperation()->getBlock(); + + bool opIsReturned = false; + parentBlock->walk([&opIsReturned, allocOp](ReturnOp op) { + auto result = allocOp->getResult(); + for (const auto &operand : op.getOperands()) + if (operand == result) + opIsReturned = true; + }); + + return opIsReturned; +} + +bool checkOpResultIsUsedByGetRef(AllocOp *allocOp) { + auto parentBlock = allocOp->getOperation()->getBlock(); + + bool opIsUsedInGetRef = false; + parentBlock->walk([&opIsUsedInGetRef, allocOp](KrnlGetRefOp op) { + auto result = allocOp->getResult(); + for (const auto &operand : op.getOperands()) + if (operand == result) + opIsUsedInGetRef = true; + }); + + return opIsUsedInGetRef; +} + +/*! + * RewritePattern that replaces: + * %0 = alloc() : memref<x> + * with: + * %mem = alloc() : memref<x> + * %0 = krnl.getref %mem : memref<x> + * + * For now, to enable testing, offset will always be 0. + */ + +class KrnlEnableMemoryPool : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + AllocOp allocOp, PatternRewriter &rewriter) const override { + auto loc = allocOp.getLoc(); + + auto memRefType = convertToMemRefType(allocOp.getResult().getType()); + + // For now we only support constant tensors. + // TODO: Enable this pass for MemRef with dyanmic shapes. + // If alloc operation is not returned then it is a candidate for + // being included in the memory pool. + if (!hasAllConstantDimensions(memRefType) || + checkOpResultIsReturned(&allocOp)) + return failure(); + + // Check the result of this alloc is not already used by a krnl.getref. + if (checkOpResultIsUsedByGetRef(&allocOp)) + return failure(); + + // Compute total size. + auto memRefShape = memRefType.getShape(); + int64_t totalSize = 1; + for (int i = 0; i < memRefShape.size(); i++) + totalSize *= memRefShape[i]; + totalSize *= getMemRefEltSizeInBytes(memRefType); + + // Emit new alloc. + SmallVector memPoolShape; + memPoolShape.emplace_back(totalSize); + auto memPoolMemRefType = + MemRefType::get(memPoolShape, rewriter.getIntegerType(8)); + auto newAlloc = rewriter.create(loc, memPoolMemRefType); + + // Emit new dealloc. + auto dealloc = rewriter.create(loc, newAlloc); + auto parentBlock = allocOp.getOperation()->getBlock(); + dealloc.getOperation()->moveBefore(&parentBlock->back()); + + // Get reference to local MemRef. + auto zero = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + auto poolMemRef = + rewriter.create(loc, memRefType, newAlloc, zero); + + rewriter.replaceOp(allocOp, poolMemRef.getResult()); + + return success(); + } +}; + +class KrnlEliminateOldDealloc : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite( + DeallocOp deallocOp, PatternRewriter &rewriter) const override { + if (auto getRefOp = llvm::dyn_cast( + deallocOp.getOperand().getDefiningOp())) { + rewriter.eraseOp(deallocOp); + return success(); + } + + return failure(); + } +}; + +// TODO: Replace old dealloc with krnl.unsetref. + +/*! + * Function pass that enables memory pooling for MemRefs. + */ +class KrnlEnableMemoryPoolPass + : public PassWrapper { +public: + void runOnFunction() override { + auto function = getFunction(); + + ConversionTarget target(getContext()); + OwningRewritePatternList patterns; + patterns.insert(&getContext()); + patterns.insert(&getContext()); + + applyPatternsAndFoldGreedily(function, patterns); + } +}; +} // namespace + +std::unique_ptr mlir::createKrnlEnableMemoryPoolPass() { + return std::make_unique(); +} + +static PassRegistration pass("enable-memory-pool", + "Enable a memory pool for allocating internal MemRefs."); diff --git a/src/Transform/LowerKrnl.cpp b/src/Transform/LowerKrnl.cpp index 79e3803..79370d0 100644 --- a/src/Transform/LowerKrnl.cpp +++ b/src/Transform/LowerKrnl.cpp @@ -165,6 +165,7 @@ void KrnlToAffineLoweringPass::runOnFunction() { target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); + target.addLegalOp(); OwningRewritePatternList patterns; patterns.insert operands, + ConversionPatternRewriter &rewriter) const override { + auto *context = op->getContext(); + auto loc = op->getLoc(); + auto *llvmDialect = + op->getContext()->getRegisteredDialect(); + assert(llvmDialect && "expected llvm dialect to be registered"); + + KrnlGetRefOpOperandAdaptor operandAdaptor(operands); + + // This is the type of the krnl.getref output. This type is used + // for the type of the internal MemRef. + auto type = op->getResult(0).getType(); + auto memRefTy = type.cast(); + auto llvmMemRefType = + typeConverter.convertType(type).cast(); + auto outputElementType = + typeConverter.convertType(memRefTy.getElementType()); + + // This is the start of the memory pool containing the output MemRef. + Type memPoolType = operandAdaptor.mempool() + .getType() + .cast() + .getStructElementType(1); + Value alignedMemPoolBase = rewriter.create(loc, + memPoolType, operandAdaptor.mempool(), rewriter.getI64ArrayAttr(1)); + + // Get pointer using the offset. + auto offset = operandAdaptor.offset(); + auto llvmMemPoolType = + typeConverter.convertType(memPoolType).cast(); + auto outputMemPoolTypePtrAlloc = rewriter.create( + loc, llvmMemPoolType, alignedMemPoolBase, ArrayRef({offset})); + + // Bitcast to output MemRef type i.e. from i8* to the element type + // of the output MemRef. + auto llvmOutputElementType = outputElementType.cast(); + Value outputTypedPtrAlloc = rewriter.create( + loc, llvmOutputElementType.getPointerTo(), outputMemPoolTypePtrAlloc); + + // Create llvm MemRef from original MemRef and fill the data pointers. + auto llvmMemRef = MemRefDescriptor::fromStaticShape( + rewriter, loc, typeConverter, memRefTy, outputTypedPtrAlloc); + + rewriter.replaceOp(op, {llvmMemRef}); + return success(); + } + +private: + static int64_t ArrayAttrIntVal(ArrayAttr a, int i) { + return (a.getValue()[i]).cast().getInt(); + } +}; + //===----------------------------------------------------------------------===// // KRNL to LLVM: KrnlGlobalOpLowering //===----------------------------------------------------------------------===// @@ -648,6 +713,7 @@ void KrnlToLLVMLoweringPass::runOnOperation() { /*useAlignedAlloc=*/false); patterns.insert(&getContext(), typeConverter); + patterns.insert(&getContext(), typeConverter); // Lower from the `krnl` dialect i.e. the Reshape operation. patterns.insert( diff --git a/test/mlir/krnl/krnl_getref_lowering.mlir b/test/mlir/krnl/krnl_getref_lowering.mlir new file mode 100644 index 0000000..3e60b4e --- /dev/null +++ b/test/mlir/krnl/krnl_getref_lowering.mlir @@ -0,0 +1,38 @@ +// RUN: onnx-mlir-opt --lower-krnl --lower-all-llvm %s -split-input-file | FileCheck %s + +func @test_getref_lowering(%arg0: memref<2x2xf32>) -> memref<2x2xf32> { + %c13_i64 = constant 13 : i64 + %1 = alloc() : memref<10x10xf32> + %2 = "krnl.getref"(%1, %c13_i64) : (memref<10x10xf32>, i64) -> memref<2x2xf32> + return %2 : memref<2x2xf32> + + // CHECK-LABEL: test_getref_lowering + // CHECK: %[[OFFSET:.+]] = llvm.mlir.constant(13 : i64) : !llvm.i64 + // CHECK: [[CONST_10_0:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64 + // CHECK: [[CONST_10_1:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64 + // CHECK: [[MUL1:%.+]] = llvm.mul [[CONST_10_0]], [[CONST_10_1]] : !llvm.i64 + // CHECK: [[FLOAT_STAR:%.+]] = llvm.mlir.null : !llvm<"float*"> + // CHECK: %[[CONST_1:.+]] = llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK: [[ELEM1:%.+]] = llvm.getelementptr [[FLOAT_STAR]][%[[CONST_1]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> + // CHECK: [[ELEM_SIZE:%.+]] = llvm.ptrtoint [[ELEM1]] : !llvm<"float*"> to !llvm.i64 + // CHECK: [[MUL2:%.+]] = llvm.mul [[MUL1]], [[ELEM_SIZE]] : !llvm.i64 + // CHECK: [[MEMPOOL:%.+]] = llvm.call @malloc([[MUL2]]) : (!llvm.i64) -> !llvm<"i8*"> + // CHECK: [[TYPED_MEMPOOL:%.+]] = llvm.bitcast [[MEMPOOL]] : !llvm<"i8*"> to !llvm<"float*"> + // CHECK: [[MEMPOOL_MEMREF:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[MEMREF1:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL]], [[MEMPOOL_MEMREF]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[MEMREF2:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL]], [[MEMREF1]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: llvm.mlir.constant + // CHECK: llvm.insertvalue + // CHECK: llvm.mlir.constant + // CHECK: llvm.mlir.constant + // CHECK: llvm.insertvalue + // CHECK: llvm.insertvalue + // CHECK: llvm.insertvalue + // CHECK: llvm.insertvalue + // CHECK: [[MEMPOOL1:%.+]] = llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[MEMPOOL_ALLOC:%.+]] = llvm.getelementptr [[MEMPOOL1]][%[[OFFSET]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> + // CHECK: [[TYPED_MEMPOOL_ALLOC:%.+]] = llvm.bitcast [[MEMPOOL_ALLOC]] : !llvm<"float*"> to !llvm<"float*"> + // CHECK: [[MEMPOOLED_MEMREF:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[MEMREF3:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL_ALLOC]], [[MEMPOOLED_MEMREF]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[MEMREF4:%.+]] = llvm.insertvalue [[TYPED_MEMPOOL_ALLOC]], [[MEMREF3]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +} diff --git a/test/mlir/krnl/memory_pool.mlir b/test/mlir/krnl/memory_pool.mlir new file mode 100644 index 0000000..49fad5c --- /dev/null +++ b/test/mlir/krnl/memory_pool.mlir @@ -0,0 +1,66 @@ +// RUN: onnx-mlir-opt --shape-inference --lower-frontend --enable-memory-pool --lower-krnl --lower-all-llvm %s -split-input-file | FileCheck %s + +func @test_memory_pool(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { + %0 = "onnx.Add"(%arg0, %arg0) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + %1 = "onnx.Add"(%0, %arg0) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + return %1 : tensor<10x10xf32> + + /// Define the offset inside the memory pool. + // CHECK: %[[OFFSET:.+]] = llvm.mlir.constant(0 : i64) : !llvm.i64 + + /// Allocate memory for the memory pool. + // CHECK: [[MEMPOOL_SIZE:%.+]] = llvm.mlir.constant(400 : index) : !llvm.i64 + // CHECK: [[TMP1:%.+]] = llvm.mlir.null : !llvm<"i8*"> + // CHECK: %[[CONST1:.+]] = llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK: [[TMP2:%.+]] = llvm.getelementptr [[TMP1]][%[[CONST1]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*"> + // CHECK: [[TYPE_SIZE_IN_BYTES:%.+]] = llvm.ptrtoint [[TMP2]] : !llvm<"i8*"> to !llvm.i64 + // CHECK: [[TOTAL_SIZE:%.+]] = llvm.mul [[MEMPOOL_SIZE]], [[TYPE_SIZE_IN_BYTES]] : !llvm.i64 + // CHECK: [[ALLOC_MEM_POOL:%.+]] = llvm.call @malloc([[TOTAL_SIZE]]) : (!llvm.i64) -> !llvm<"i8*"> + // CHECK: [[BITCAST_ALLOC_MEM_POOL:%.+]] = llvm.bitcast [[ALLOC_MEM_POOL]] : !llvm<"i8*"> to !llvm<"i8*"> + + /// MemRef representing the memory pool and which contains the memory allocated above. + // CHECK: [[MEMREF0:%.+]] = llvm.mlir.undef : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }"> + // CHECK: [[TMP3:%.+]] = llvm.insertvalue [[BITCAST_ALLOC_MEM_POOL]], [[MEMREF0]][0] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }"> + // CHECK: llvm.insertvalue [[BITCAST_ALLOC_MEM_POOL]], [[TMP3]][1] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }"> + // CHECK: llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK: llvm.insertvalue + // CHECK: llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK: llvm.insertvalue + // CHECK: [[TMP4:%.+]] = llvm.insertvalue {{.*}}[4, 0] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }"> + + /// Get reference within the memory pool where the data of the getref instruction has already been allocated. + // CHECK: [[MEMPOOL_BASE:%.+]] = llvm.extractvalue [[TMP4]][1] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }"> + // CHECK: [[GETREF_MEMORY:%.+]] = llvm.getelementptr [[MEMPOOL_BASE]][%[[OFFSET]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*"> + // CHECK: [[CASTED_GETREF_MEMORY:%.+]] = llvm.bitcast [[GETREF_MEMORY]] : !llvm<"i8*"> to !llvm<"float*"> + + /// Create MemRef for krnl.getref. + // CHECK: [[MEMREF1:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[MEMREF1_TMP1:%.+]] = llvm.insertvalue [[CASTED_GETREF_MEMORY]], [[MEMREF1]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[MEMREF1_TMP2:%.+]] = llvm.insertvalue [[CASTED_GETREF_MEMORY]], [[MEMREF1_TMP1]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[CONST2:%.+]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK: [[MEMREF1_TMP3:%.+]] = llvm.insertvalue [[CONST2]], [[MEMREF1_TMP2]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[CONST3:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64 + // CHECK: [[MEMREF1_TMP4:%.+]] = llvm.insertvalue [[CONST3]], [[MEMREF1_TMP3]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[CONST4:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64 + // CHECK: [[MEMREF1_TMP5:%.+]] = llvm.insertvalue [[CONST4]], [[MEMREF1_TMP4]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[CONST5:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64 + // CHECK: [[MEMREF1_TMP6:%.+]] = llvm.insertvalue [[CONST5]], [[MEMREF1_TMP5]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[CONST6:%.+]] = llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK: [[MEMREF1_TMP7:%.+]] = llvm.insertvalue [[CONST6]], [[MEMREF1_TMP6]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + + /// Usage of the getref MemRef. + // CHECK: [[MEM0:%.+]] = llvm.extractvalue [[MEMREF1_TMP7]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: [[CONST7:%.+]] = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK: [[CONST8:%.+]] = llvm.mlir.constant(10 : index) : !llvm.i64 + // CHECK: [[MUL1:%.+]] = llvm.mul {{.*}}, [[CONST8]] : !llvm.i64 + // CHECK: [[ADD1:%.+]] = llvm.add [[CONST7]], [[MUL1]] : !llvm.i64 + // CHECK: [[CONST9:%.+]] = llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK: [[MUL2:%.+]] = llvm.mul {{.*}}, [[CONST9]] : !llvm.i64 + // CHECK: %[[ADD2:.+]] = llvm.add [[ADD1]], [[MUL2]] : !llvm.i64 + // CHECK: llvm.getelementptr [[MEM0]][%[[ADD2]]] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*"> + + /// Deallocation of the memory pool. + // CHECK: [[MEMPOOL_BASE_UNALIGNED:%.+]] = llvm.extractvalue [[TMP4]][0] : !llvm<"{ i8*, i8*, i64, [1 x i64], [1 x i64] }"> + // CHECK: [[CASTED_MEMPOOL_BASE_UNALIGNED:%.+]] = llvm.bitcast [[MEMPOOL_BASE_UNALIGNED]] : !llvm<"i8*"> to !llvm<"i8*"> + // CHECK: llvm.call @free([[CASTED_MEMPOOL_BASE_UNALIGNED]]) : (!llvm<"i8*">) -> () +} diff --git a/test/mlir/onnx/onnx_enable_memory_pool.mlir b/test/mlir/onnx/onnx_enable_memory_pool.mlir new file mode 100644 index 0000000..9b1e1c7 --- /dev/null +++ b/test/mlir/onnx/onnx_enable_memory_pool.mlir @@ -0,0 +1,69 @@ +// RUN: onnx-mlir-opt --shape-inference --lower-frontend --enable-memory-pool %s -split-input-file | FileCheck %s + +/// One intermediate value to allocate in the memory pool. +func @test_enable_memory_pool(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { + %0 = "onnx.Add"(%arg0, %arg0) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + %1 = "onnx.Add"(%0, %arg0) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + return %1 : tensor<10x10xf32> + + // CHECK-LABEL: test_enable_memory_pool + // CHECK: [[CONST0:%.+]] = constant 0 : i64 + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32> + // CHECK: [[MEMPOOL:%.+]] = alloc() : memref<400xi8> + // CHECK: [[GETREF:%.+]] = "krnl.getref"([[MEMPOOL]], [[CONST0]]) : (memref<400xi8>, i64) -> memref<10x10xf32> + // CHECK: krnl.define_loops + // CHECK: krnl.optimize_loops + // CHECK: krnl.iterate + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg1, %arg2] : memref<10x10xf32> + // CHECK: [[LOAD2:%.+]] = load %arg0[%arg1, %arg2] : memref<10x10xf32> + // CHECK: [[ADDF1:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32 + // CHECK: store [[ADDF1]], [[GETREF]][%arg1, %arg2] : memref<10x10xf32> + // CHECK: krnl.define_loops + // CHECK: krnl.optimize_loops + // CHECK: krnl.iterate + // CHECK: dealloc [[MEMPOOL]] : memref<400xi8> + // CHECK: return [[RES]] : memref<10x10xf32> +} + +/// Two intermediate values to allocate in the memory pool. +func @test_enable_memory_pool_2(%arg0: tensor<10x10xf32>, %arg1: tensor<10x20xf32>) -> tensor<10x20xf32> { + %0 = "onnx.Add"(%arg0, %arg0) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + %1 = "onnx.MatMul"(%0, %arg1) : (tensor<10x10xf32>, tensor<10x20xf32>) -> tensor<10x20xf32> + %2 = "onnx.Add"(%1, %arg1) : (tensor<10x20xf32>, tensor<10x20xf32>) -> tensor<10x20xf32> + return %2 : tensor<10x20xf32> + + // CHECK-LABEL: test_enable_memory_pool_2 + // CHECK: [[CONST0:%.+]] = constant 0 : i64 + // CHECK: [[CONST1:%.+]] = constant 0.000000e+00 : f32 + // CHECK: [[RES:%.+]] = alloc() : memref<10x20xf32> + // CHECK: [[MEMPOOL0:%.+]] = alloc() : memref<800xi8> + // CHECK: [[GETREF0:%.+]] = "krnl.getref"([[MEMPOOL0]], [[CONST0]]) : (memref<800xi8>, i64) -> memref<10x20xf32> + // CHECK: [[MEMPOOL1:%.+]] = alloc() : memref<400xi8> + // CHECK: [[GETREF1:%.+]] = "krnl.getref"([[MEMPOOL1]], [[CONST0]]) : (memref<400xi8>, i64) -> memref<10x10xf32> + // CHECK: krnl.define_loops + // CHECK: krnl.optimize_loops + // CHECK: krnl.iterate + // CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[LOAD2:%.+]] = load %arg0[%arg2, %arg3] : memref<10x10xf32> + // CHECK: [[ADDF1:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32 + // CHECK: store [[ADDF1]], [[GETREF1]][%arg2, %arg3] : memref<10x10xf32> + // CHECK: krnl.define_loops + // CHECK: krnl.optimize_loops + // CHECK: krnl.iterate + // CHECK: [[LOAD3:%.+]] = load [[GETREF1]][%arg2, %arg4] : memref<10x10xf32> + // CHECK: [[LOAD4:%.+]] = load %arg1[%arg4, %arg3] : memref<10x20xf32> + // CHECK: [[LOAD5:%.+]] = load [[GETREF0]][%arg2, %arg3] : memref<10x20xf32> + // CHECK: [[MULF1:%.+]] = mulf [[LOAD3]], [[LOAD4]] : f32 + // CHECK: [[ADDF2:%.+]] = addf [[LOAD5]], [[MULF1]] : f32 + // CHECK: store [[ADDF2]], [[GETREF0]][%arg2, %arg3] : memref<10x20xf32> + // CHECK: krnl.define_loops + // CHECK: krnl.optimize_loops + // CHECK: krnl.iterate + // CHECK: [[LOAD6:%.+]] = load [[GETREF0]][%arg2, %arg3] : memref<10x20xf32> + // CHECK: [[LOAD7:%.+]] = load %arg1[%arg2, %arg3] : memref<10x20xf32> + // CHECK: [[ADDF3:%.+]] = addf [[LOAD6]], [[LOAD7]] : f32 + // CHECK: store [[ADDF3]], [[RES]][%arg2, %arg3] : memref<10x20xf32> + // CHECK: dealloc [[MEMPOOL1]] : memref<400xi8> + // CHECK: dealloc [[MEMPOOL0]] : memref<800xi8> + // CHECK: return [[RES]] : memref<10x20xf32> +}