Bundle individual memory pools into a single memory pool (#183)
* Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Add memory pooling for constant sized arrays. * Clean code. * Clean code. * Clean code. * Add simple bundling test. Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
0a936edf79
commit
100bfc81b4
|
@ -22,7 +22,8 @@ set(OMLibs
|
||||||
OMElideConstants
|
OMElideConstants
|
||||||
OMElideKrnlGlobalConstants
|
OMElideKrnlGlobalConstants
|
||||||
OMPackKrnlGlobalConstants
|
OMPackKrnlGlobalConstants
|
||||||
OMEnableMemoryPool)
|
OMEnableMemoryPool
|
||||||
|
OMBundleMemoryPools)
|
||||||
set(OMLibs ${OMLibs} PARENT_SCOPE)
|
set(OMLibs ${OMLibs} PARENT_SCOPE)
|
||||||
|
|
||||||
message(SATUS "OMLibs" ${OMLibs})
|
message(SATUS "OMLibs" ${OMLibs})
|
||||||
|
|
|
@ -461,3 +461,28 @@ Value emitNegativeInfinityConstantOp(
|
||||||
int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
|
int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
|
||||||
return (a.getValue()[i]).cast<IntegerAttr>().getInt();
|
return (a.getValue()[i]).cast<IntegerAttr>().getInt();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool checkOpResultIsUsedByGetRef(AllocOp *allocOp) {
|
||||||
|
auto parentBlock = allocOp->getOperation()->getBlock();
|
||||||
|
|
||||||
|
bool opIsUsedInGetRef = false;
|
||||||
|
parentBlock->walk([&opIsUsedInGetRef, allocOp](KrnlGetRefOp op) {
|
||||||
|
auto result = allocOp->getResult();
|
||||||
|
for (const auto &operand : op.getOperands())
|
||||||
|
if (operand == result)
|
||||||
|
opIsUsedInGetRef = true;
|
||||||
|
});
|
||||||
|
|
||||||
|
return opIsUsedInGetRef;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: support dynamic sizes.
|
||||||
|
int64_t getMemRefSizeInBytes(Value val) {
|
||||||
|
auto memRefType = convertToMemRefType(val.getType());
|
||||||
|
auto memRefShape = memRefType.getShape();
|
||||||
|
int64_t size = 1;
|
||||||
|
for (int i = 0; i < memRefShape.size(); i++)
|
||||||
|
size *= memRefShape[i];
|
||||||
|
size *= getMemRefEltSizeInBytes(memRefType);
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
|
@ -242,3 +242,7 @@ void populateLoweringONNXSqueezeOpPattern(
|
||||||
|
|
||||||
void populateLoweringONNXSplitOpPattern(
|
void populateLoweringONNXSplitOpPattern(
|
||||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
bool checkOpResultIsUsedByGetRef(AllocOp *allocOp);
|
||||||
|
|
||||||
|
int64_t getMemRefSizeInBytes(Value val);
|
||||||
|
|
|
@ -42,6 +42,12 @@ void initOMPasses() {
|
||||||
return mlir::createKrnlEnableMemoryPoolPass();
|
return mlir::createKrnlEnableMemoryPoolPass();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
mlir::registerPass("bundle-memory-pools",
|
||||||
|
"Bundle memory pools of internal MemRefs into a single memory pool.",
|
||||||
|
[]() -> std::unique_ptr<mlir::Pass> {
|
||||||
|
return mlir::createKrnlBundleMemoryPoolsPass();
|
||||||
|
});
|
||||||
|
|
||||||
mlir::registerPass(
|
mlir::registerPass(
|
||||||
"lower-krnl", "Lower Krnl dialect.", []() -> std::unique_ptr<mlir::Pass> {
|
"lower-krnl", "Lower Krnl dialect.", []() -> std::unique_ptr<mlir::Pass> {
|
||||||
return mlir::createLowerKrnlPass();
|
return mlir::createLowerKrnlPass();
|
||||||
|
|
|
@ -269,6 +269,8 @@ void addONNXToKrnlPasses(mlir::PassManager &pm) {
|
||||||
|
|
||||||
// TODO: make this pass optional:
|
// TODO: make this pass optional:
|
||||||
pm.addPass(mlir::createKrnlEnableMemoryPoolPass());
|
pm.addPass(mlir::createKrnlEnableMemoryPoolPass());
|
||||||
|
pm.addPass(mlir::createKrnlBundleMemoryPoolsPass());
|
||||||
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
}
|
}
|
||||||
|
|
||||||
void addKrnlToAffinePasses(mlir::PassManager &pm) {
|
void addKrnlToAffinePasses(mlir::PassManager &pm) {
|
||||||
|
|
|
@ -31,6 +31,9 @@ std::unique_ptr<Pass> createElideConstantValuePass();
|
||||||
/// Pass for enabling a memory pool for MemRefs.
|
/// Pass for enabling a memory pool for MemRefs.
|
||||||
std::unique_ptr<Pass> createKrnlEnableMemoryPoolPass();
|
std::unique_ptr<Pass> createKrnlEnableMemoryPoolPass();
|
||||||
|
|
||||||
|
/// Pass for enabling a memory pool for MemRefs.
|
||||||
|
std::unique_ptr<Pass> createKrnlBundleMemoryPoolsPass();
|
||||||
|
|
||||||
/// Add pass for lowering to Krnl IR.
|
/// Add pass for lowering to Krnl IR.
|
||||||
std::unique_ptr<Pass> createLowerToKrnlPass();
|
std::unique_ptr<Pass> createLowerToKrnlPass();
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,148 @@
|
||||||
|
//===-- BundleMemoryPools.cpp - Bundle Memory Pools for internal MemRefs -===//
|
||||||
|
//
|
||||||
|
// Copyright 2019-2020 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// For certain cases the number of individual memory allocations required for
|
||||||
|
// all internal tensors is large and needs to be mitigated. This pass bundles
|
||||||
|
// all the internal MemRef memory pools emitted by the EnableMemoryPool pass
|
||||||
|
// int a single memory pool.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
|
||||||
|
#include "src/Dialect/Krnl/KrnlOps.hpp"
|
||||||
|
#include "src/Pass/Passes.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
KrnlGetRefOp getUnbundledGetRef(AllocOp *memPool) {
|
||||||
|
auto parentBlock = memPool->getOperation()->getBlock();
|
||||||
|
|
||||||
|
KrnlGetRefOp unbundledGetRef = nullptr;
|
||||||
|
parentBlock->walk([&unbundledGetRef, memPool](KrnlGetRefOp op) {
|
||||||
|
auto result = memPool->getResult();
|
||||||
|
if (op.getOperands()[0] != result)
|
||||||
|
unbundledGetRef = op;
|
||||||
|
});
|
||||||
|
|
||||||
|
return unbundledGetRef;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* RewritePattern that replaces:
|
||||||
|
* %mem1 = alloc() : memref<<dims1>x<type>>
|
||||||
|
* %mem2 = alloc() : memref<<dims2>x<type>>
|
||||||
|
* %1 = krnl.getref %mem2 0 : memref<<dims2>x<type>>
|
||||||
|
* =>
|
||||||
|
* %mem1 = alloc() : memref<<dims1 + dims2>x<type>>
|
||||||
|
* %1 = krnl.getref %mem1 <dims1> : memref<<dims2>x<type>>
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* ASSUMPTION: All krnl.getref operations in the program have been emitted
|
||||||
|
* by the EnableMemoryPool pass i.e. there are no krnl.getref
|
||||||
|
* operations which are not related to the memory pool.
|
||||||
|
* krnl.getref is an operation specific to memory management
|
||||||
|
* for other use cases use MLIR Standard dialect operations.
|
||||||
|
* This assumption simplifies the code and avoids additional
|
||||||
|
* checks to ensure that all the participating krnl.getref
|
||||||
|
* operations are part of memory pooling.
|
||||||
|
*/
|
||||||
|
|
||||||
|
class KrnlBundleMemoryPools : public OpRewritePattern<AllocOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<AllocOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
AllocOp allocOp, PatternRewriter &rewriter) const override {
|
||||||
|
auto loc = allocOp.getLoc();
|
||||||
|
|
||||||
|
auto memRefType = convertToMemRefType(allocOp.getResult().getType());
|
||||||
|
auto memRefShape = memRefType.getShape();
|
||||||
|
|
||||||
|
// If alloca result is not used by getref then it cannot be part of
|
||||||
|
// the memory pool.
|
||||||
|
if (!checkOpResultIsUsedByGetRef(&allocOp))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Alloc memory type must be byte.
|
||||||
|
if (getMemRefEltSizeInBytes(memRefType) != 1)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Rank of the allocated MemRef must be 1.
|
||||||
|
if (memRefShape.size() != 1)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// TODO: Change this when dyanmic shapes are supported.
|
||||||
|
// TODO: Add support for dynamic shapes.
|
||||||
|
int64_t currentMemPoolSize = memRefShape[0];
|
||||||
|
|
||||||
|
// Get a KrnlGetRefOp which does not use the current alloc.
|
||||||
|
if (KrnlGetRefOp unbundledGetRef = getUnbundledGetRef(&allocOp)) {
|
||||||
|
unbundledGetRef.dump();
|
||||||
|
|
||||||
|
// Current memory pool size is the offset for the newly bundled
|
||||||
|
// internal MemRef. Emit the offset as a constant.
|
||||||
|
auto offset = rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(
|
||||||
|
rewriter.getIntegerType(64), currentMemPoolSize));
|
||||||
|
|
||||||
|
// Size in bytes of the output of the krnl.getref operation.
|
||||||
|
int64_t unbundledTotalSize =
|
||||||
|
getMemRefSizeInBytes(unbundledGetRef.getResult());
|
||||||
|
|
||||||
|
// Compute new size.
|
||||||
|
int64_t bundleTotalSize = unbundledTotalSize + currentMemPoolSize;
|
||||||
|
|
||||||
|
// We need to emit a new alloc which contains the additional MemRef.
|
||||||
|
SmallVector<int64_t, 1> newMemPoolShape;
|
||||||
|
newMemPoolShape.emplace_back(bundleTotalSize);
|
||||||
|
auto bundledMemPoolMemRefType =
|
||||||
|
MemRefType::get(newMemPoolShape, rewriter.getIntegerType(8));
|
||||||
|
auto bundledAlloc =
|
||||||
|
rewriter.create<AllocOp>(loc, bundledMemPoolMemRefType);
|
||||||
|
|
||||||
|
// The newly bundled MemRef expressed as a KrnlGetRefOp.
|
||||||
|
auto bundledMemRef = rewriter.create<KrnlGetRefOp>(
|
||||||
|
loc, unbundledGetRef.getResult().getType(), bundledAlloc, offset);
|
||||||
|
rewriter.replaceOp(unbundledGetRef, bundledMemRef.getResult());
|
||||||
|
|
||||||
|
// Replace old memory pool with new one.
|
||||||
|
rewriter.replaceOp(allocOp, bundledAlloc.getResult());
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* Function pass that enables memory pooling for MemRefs.
|
||||||
|
*/
|
||||||
|
class KrnlBundleMemoryPoolsPass
|
||||||
|
: public PassWrapper<KrnlBundleMemoryPoolsPass, FunctionPass> {
|
||||||
|
public:
|
||||||
|
void runOnFunction() override {
|
||||||
|
auto function = getFunction();
|
||||||
|
|
||||||
|
ConversionTarget target(getContext());
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
patterns.insert<KrnlBundleMemoryPools>(&getContext());
|
||||||
|
|
||||||
|
applyPatternsAndFoldGreedily(function, patterns);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> mlir::createKrnlBundleMemoryPoolsPass() {
|
||||||
|
return std::make_unique<KrnlBundleMemoryPoolsPass>();
|
||||||
|
}
|
|
@ -46,6 +46,8 @@ target_include_directories(OMPackKrnlGlobalConstants
|
||||||
${ONNX_MLIR_SRC_ROOT}
|
${ONNX_MLIR_SRC_ROOT}
|
||||||
${ONNX_MLIR_BIN_ROOT}
|
${ONNX_MLIR_BIN_ROOT}
|
||||||
${ONNX_MLIR_SRC_ROOT})
|
${ONNX_MLIR_SRC_ROOT})
|
||||||
|
add_dependencies(OMPackKrnlGlobalConstants
|
||||||
|
OMKrnlOps)
|
||||||
|
|
||||||
add_library(OMEnableMemoryPool
|
add_library(OMEnableMemoryPool
|
||||||
EnableMemoryPool.cpp)
|
EnableMemoryPool.cpp)
|
||||||
|
@ -54,13 +56,23 @@ target_include_directories(OMEnableMemoryPool
|
||||||
${ONNX_MLIR_SRC_ROOT}
|
${ONNX_MLIR_SRC_ROOT}
|
||||||
${ONNX_MLIR_BIN_ROOT}
|
${ONNX_MLIR_BIN_ROOT}
|
||||||
${ONNX_MLIR_SRC_ROOT})
|
${ONNX_MLIR_SRC_ROOT})
|
||||||
add_dependencies(OMPackKrnlGlobalConstants
|
|
||||||
OMKrnlOps)
|
|
||||||
|
|
||||||
target_link_libraries(OMEnableMemoryPool
|
target_link_libraries(OMEnableMemoryPool
|
||||||
onnx)
|
onnx)
|
||||||
add_dependencies(OMEnableMemoryPool
|
add_dependencies(OMEnableMemoryPool
|
||||||
OMKrnlOps
|
OMKrnlOps
|
||||||
OMONNXOps)
|
OMONNXOps)
|
||||||
|
|
||||||
|
add_library(OMBundleMemoryPools
|
||||||
|
BundleMemoryPools.cpp)
|
||||||
|
target_include_directories(OMBundleMemoryPools
|
||||||
|
PRIVATE
|
||||||
|
${ONNX_MLIR_SRC_ROOT}
|
||||||
|
${ONNX_MLIR_BIN_ROOT}
|
||||||
|
${ONNX_MLIR_SRC_ROOT})
|
||||||
|
target_link_libraries(OMBundleMemoryPools
|
||||||
|
onnx)
|
||||||
|
add_dependencies(OMBundleMemoryPools
|
||||||
|
OMKrnlOps
|
||||||
|
OMONNXOps)
|
||||||
|
|
||||||
add_subdirectory(ONNX)
|
add_subdirectory(ONNX)
|
||||||
|
|
|
@ -37,20 +37,6 @@ bool checkOpResultIsReturned(AllocOp *allocOp) {
|
||||||
return opIsReturned;
|
return opIsReturned;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool checkOpResultIsUsedByGetRef(AllocOp *allocOp) {
|
|
||||||
auto parentBlock = allocOp->getOperation()->getBlock();
|
|
||||||
|
|
||||||
bool opIsUsedInGetRef = false;
|
|
||||||
parentBlock->walk([&opIsUsedInGetRef, allocOp](KrnlGetRefOp op) {
|
|
||||||
auto result = allocOp->getResult();
|
|
||||||
for (const auto &operand : op.getOperands())
|
|
||||||
if (operand == result)
|
|
||||||
opIsUsedInGetRef = true;
|
|
||||||
});
|
|
||||||
|
|
||||||
return opIsUsedInGetRef;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* RewritePattern that replaces:
|
* RewritePattern that replaces:
|
||||||
* %0 = alloc() : memref<<dims>x<type>>
|
* %0 = alloc() : memref<<dims>x<type>>
|
||||||
|
@ -84,11 +70,7 @@ public:
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Compute total size.
|
// Compute total size.
|
||||||
auto memRefShape = memRefType.getShape();
|
int64_t totalSize = getMemRefSizeInBytes(allocOp.getResult());
|
||||||
int64_t totalSize = 1;
|
|
||||||
for (int i = 0; i < memRefShape.size(); i++)
|
|
||||||
totalSize *= memRefShape[i];
|
|
||||||
totalSize *= getMemRefEltSizeInBytes(memRefType);
|
|
||||||
|
|
||||||
// Emit new alloc.
|
// Emit new alloc.
|
||||||
SmallVector<int64_t, 1> memPoolShape;
|
SmallVector<int64_t, 1> memPoolShape;
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
// RUN: onnx-mlir-opt --bundle-memory-pools --canonicalize %s -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
func @test_pool_bundling(%arg0: memref<10x10xf32>, %arg1: memref<10x20xf32>) -> memref<10x20xf32> {
|
||||||
|
%c0_i64 = constant 0 : i64
|
||||||
|
%ind = constant 0 : index
|
||||||
|
%cst = constant 0.000000e+00 : f32
|
||||||
|
%0 = alloc() : memref<10x20xf32>
|
||||||
|
%1 = alloc() : memref<800xi8>
|
||||||
|
%2 = "krnl.getref"(%1, %c0_i64) : (memref<800xi8>, i64) -> memref<10x20xf32>
|
||||||
|
%3 = alloc() : memref<400xi8>
|
||||||
|
%4 = "krnl.getref"(%3, %c0_i64) : (memref<400xi8>, i64) -> memref<10x10xf32>
|
||||||
|
%5 = alloc() : memref<800xi8>
|
||||||
|
%6 = "krnl.getref"(%5, %c0_i64) : (memref<800xi8>, i64) -> memref<10x20xf32>
|
||||||
|
%7 = alloc() : memref<800xi8>
|
||||||
|
%8 = "krnl.getref"(%7, %c0_i64) : (memref<800xi8>, i64) -> memref<10x20xf32>
|
||||||
|
%9 = alloc() : memref<400xi8>
|
||||||
|
%10 = "krnl.getref"(%9, %c0_i64) : (memref<400xi8>, i64) -> memref<10x10xf32>
|
||||||
|
affine.store %cst, %10[%ind, %ind] : memref<10x10xf32>
|
||||||
|
affine.store %cst, %8[%ind, %ind] : memref<10x20xf32>
|
||||||
|
affine.store %cst, %6[%ind, %ind] : memref<10x20xf32>
|
||||||
|
affine.store %cst, %4[%ind, %ind] : memref<10x10xf32>
|
||||||
|
affine.store %cst, %2[%ind, %ind] : memref<10x20xf32>
|
||||||
|
affine.store %cst, %0[%ind, %ind] : memref<10x20xf32>
|
||||||
|
dealloc %9 : memref<400xi8>
|
||||||
|
dealloc %7 : memref<800xi8>
|
||||||
|
dealloc %5 : memref<800xi8>
|
||||||
|
dealloc %3 : memref<400xi8>
|
||||||
|
dealloc %1 : memref<800xi8>
|
||||||
|
return %0 : memref<10x20xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_pool_bundling
|
||||||
|
// CHECK: [[CONST_0:%.+]] = constant 0 : i64
|
||||||
|
// CHECK: [[CONST_CST:%.+]] = constant 0.000000e+00 : f32
|
||||||
|
// CHECK: [[CONST_400:%.+]] = constant 400 : i64
|
||||||
|
// CHECK: [[CONST_1200:%.+]] = constant 1200 : i64
|
||||||
|
// CHECK: [[CONST_2000:%.+]] = constant 2000 : i64
|
||||||
|
// CHECK: [[CONST_2400:%.+]] = constant 2400 : i64
|
||||||
|
// CHECK: [[RES:%.+]] = alloc() : memref<10x20xf32>
|
||||||
|
// CHECK: [[MEMPOOL:%.+]] = alloc() : memref<3200xi8>
|
||||||
|
// CHECK: [[MEMREF1:%.+]] = "krnl.getref"([[MEMPOOL]], [[CONST_2400]]) : (memref<3200xi8>, i64) -> memref<10x20xf32>
|
||||||
|
// CHECK: [[MEMREF2:%.+]] = "krnl.getref"([[MEMPOOL]], [[CONST_2000]]) : (memref<3200xi8>, i64) -> memref<10x10xf32>
|
||||||
|
// CHECK: [[MEMREF3:%.+]] = "krnl.getref"([[MEMPOOL]], [[CONST_1200]]) : (memref<3200xi8>, i64) -> memref<10x20xf32>
|
||||||
|
// CHECK: [[MEMREF4:%.+]] = "krnl.getref"([[MEMPOOL]], [[CONST_400]]) : (memref<3200xi8>, i64) -> memref<10x20xf32>
|
||||||
|
// CHECK: [[MEMREF5:%.+]] = "krnl.getref"([[MEMPOOL]], [[CONST_0]]) : (memref<3200xi8>, i64) -> memref<10x10xf32>
|
||||||
|
// CHECK: affine.store %cst, [[MEMREF5]][0, 0] : memref<10x10xf32>
|
||||||
|
// CHECK: affine.store %cst, [[MEMREF4]][0, 0] : memref<10x20xf32>
|
||||||
|
// CHECK: affine.store %cst, [[MEMREF3]][0, 0] : memref<10x20xf32>
|
||||||
|
// CHECK: affine.store %cst, [[MEMREF2]][0, 0] : memref<10x10xf32>
|
||||||
|
// CHECK: affine.store %cst, [[MEMREF1]][0, 0] : memref<10x20xf32>
|
||||||
|
// CHECK: affine.store %cst, [[RES]][0, 0] : memref<10x20xf32>
|
||||||
|
// CHECK: dealloc [[MEMPOOL]] : memref<3200xi8>
|
||||||
|
// CHECK: return [[RES]] : memref<10x20xf32>
|
||||||
|
}
|
|
@ -0,0 +1,28 @@
|
||||||
|
// RUN: onnx-mlir-opt --shape-inference --lower-frontend --enable-memory-pool --bundle-memory-pools --canonicalize %s -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
func @test_bundle_memory_pool(%arg0: tensor<10x10xf32>, %arg1: tensor<10x20xf32>) -> tensor<10x20xf32> {
|
||||||
|
%0 = "onnx.Add"(%arg0, %arg0) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||||
|
%1 = "onnx.MatMul"(%0, %arg1) : (tensor<10x10xf32>, tensor<10x20xf32>) -> tensor<10x20xf32>
|
||||||
|
%2 = "onnx.Add"(%1, %arg1) : (tensor<10x20xf32>, tensor<10x20xf32>) -> tensor<10x20xf32>
|
||||||
|
%3 = "onnx.Add"(%0, %arg0) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||||
|
%4 = "onnx.MatMul"(%3, %arg1) : (tensor<10x10xf32>, tensor<10x20xf32>) -> tensor<10x20xf32>
|
||||||
|
%5 = "onnx.Add"(%4, %arg1) : (tensor<10x20xf32>, tensor<10x20xf32>) -> tensor<10x20xf32>
|
||||||
|
return %5 : tensor<10x20xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_bundle_memory_pool
|
||||||
|
// CHECK: [[CONST0:%.+]] = constant 0 : i64
|
||||||
|
// CHECK: [[CONST00:%.+]] = constant 0.000000e+00 : f32
|
||||||
|
// CHECK: [[CONST400:%.+]] = constant 400 : i64
|
||||||
|
// CHECK: [[CONST1200:%.+]] = constant 1200 : i64
|
||||||
|
// CHECK: [[CONST2000:%.+]] = constant 2000 : i64
|
||||||
|
// CHECK: [[CONST2400:%.+]] = constant 2400 : i64
|
||||||
|
// CHECK: [[RES:%.+]] = alloc() : memref<10x20xf32>
|
||||||
|
// CHECK: [[MEMPOOL:%.+]] = alloc() : memref<3200xi8>
|
||||||
|
// CHECK: "krnl.getref"([[MEMPOOL]], [[CONST2400]]) : (memref<3200xi8>, i64) -> memref<10x20xf32>
|
||||||
|
// CHECK: "krnl.getref"([[MEMPOOL]], [[CONST2000]]) : (memref<3200xi8>, i64) -> memref<10x10xf32>
|
||||||
|
// CHECK: "krnl.getref"([[MEMPOOL]], [[CONST1200]]) : (memref<3200xi8>, i64) -> memref<10x20xf32>
|
||||||
|
// CHECK: "krnl.getref"([[MEMPOOL]], [[CONST400]]) : (memref<3200xi8>, i64) -> memref<10x20xf32>
|
||||||
|
// CHECK: "krnl.getref"([[MEMPOOL]], [[CONST0]]) : (memref<3200xi8>, i64) -> memref<10x10xf32>
|
||||||
|
// CHECK: dealloc [[MEMPOOL]] : memref<3200xi8>
|
||||||
|
// CHECK: return [[RES]] : memref<10x20xf32>
|
||||||
|
}
|
Loading…
Reference in New Issue