Compiling Models with Large Constant Arrays (#146)

* PoC works.

* MNist works.

* Clean up.

* Fix test.

* Make Linux work.

* Use consistent symbol name.

* Fix variable name.

* Fix array addr access.

* Bug fix.

* Bug fix.

* install before running e2e tests.

* Fix build config.

* Use sudo when installing.

* Make embeddedDataLoader position independent.

* Enable ResNet50.

* Format code.

* Format MainUtil.

* Try not using sudo to install.

* Supply runtime dir via environment variable.

* Dump problematic operation.

* Dump entire function.

* Debug.

* Dump input.

* Dump constant op.

* Debug.

* Debug.

* Debug.

* Print to stderr.

* take care of endianness.

* Use endianness-aware execution session.

* Fix ZLinux error.

* Include warning when desired output endianness can't be deduced.

* Remove debug code.

* Remove debug code in shape inference.

* Support binary-decoder for testing constants packing.

* Support filename, move-to-file, elision-threshold configurations in constant packing pass for easy testing.

* Add lit test, fix lit test type mismatch.

* Add more consts packing tests.

* Ensure intermediate files are properly cleaned up.

* No need for constant elimination.

* Link with threading libraries.

* Remove debug code.

* Format code.

* More tests.

* test nit.

* Remove debug code.

* Reduce hard-coded constants.

* Use temporary and unique working directory for hosting model parameters.

* Test if it works.

* Try to find objcopy.

* Rename symbols using objcopy.

* Move sanitized name to linux section.

* Use verbose mode for debugging.

* Disambiguate pass constructor.

* Fix symbol name.

* Use Command API to build and execute commands.

* Move linux to use Command API.

* Fix reset args.

* Execute redefine sym.

* Format code.

* Do not use verbose mode for CircleCI.

* Remove debug code.

* Prettify code, add comments.

* getSegmentData -> getEmbeddedConstPool

* vector -> std::vector.

* Make sure we properly clean up intermediate files.

* Fix test cases.

* Add runtime directory.

* Trigger rebuild.

* [Merge with master] fix debug script.

* Diable affine fusion pass for now.

* Support generic fallback const packing mechanism.

* Remove debug code.

* Handle the case where objcopy is not available.

* Fix Windows missing types.

* Support int64.

* Copy packed constant to a local directory for non-Linux/Mac platforms.

* Nit: remove debug code, refactor const pack preprocessing out as a separate function.

* Cannot make preprocessConstPack a standalone function because file removers are stack-allocated, and they are deallocated prematurely when function stack gets popped, deleteing intermediate files too early.

* Don't require executable filename.

* Import ONNX data types directly.

* Fix LIT test.

* Bug fix, use moved string value.

* Remove redundant filenames.

* Fix CMake script.

* Embed endianness information as a symbol, and check during runtime.

* More comments, update lit tests.

* Fix lit test on BE machine.

* Copyright notices.
This commit is contained in:
Tian Jin 2020-06-12 10:27:05 +08:00 committed by GitHub
parent 8c4d527eea
commit e0ae583da0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 994 additions and 129 deletions

View File

@ -71,4 +71,4 @@ cmake -DCMAKE_INSTALL_PREFIX=${INSTALL_PATH} .. \
make -j$(nproc) onnx-mlir make -j$(nproc) onnx-mlir
make -j$(nproc) check-onnx-lit make -j$(nproc) check-onnx-lit
make -j$(nproc) check-onnx-backend RUNTIME_DIR=$(pwd)/lib make -j$(nproc) check-onnx-backend

View File

@ -40,14 +40,14 @@ jobs:
command: | command: |
sudo pip install -q -e ./onnx-mlir/third_party/onnx sudo pip install -q -e ./onnx-mlir/third_party/onnx
cd onnx-mlir/build cd onnx-mlir/build
VERBOSE=1 cmake --build . --target check-onnx-backend RUNTIME_DIR=$(pwd)/lib cmake --build . --target check-onnx-backend
- run: - run:
name: Run Unit Tests name: Run Unit Tests
command: | command: |
cd onnx-mlir/build cd onnx-mlir/build
# Need to include the bin directory in $PATH, # Need to include the bin directory in $PATH,
# otherwise CTest fails to find the test executables. # otherwise CTest fails to find the test executables.
PATH=$(pwd)/bin:$PATH make test -j$(nproc) RUNTIME_DIR=$(pwd)/lib PATH=$(pwd)/bin:$PATH make test -j$(nproc)
- run: - run:
name: Run DocCheck name: Run DocCheck
command: cd onnx-mlir/build && cmake --build . --target check-doc command: cd onnx-mlir/build && cmake --build . --target check-doc

View File

@ -269,6 +269,7 @@ set(ONNXMLIRWholeArchiveLibs
OMPromotableConstOperandsOpInterface OMPromotableConstOperandsOpInterface
OMElideConstants OMElideConstants
OMElideKrnlGlobalConstants OMElideKrnlGlobalConstants
OMPackKrnlGlobalConstants
OMEnableMemoryPool) OMEnableMemoryPool)
# Function to construct linkage option for the static libraries that must be # Function to construct linkage option for the static libraries that must be

View File

@ -7,6 +7,8 @@
// Helper methods for handling input ONNX models. // Helper methods for handling input ONNX models.
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include <llvm/Support/Endian.h>
#include <llvm/Support/SwapByteOrder.h>
#include "src/Builder/FrontendDialectHelper.hpp" #include "src/Builder/FrontendDialectHelper.hpp"
@ -104,8 +106,16 @@ static std::vector<T> CreateArrayAttribute(onnx::TensorProto initializer) {
std::copy(initializer.raw_data().begin(), initializer.raw_data().end(), std::copy(initializer.raw_data().begin(), initializer.raw_data().end(),
back_inserter(byteInitializer)); back_inserter(byteInitializer));
size = initializer.raw_data().size() / sizeof(T); size = initializer.raw_data().size() / sizeof(T);
T *res = reinterpret_cast<T *>(&byteInitializer[0]); T *arrayPtr = reinterpret_cast<T *>(&byteInitializer[0]);
return std::vector<T>(res, res + size); auto array = std::vector<T>(arrayPtr, arrayPtr + size);
// Perform byte swap if system endianness is BE.
// ONNX tensor content raw data is always in LE.
if (llvm::support::endian::system_endianness() !=
llvm::support::endianness::little)
for (int i = 0; i < array.size(); i++)
llvm::sys::swapByteOrder<T>(array[i]);
return array;
} }
// copy, no need to take care of endianness // copy, no need to take care of endianness

View File

@ -25,9 +25,6 @@ if(NOT EXISTS "${LLVM_PROJ_BUILD}/bin/llc")
message(ERROR "Cannot find llc.") message(ERROR "Cannot find llc.")
endif() endif()
# Get the compiler command name, the C++ compiler is needed to to translate
# object files to shared libraries.
get_filename_component(CXX_COMPILER_FILENAME ${CMAKE_CXX_COMPILER} NAME)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/ExternalUtil.hpp.in configure_file(${CMAKE_CURRENT_SOURCE_DIR}/ExternalUtil.hpp.in
${CMAKE_CURRENT_BINARY_DIR}/ExternalUtil.hpp) ${CMAKE_CURRENT_BINARY_DIR}/ExternalUtil.hpp)
@ -57,6 +54,7 @@ target_link_libraries(MainUtils
OMResultTypeInferenceOpInterface OMResultTypeInferenceOpInterface
OMElideConstants OMElideConstants
OMElideKrnlGlobalConstants OMElideKrnlGlobalConstants
OMPackKrnlGlobalConstants
OMEnableMemoryPool OMEnableMemoryPool
OMKrnlToAffine OMKrnlToAffine
OMKrnlToLLVM OMKrnlToLLVM
@ -71,6 +69,9 @@ if (INCLUDE_ONNX_ML)
add_dependencies(MainUtils OMMLONNXOpsInc) add_dependencies(MainUtils OMMLONNXOpsInc)
endif() endif()
add_dependencies(onnx-mlir cruntime)
add_dependencies(onnx-mlir EmbeddedDataLoader)
target_include_directories(onnx-mlir PRIVATE ${ONNX_MLIR_SRC_ROOT}) target_include_directories(onnx-mlir PRIVATE ${ONNX_MLIR_SRC_ROOT})
target_include_directories(onnx-mlir PRIVATE ${CMAKE_BINARY_DIR}) target_include_directories(onnx-mlir PRIVATE ${CMAKE_BINARY_DIR})
target_include_directories(onnx-mlir PRIVATE ${ONNX_MLIR_BIN_ROOT}) target_include_directories(onnx-mlir PRIVATE ${ONNX_MLIR_BIN_ROOT})

View File

@ -38,9 +38,11 @@ struct ONNXConstantOpLowering : public ConversionPattern {
// Emit the constant global in Krnl dialect. // Emit the constant global in Krnl dialect.
auto constantGlobal = rewriter.create<KrnlGlobalOp>(loc, memRefType, auto constantGlobal = rewriter.create<KrnlGlobalOp>(loc, memRefType,
rewriter.getI64ArrayAttr(shape), /*shape=*/rewriter.getI64ArrayAttr(shape),
/*name=*/
rewriter.getStringAttr("constant_" + std::to_string(constantID)), rewriter.getStringAttr("constant_" + std::to_string(constantID)),
constantOp.value().getValue()); /*value=*/constantOp.value().getValue(),
/*offset=*/nullptr);
// Increment constant ID: // Increment constant ID:
constantID++; constantID++;

View File

@ -196,16 +196,66 @@ def KrnlMemcpyOp : Op<Krnl_Dialect, "memcpy"> {
def KrnlGlobalOp : Op<Krnl_Dialect, "global"> { def KrnlGlobalOp : Op<Krnl_Dialect, "global"> {
let summary = "Krnl global operation"; let summary = "Krnl global operation";
let description = [{ let description = [{
Operation for holding global data values. Operation for holding global data values. A global constant can have a
meaningful name recorded as its `name` attribute. Its content is stored
in the `value` dense element attribute. Alternatively, if the constants
are packed together, `offset` records the byte offset in the global
constant pool from which the current constant is to be read.
}]; }];
let arguments = (ins AnyAttr:$shape, StrAttr:$name, OptionalAttr<AnyAttr>:$value); let arguments = (ins AnyAttr:$shape,
StrAttr:$name, OptionalAttr<AnyAttr>:$value, OptionalAttr<I64Attr>:$offset);
let results = (outs AnyTypeOf<[AnyMemRef]>:$output); let results = (outs AnyTypeOf<[AnyMemRef]>:$output);
let parser = ?; let parser = ?;
let printer = ?; let printer = ?;
} }
def KrnlPackedConstantOp : Op<Krnl_Dialect, "packed_const"> {
let summary = "Krnl packed constant operation";
let description = [{
Operation for holding packed constants.
}];
let arguments = (ins I64Attr:$size_in_bytes,
BoolAttr:$is_le,
OptionalAttr<AnyIntElementsAttr<8>>:$value,
OptionalAttr<StrAttr>:$file_name);
let results = (outs I64:$output);
let extraClassDeclaration = [{
// The *path* to the file storing the constant pack on disk is encoded
// as a global variable in the LLVM module of the lowered model.
// getConstPackFilePathSymbolName returns the name of this symbol;
// getConstPackFilePathStrLenSymbolName returns the name of the symbol holding
// a constant value equal to the length of the file path.
static StringRef getConstPackFilePathSymbolName() { return "constPackFilePath"; }
static StringRef getConstPackFilePathStrLenSymbolName() { return "constPackFilePathStrLen"; }
// The *name* of the file storing the constant pack is also recorded for
// convenience. Similarly, getConstPackFileNameSymbolName and
// getConstPackFileNameStrLenSymbolName returns records the symbol holding
// the string constant representing the filename and the length of this
// string constant.
static StringRef getConstPackFileNameSymbolName() { return "constPackFileName"; }
static StringRef getConstPackFileNameStrLenSymbolName() { return "constPackFileNameStrLen"; }
// We record whether the constant pack is stored in LE byte order as a
// int8 symbol; it is to be interpreted as a boolean switch - with 0
// meaning that the constant pack is not stored in LE byte order and
// non-0 values meaning that it is stored in LE byte order.
static StringRef getConstPackIsLESymbolName() { return "constPackIsLE"; }
// The name of a function we call to read packed constants embedded within
// the current binary executable/library, or in the case of unsupported platform,
// from a binary constant pack file.
static StringRef getEmbeddedDataLoaderMethodName() {
return "getEmbeddedConstPool";
}
}];
let parser = ?;
let printer = ?;
}
def KrnlGetRefOp : Op<Krnl_Dialect, "getref"> { def KrnlGetRefOp : Op<Krnl_Dialect, "getref"> {
let summary = "Krnl a MemRef from within another MemRef starting at a specific offset."; let summary = "Krnl a MemRef from within another MemRef starting at a specific offset.";
let description = [{ let description = [{

View File

@ -1138,7 +1138,6 @@ LogicalResult ONNXReshapeOp::inferShapes() {
if (constantOp) { if (constantOp) {
DenseElementsAttr valueAttribute = DenseElementsAttr valueAttribute =
constantOp.valueAttr().dyn_cast<DenseElementsAttr>(); constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
if (!valueAttribute) if (!valueAttribute)
return emitError("DenseElementsAttr expected"); return emitError("DenseElementsAttr expected");
// Get dims from valueAttribute. // Get dims from valueAttribute.

View File

@ -1,9 +1,10 @@
#pragma once #pragma once
#include <string> #include <string>
namespace onnx_mlir { namespace onnx_mlir {
const std::string kLlcPath = "@LLVM_PROJ_BUILD@/bin/llc"; const std::string kLlcPath = "@LLVM_PROJ_BUILD@/bin/llc";
const std::string kCxxPath = "@CMAKE_CXX_COMPILER@"; const std::string kCxxPath = "@CMAKE_CXX_COMPILER@";
const std::string kCxxFileName = "@CXX_COMPILER_FILENAME@"; const std::string kLinkerPath = "@CMAKE_LINKER@";
const std::string kRuntimeDirPath = "@CMAKE_BINARY_DIR@/lib"; const std::string kObjCopyPath = "@CMAKE_OBJCOPY@";
} } // namespace onnx_mlir

View File

@ -9,8 +9,15 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include <cstdio> #include <cstdio>
#include <cstdlib>
#include <fcntl.h> #include <fcntl.h>
#include <regex>
#include <string>
#include <llvm/Support/FileSystem.h>
#include <llvm/Support/Program.h> #include <llvm/Support/Program.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/IR/SymbolTable.h>
#include "src/ExternalUtil.hpp" #include "src/ExternalUtil.hpp"
#include "src/MainUtils.hpp" #include "src/MainUtils.hpp"
@ -26,6 +33,63 @@
using namespace std; using namespace std;
using namespace onnx_mlir; using namespace onnx_mlir;
namespace {
llvm::Optional<std::string> getEnvVar(std::string name) {
if (const char *envVerbose = std::getenv(name.c_str()))
return std::string(envVerbose);
return llvm::None;
}
// Helper struct to make command construction and execution easy & readable.
struct Command {
std::string _path;
std::vector<std::string> _args;
Command(std::string exePath)
: _path(std::move(exePath)),
_args({llvm::sys::path::filename(_path).str()}) {}
// Append a single string argument.
Command &appendStr(const std::string &arg) {
_args.emplace_back(arg);
return *this;
}
// Append a list of string arguments.
Command &appendList(const std::vector<std::string> &args) {
_args.insert(_args.end(), args.begin(), args.end());
return *this;
}
// Reset arguments.
Command &resetArgs() {
auto exeFileName = _args.front();
_args.clear();
_args.emplace_back(exeFileName);
return *this;
}
// Execute command.
void exec() {
auto argsRef = std::vector<llvm::StringRef>(_args.begin(), _args.end());
bool verbose = false;
if (const auto &verboseStr = getEnvVar("VERBOSE"))
istringstream(verboseStr.getValue()) >> verbose;
// If in verbose mode, print out command before execution.
if (verbose)
cout << llvm::join(argsRef, " ") << "\n";
int rc = llvm::sys::ExecuteAndWait(_path, llvm::makeArrayRef(argsRef));
if (rc != 0) {
fprintf(stderr, "%s\n", llvm::join(argsRef, " ").c_str());
llvm_unreachable("Command execution failed.");
}
}
};
} // namespace
void LoadMLIR(string inputFilename, mlir::MLIRContext &context, void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
mlir::OwningModuleRef &module) { mlir::OwningModuleRef &module) {
// Handle '.mlir' input to the ONNX MLIR frontend. // Handle '.mlir' input to the ONNX MLIR frontend.
@ -50,31 +114,123 @@ void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
void compileModuleToSharedLibrary( void compileModuleToSharedLibrary(
const mlir::OwningModuleRef &module, string outputBaseName) { const mlir::OwningModuleRef &module, string outputBaseName) {
// Extract constant pack file name, which is embedded as a symbol in the
// module being compiled.
auto constPackFilePathSym = (*module).lookupSymbol<mlir::LLVM::GlobalOp>(
mlir::KrnlPackedConstantOp::getConstPackFilePathSymbolName());
auto constPackFilePath = constPackFilePathSym.valueAttr()
.dyn_cast_or_null<mlir::StringAttr>()
.getValue()
.str();
llvm::FileRemover constPackRemover(constPackFilePath);
llvm::Optional<std::string> constPackObjPath;
#if __APPLE__
// Create a empty stub file, compile it to an empty obj file.
llvm::SmallVector<char, 20> stubSrcPath;
llvm::sys::fs::createTemporaryFile("stub", "cpp", stubSrcPath);
llvm::FileRemover subSrcRemover(stubSrcPath);
std::string stubSrcPathStr(stubSrcPath.begin(), stubSrcPath.end());
Command createStubObj(/*exePath=*/kCxxPath);
std::string stubObjPathStr = stubSrcPathStr + ".o";
createStubObj.appendList({"-o", stubObjPathStr})
.appendList({"-c", stubSrcPathStr})
.exec();
llvm::FileRemover stubObjRemover(stubObjPathStr);
// Embed data into the empty stub obj file.
constPackObjPath = constPackFilePath + ".o";
Command genParamObj(/*exePath=*/kLinkerPath);
genParamObj.appendStr("-r")
.appendList({"-o", constPackObjPath.getValue()})
.appendList({"-sectcreate", "binary", "param", constPackFilePath})
.appendStr(stubObjPathStr)
.exec();
llvm::FileRemover constPackObjRemover(constPackObjPath.getValue());
#elif __linux__
// Create param.o holding packed parameter values.
constPackObjPath = constPackFilePath + ".o";
Command genParamObj(/*exePath=*/kLinkerPath);
genParamObj.appendStr("-r")
.appendList({"-b", "binary"})
.appendList({"-o", constPackObjPath.getValue()})
.appendStr(constPackFilePath)
.exec();
llvm::FileRemover constPackObjRemover(constPackObjPath.getValue());
// Figure out what is the default symbol name describing the start/end
// address of the embedded data.
std::regex e("[^0-9A-Za-z]");
auto sanitizedName =
"_binary_" + std::regex_replace(constPackFilePath, e, "_");
// Rename the symbols to saner ones expected by the runtime function.
Command redefineSym(/*exePath=*/kObjCopyPath);
redefineSym.appendStr("--redefine-sym")
.appendStr(sanitizedName + "_start=_binary_param_bin_start")
.appendStr(constPackObjPath.getValue())
.exec();
redefineSym.resetArgs()
.appendStr("--redefine-sym")
.appendStr(sanitizedName + "_end=_binary_param_bin_end")
.appendStr(constPackObjPath.getValue())
.exec();
#else
llvm::SmallVector<char, 10> permConstPackFileName(
constPackFilePath.begin(), constPackFilePath.end());
llvm::sys::path::replace_extension(permConstPackFileName, "bin");
std::string permConstPackFileNameStr(
permConstPackFileName.begin(), permConstPackFileName.end());
auto constPackFileName = llvm::sys::path::filename(outputBaseName) + "." +
llvm::sys::path::filename(permConstPackFileNameStr);
llvm::sys::fs::rename(constPackFilePath, constPackFileName);
mlir::Builder builder(*module);
(*module)
.lookupSymbol<mlir::LLVM::GlobalOp>(
mlir::KrnlPackedConstantOp::getConstPackFileNameSymbolName())
.valueAttr(builder.getStringAttr(constPackFileName.str()));
(*module)
.lookupSymbol<mlir::LLVM::GlobalOp>(
mlir::KrnlPackedConstantOp::getConstPackFileNameStrLenSymbolName())
.valueAttr(builder.getI64IntegerAttr(constPackFileName.str().size()));
#endif
// Write LLVM bitcode. // Write LLVM bitcode.
string outputFilename = outputBaseName + ".bc"; string outputFilename = outputBaseName + ".bc";
error_code error; error_code error;
llvm::raw_fd_ostream moduleBitcodeStream( llvm::raw_fd_ostream moduleBitcodeStream(
outputFilename, error, llvm::sys::fs::F_None); outputFilename, error, llvm::sys::fs::F_None);
llvm::WriteBitcodeToFile( llvm::WriteBitcodeToFile(
*mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream); *mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream);
moduleBitcodeStream.flush(); moduleBitcodeStream.flush();
llvm::FileRemover bcRemover(outputFilename);
// Compile bitcode to object file. // Compile LLVM bitcode to object file.
std::vector<std::string> llcArgs = { Command llvmToObj(/*exePath=*/kLlcPath);
"llc", "-filetype=obj", "-relocation-model=pic", outputFilename}; llvmToObj.appendStr("-filetype=obj");
auto llcArgStrRefs = llvmToObj.appendStr("-relocation-model=pic");
std::vector<llvm::StringRef>(llcArgs.begin(), llcArgs.end()); llvmToObj.appendStr(outputFilename);
llvm::sys::ExecuteAndWait(kLlcPath, llvm::makeArrayRef(llcArgStrRefs)); llvmToObj.exec();
std::string modelObjPath = outputBaseName + ".o";
llvm::FileRemover modelObjRemover(modelObjPath);
// Link with runtime. std::string runtimeDirInclFlag;
// TODO(tjingrant): link with runtime library in LLVM, and make the shared if (getEnvVar("RUNTIME_DIR").hasValue())
// library more self-contained. runtimeDirInclFlag = "-L" + getEnvVar("RUNTIME_DIR").getValue();
std::vector<std::string> cxxArgs = {kCxxFileName, "-shared", "-fPIC",
outputBaseName + ".o", "-o", outputBaseName + ".so", // Link everything into a shared object.
"-L" + kRuntimeDirPath, "-lcruntime", "-Wl,-rpath," + kRuntimeDirPath}; Command link(kCxxPath);
auto argsArrayRefVector = link.appendList({"-shared", "-fPIC"})
std::vector<llvm::StringRef>(cxxArgs.begin(), cxxArgs.end()); .appendStr(modelObjPath)
llvm::sys::ExecuteAndWait(kCxxPath, llvm::makeArrayRef(argsArrayRefVector)); .appendStr(constPackObjPath.getValueOr(""))
.appendList({"-o", outputBaseName + ".so"})
.appendStr(runtimeDirInclFlag)
.appendList({"-lEmbeddedDataLoader", "-lcruntime"})
.exec();
} }
void registerDialects() { void registerDialects() {
@ -98,6 +254,7 @@ void addONNXToMLIRPasses(mlir::PassManager &pm) {
void addONNXToKrnlPasses(mlir::PassManager &pm) { void addONNXToKrnlPasses(mlir::PassManager &pm) {
pm.addPass(mlir::createLowerToKrnlPass()); pm.addPass(mlir::createLowerToKrnlPass());
pm.addPass(mlir::createPackKrnlGlobalConstantsPass());
// An additional pass of canonicalization is helpful because lowering // An additional pass of canonicalization is helpful because lowering
// from ONNX dialect to Standard dialect exposes additional canonicalization // from ONNX dialect to Standard dialect exposes additional canonicalization
// oppertunities. // oppertunities.
@ -110,7 +267,7 @@ void addONNXToKrnlPasses(mlir::PassManager &pm) {
void addKrnlToAffinePasses(mlir::PassManager &pm) { void addKrnlToAffinePasses(mlir::PassManager &pm) {
pm.addPass(mlir::createLowerKrnlPass()); pm.addPass(mlir::createLowerKrnlPass());
// Fuse loops in Affine dialect. // Fuse loops in Affine dialect.
pm.addPass(mlir::createLoopFusionPass()); // pm.addPass(mlir::createLoopFusionPass());
} }
void addKrnlToLLVMPasses(mlir::PassManager &pm) { void addKrnlToLLVMPasses(mlir::PassManager &pm) {

View File

@ -43,4 +43,7 @@ std::unique_ptr<Pass> createElideConstGlobalValuePass();
/// Pass for lowering Krnl dialect to LLVM dialect. /// Pass for lowering Krnl dialect to LLVM dialect.
std::unique_ptr<Pass> createKrnlLowerToLLVMPass(); std::unique_ptr<Pass> createKrnlLowerToLLVMPass();
/// Pass for packing Krnl global constants.
std::unique_ptr<Pass> createPackKrnlGlobalConstantsPass();
} // end namespace mlir } // end namespace mlir

View File

@ -1,6 +1,6 @@
# Create shared libcruntime.so since model.so linkage for backend tests # Create shared libcruntime.so since model.so linkage for backend tests
# will fail on x86 Linux if cruntime is statically linked. # will fail on x86 Linux if cruntime is statically linked.
add_library(cruntime SHARED add_library(cruntime STATIC
DynMemRef.cpp DynMemRef.cpp
DynMemRef.h DynMemRef.h
DataType.h) DataType.h)
@ -34,6 +34,13 @@ target_include_directories(PyRuntime PRIVATE
${ONNX_MLIR_BIN_ROOT} ${ONNX_MLIR_BIN_ROOT}
${ONNX_MLIR_SRC_ROOT}) ${ONNX_MLIR_SRC_ROOT})
add_library(EmbeddedDataLoader STATIC
GetEmbeddedConstPool.h
GetEmbeddedConstPool.cpp)
set_target_properties(EmbeddedDataLoader PROPERTIES
POSITION_INDEPENDENT_CODE TRUE)
add_dependencies(PyRuntime cruntime) add_dependencies(PyRuntime cruntime)
install(FILES DynMemRef.h DESTINATION include) install(FILES DynMemRef.h DESTINATION include)
install(TARGETS cruntime DESTINATION lib) install(TARGETS cruntime DESTINATION lib)
install(TARGETS EmbeddedDataLoader DESTINATION lib)

View File

@ -14,6 +14,10 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "DynMemRef.h" #include "DynMemRef.h"
namespace { namespace {

View File

@ -0,0 +1,88 @@
//===---- GetEmbeddedConstPool.h - Get Embedded Const Pool API Func Impl---===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains runtime API implementation to extract constant pool values
// embedded within the shared library binary files.
//
//===----------------------------------------------------------------------===//
#include "GetEmbeddedConstPool.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
// Adapted from:
// https://developer.ibm.com/technologies/systems/articles/au-endianc/
const int i = 1;
#define IS_SYSTEM_LE() (!((*(char *)&i) == 0))
#define XOR(a, b) (!(a) != !(b))
extern const char constPackIsLE;
void checkEndianness() {
if (XOR(IS_SYSTEM_LE(), constPackIsLE)) {
fprintf(stderr, "Constant pack is stored in a byte order that is not "
"native to this current system.");
exit(1);
}
}
#if __APPLE__
#include <mach-o/getsect.h>
extern const struct mach_header_64 _mh_dylib_header;
void *getEmbeddedConstPool(int64_t size_in_byte) {
checkEndianness();
size_t size = size_in_byte;
unsigned char *data =
getsectiondata(&_mh_dylib_header, "binary", "param", &size);
float *data_ptr = (float *)data;
void *buffer = malloc(size);
memcpy(buffer, data, size);
return data;
}
#elif __linux__
extern char _binary_param_bin_start;
extern char _binary_param_bin_end;
void *getEmbeddedConstPool(int64_t _) {
checkEndianness();
auto size = (unsigned int)(&_binary_param_bin_end - &_binary_param_bin_start);
void *buffer = malloc(size);
memcpy(buffer, &_binary_param_bin_start, size);
return buffer;
}
#else
extern char constPackFileName[];
extern int64_t constPackFileNameStrLen;
void *getEmbeddedConstPool(int64_t _) {
checkEndianness();
char *fname = (char *)calloc(1, constPackFileNameStrLen + 1);
memcpy(fname, constPackFileName, constPackFileNameStrLen);
// Adapted from https://stackoverflow.com/a/22059317 .
FILE *fileptr;
char *buffer;
long filelen;
fileptr = fopen(fname, "rb"); // Open the file in binary mode
fseek(fileptr, 0, SEEK_END); // Jump to the end of the file
filelen = ftell(fileptr); // Get the current byte offset in the file
rewind(fileptr); // Jump back to the beginning of the file
buffer = (char *)malloc(filelen * sizeof(char)); // Enough memory for the file
fread(buffer, filelen, 1, fileptr); // Read in the entire file
fclose(fileptr); // Close the file
return (void *)buffer;
}
#endif

View File

@ -0,0 +1,18 @@
//===---- GetEmbeddedConstPool.h - Get Embedded Const Pool API Func Decl---===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains runtime API declarations to extract constant pool values
// embedded within the shared library binary files.
//
//===----------------------------------------------------------------------===//
#pragma once
#include <stdint.h>
extern "C" {
void *getEmbeddedConstPool(int64_t size_in_byte);
}

View File

@ -0,0 +1,95 @@
//===----- BinaryDecoder.cpp - Decode binary files into typed arrays ------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains implementation of a utility called BinaryDecoder, which
// decodes a sequence of binary data within a binary file specified by an
// offset and a length into a typed array and print to stdout.
//
//===----------------------------------------------------------------------===//
#include <fstream>
#include <iostream>
#include <string>
#include "onnx/onnx_pb.h"
#include <llvm/Support/CommandLine.h>
#if defined(_WIN32)
#include <stdint.h>
typedef uint8_t u_int8_t;
typedef uint16_t u_int16_t;
typedef uint32_t u_int32_t;
typedef uint64_t u_int64_t;
#endif
llvm::cl::opt<std::string> Filename(
llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::Required);
llvm::cl::opt<int64_t> Start("s",
llvm::cl::desc("Specify the index of the starting byte"),
llvm::cl::value_desc("start"), llvm::cl::Required);
llvm::cl::opt<int64_t> Size("n",
llvm::cl::desc("Specify the number of bytes of data to decode"),
llvm::cl::value_desc("size"), llvm::cl::Required);
llvm::cl::opt<bool> Remove(
"rm", llvm::cl::desc(
"Whether to remove the file being decoded after inspection."));
llvm::cl::opt<onnx::TensorProto::DataType> DataType(
llvm::cl::desc("Choose data type to decode:"),
llvm::cl::values(clEnumVal(onnx::TensorProto::FLOAT, "FLOAT"),
clEnumVal(onnx::TensorProto::UINT8, "UINT8"),
clEnumVal(onnx::TensorProto::INT8, "INT8"),
clEnumVal(onnx::TensorProto::UINT16, "UINT16"),
clEnumVal(onnx::TensorProto::INT16, "INT16"),
clEnumVal(onnx::TensorProto::INT32, "INT32"),
clEnumVal(onnx::TensorProto::INT64, "INT64"),
clEnumVal(onnx::TensorProto::STRING, "STRING"),
clEnumVal(onnx::TensorProto::BOOL, "BOOL"),
clEnumVal(onnx::TensorProto::FLOAT16, "FLOAT16"),
clEnumVal(onnx::TensorProto::DOUBLE, "DOUBLE"),
clEnumVal(onnx::TensorProto::UINT32, "UINT32"),
clEnumVal(onnx::TensorProto::UINT64, "UINT64")));
template <typename T>
int printBuffer(std::vector<char> buffer) {
auto *ptr = (T *)&buffer[0];
auto data = std::vector<T>(ptr, ptr + buffer.size() / sizeof(T));
for (const auto &elem : data)
std::cout << elem << " ";
return 0;
}
int main(int argc, char **argv) {
llvm::cl::ParseCommandLineOptions(argc, argv);
std::vector<char> buffer(Size);
std::ifstream file(Filename, std::ios::in | std::ios::binary);
if (!file)
return -1;
file.seekg(Start, file.beg);
file.read(&buffer[0], Size);
file.close();
if (Remove)
llvm::sys::fs::remove(Filename);
#define PRINT_BUFFER_FOR_TYPE(ONNX_TYPE, CPP_TYPE) \
if (DataType == ONNX_TYPE) \
return printBuffer<CPP_TYPE>(buffer);
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::UINT8, u_int8_t);
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::UINT16, u_int16_t);
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::INT16, int16_t);
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::INT32, int32_t);
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::INT64, int64_t);
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::FLOAT, float);
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::DOUBLE, double);
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::UINT32, u_int32_t);
PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::UINT64, u_int64_t);
}

View File

@ -0,0 +1,9 @@
add_executable(binary-decoder BinaryDecoder.cpp)
target_link_libraries(binary-decoder
${LLVMSupport}
${LLVMDemangle}
${CURSES_LIBRARIES}
${CMAKE_THREAD_LIBS_INIT})
message(STATUS "incl dir" ${ONNX_INCLUDE_DIRS})
target_link_libraries(binary-decoder onnx)

View File

@ -1 +1,2 @@
add_subdirectory(ONNXMLIROpt) add_subdirectory(ONNXMLIROpt)
add_subdirectory(BinaryDecoder)

View File

@ -27,7 +27,8 @@ target_link_libraries(OMKrnlToLLVM
onnx) onnx)
add_library(OMElideKrnlGlobalConstants add_library(OMElideKrnlGlobalConstants
ElideKrnlGlobalConstants.cpp) ElideKrnlGlobalConstants.cpp
ElideKrnlGlobalConstants.hpp)
target_include_directories(OMElideKrnlGlobalConstants target_include_directories(OMElideKrnlGlobalConstants
PRIVATE PRIVATE
${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_SRC_ROOT}
@ -38,6 +39,14 @@ add_dependencies(OMElideKrnlGlobalConstants OMKrnlOpsInc)
# Linking dependencies # Linking dependencies
add_dependencies(OMElideKrnlGlobalConstants OMKrnlOps) add_dependencies(OMElideKrnlGlobalConstants OMKrnlOps)
add_library(OMPackKrnlGlobalConstants
PackKrnlGlobalConstants.cpp)
target_include_directories(OMPackKrnlGlobalConstants
PRIVATE
${ONNX_MLIR_SRC_ROOT}
${ONNX_MLIR_BIN_ROOT}
${ONNX_MLIR_SRC_ROOT})
add_library(OMEnableMemoryPool add_library(OMEnableMemoryPool
EnableMemoryPool.cpp) EnableMemoryPool.cpp)
target_include_directories(OMEnableMemoryPool target_include_directories(OMEnableMemoryPool
@ -45,6 +54,9 @@ target_include_directories(OMEnableMemoryPool
${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_SRC_ROOT}
${ONNX_MLIR_BIN_ROOT} ${ONNX_MLIR_BIN_ROOT}
${ONNX_MLIR_SRC_ROOT}) ${ONNX_MLIR_SRC_ROOT})
add_dependencies(OMPackKrnlGlobalConstants
OMKrnlOps)
target_link_libraries(OMEnableMemoryPool target_link_libraries(OMEnableMemoryPool
onnx) onnx)
add_dependencies(OMEnableMemoryPool add_dependencies(OMEnableMemoryPool

View File

@ -22,34 +22,31 @@
#include "src/Dialect/Krnl/KrnlOps.hpp" #include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Pass/Passes.hpp" #include "src/Pass/Passes.hpp"
#include "ElideKrnlGlobalConstants.hpp"
using namespace mlir; using namespace mlir;
namespace { const int64_t KrnlConstGlobalValueElision::kDefaultElisionThreshold = 32;
/*! mlir::LogicalResult KrnlConstGlobalValueElision::matchAndRewrite(
* RewritePattern that replaces existing constant Krnl global values mlir::KrnlGlobalOp op, mlir::PatternRewriter &rewriter) const {
* with a similar operation which preserves all attributes except the value
* attribute.
*/
class KrnlConstGlobalValueElision : public OpRewritePattern<KrnlGlobalOp> {
public:
using OpRewritePattern<KrnlGlobalOp>::OpRewritePattern;
LogicalResult matchAndRewrite(
KrnlGlobalOp op, PatternRewriter &rewriter) const override {
auto loc = op.getLoc(); auto loc = op.getLoc();
if (op.value().hasValue()) { if (op.value().hasValue()) {
auto newGlobalOp = rewriter.create<KrnlGlobalOp>( const auto &valAttr = op.valueAttr().dyn_cast_or_null<DenseElementsAttr>();
loc, op.getResult().getType(), op.shape(), op.name(), nullptr); if (valAttr.getNumElements() > elisionThreshold) {
IntegerAttr offsetAttr = op.offset() ? op.offsetAttr() : nullptr;
auto newGlobalOp = rewriter.create<KrnlGlobalOp>(loc,
op.getResult().getType(), /*shape=*/op.shape(),
/*name=*/op.name(), /*value=*/nullptr, /*offset=*/offsetAttr);
rewriter.replaceOp(op, newGlobalOp.getResult()); rewriter.replaceOp(op, newGlobalOp.getResult());
} }
}
return success(); return success();
} }
};
namespace {
/*! /*!
* Function pass that performs constant value elision of Krnl globals. * Function pass that performs constant value elision of Krnl globals.
*/ */
@ -66,6 +63,7 @@ public:
applyPatternsAndFoldGreedily(function, patterns); applyPatternsAndFoldGreedily(function, patterns);
} }
}; };
} // namespace } // namespace
std::unique_ptr<Pass> mlir::createElideConstGlobalValuePass() { std::unique_ptr<Pass> mlir::createElideConstGlobalValuePass() {

View File

@ -0,0 +1,35 @@
#pragma once
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Dialect/Krnl/KrnlOps.hpp"
/*!
* RewritePattern that replaces existing constant Krnl global values
* with a similar operation which preserves all attributes except the value
* attribute.
*/
class KrnlConstGlobalValueElision
: public mlir::OpRewritePattern<mlir::KrnlGlobalOp> {
public:
/*
* A threshold value specifying the maximum number of elements a constant
* operation can hold as an attribute. If the number exceeds this threshold,
* constants will be packed together and, in the case where `move-to-file`
* option is enabled, stored as a binary file on disk. This can help preserve
* readability of IR dump and improve compilation speed.
*/
static const int64_t kDefaultElisionThreshold;
int64_t elisionThreshold;
using mlir::OpRewritePattern<mlir::KrnlGlobalOp>::OpRewritePattern;
explicit KrnlConstGlobalValueElision(
mlir::MLIRContext *context, int64_t elisionThreshold)
: OpRewritePattern(context), elisionThreshold(elisionThreshold) {}
mlir::LogicalResult matchAndRewrite(
mlir::KrnlGlobalOp op, mlir::PatternRewriter &rewriter) const override;
};

View File

@ -32,10 +32,9 @@ namespace {
static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName, static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName,
ModuleOp module, mlir::LLVM::LLVMType funcType, PatternRewriter &rewriter) { ModuleOp module, mlir::LLVM::LLVMType funcType, PatternRewriter &rewriter) {
auto *context = module.getContext(); auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>(funcName)) { if (auto sym = module.lookupSymbol<LLVM::LLVMFuncOp>(funcName)) {
auto symbolRef = SymbolRefAttr::get(funcName, context); assert(sym.getType() == funcType && "wrong symbol type");
assert(symbolRef.getType() == funcType && "wrong symbol type"); return SymbolRefAttr::get(funcName, context);
return symbolRef;
} }
// Insert the function into the body of the parent module. // Insert the function into the body of the parent module.
@ -203,6 +202,8 @@ public:
// The llvm type of the global (example: [2 x [8 x float]]) // The llvm type of the global (example: [2 x [8 x float]])
auto llvmGlobalType = globalType.cast<LLVM::LLVMType>(); auto llvmGlobalType = globalType.cast<LLVM::LLVMType>();
mlir::Value alloc;
if (krnlGlobalOp.value().hasValue()) {
{ {
OpBuilder::InsertionGuard insertGuard(rewriter); OpBuilder::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody()); rewriter.setInsertionPointToStart(module.getBody());
@ -222,7 +223,7 @@ public:
// This is a region of local memory and needs to be emitted as an alloca. // This is a region of local memory and needs to be emitted as an alloca.
auto one = rewriter.create<LLVM::ConstantOp>( auto one = rewriter.create<LLVM::ConstantOp>(
loc, llvmI64Ty, rewriter.getI64IntegerAttr(1)); loc, llvmI64Ty, rewriter.getI64IntegerAttr(1));
auto alloc = rewriter.create<LLVM::AllocaOp>( alloc = rewriter.create<LLVM::AllocaOp>(
loc, llvmGlobalType.getPointerTo(), one, /*alignment=*/0); loc, llvmGlobalType.getPointerTo(), one, /*alignment=*/0);
// Copy constant value into the local alloca: // Copy constant value into the local alloca:
@ -234,12 +235,13 @@ public:
Value i8PtrGlobal = Value i8PtrGlobal =
rewriter.create<LLVM::BitcastOp>(loc, llvmI8PtrTy, globalValue); rewriter.create<LLVM::BitcastOp>(loc, llvmI8PtrTy, globalValue);
// - Set size. // - Set size.
Value memRefElementSize = rewriter.create<LLVM::ConstantOp>(loc, llvmI64Ty, Value memRefElementSize =
rewriter.create<LLVM::ConstantOp>(loc, llvmI64Ty,
rewriter.getI64IntegerAttr(getMemRefEltSizeInBytes(memRefTy))); rewriter.getI64IntegerAttr(getMemRefEltSizeInBytes(memRefTy)));
Value numElementsValue = rewriter.create<LLVM::ConstantOp>( Value numElementsValue = rewriter.create<LLVM::ConstantOp>(
loc, llvmI64Ty, rewriter.getI64IntegerAttr(numElements)); loc, llvmI64Ty, rewriter.getI64IntegerAttr(numElements));
Value totalElementsSize = Value totalElementsSize = rewriter.create<LLVM::MulOp>(
rewriter.create<LLVM::MulOp>(loc, memRefElementSize, numElementsValue); loc, memRefElementSize, numElementsValue);
Value int64Size = Value int64Size =
rewriter.create<LLVM::SExtOp>(loc, llvmI64Ty, totalElementsSize); rewriter.create<LLVM::SExtOp>(loc, llvmI64Ty, totalElementsSize);
// - Set volatile. // - Set volatile.
@ -251,7 +253,29 @@ public:
rewriter.create<CallOp>(loc, memcpyRef, rewriter.create<CallOp>(loc, memcpyRef,
LLVM::LLVMType::getVoidTy(llvmDialect), LLVM::LLVMType::getVoidTy(llvmDialect),
ArrayRef<Value>({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile})); ArrayRef<Value>({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile}));
} else {
// Some frequently used types.
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
// Allocate the memory where the constants will be used from.
// This is a region of local memory and needs to be emitted as an alloca.
auto one = rewriter.create<LLVM::ConstantOp>(
loc, llvmI64Ty, rewriter.getI64IntegerAttr(1));
auto base = module.lookupSymbol<LLVM::GlobalOp>("packedConst");
assert(base && "Cannot find symbol packedConst.");
Value constPackBasePtrAddr =
rewriter.create<LLVM::AddressOfOp>(loc, base);
Value constPackBasePtr = rewriter.create<LLVM::LoadOp>(
loc, base.getType(), constPackBasePtrAddr);
auto offset = rewriter.create<LLVM::ConstantOp>(loc, llvmI64Ty,
rewriter.getI64IntegerAttr(
krnlGlobalOp.offsetAttr().getValue().getSExtValue()));
alloc = rewriter.create<LLVM::GEPOp>(
loc, llvmI8PtrTy, constPackBasePtr, ValueRange({offset}));
}
// Prepare data to be inserted into MemRef. // Prepare data to be inserted into MemRef.
auto llvmConstantElementType = constantElementType.cast<LLVM::LLVMType>(); auto llvmConstantElementType = constantElementType.cast<LLVM::LLVMType>();
Value typedAlloc = rewriter.create<LLVM::BitcastOp>( Value typedAlloc = rewriter.create<LLVM::BitcastOp>(
@ -681,6 +705,115 @@ private:
} }
} }
}; };
//===----------------------------------------------------------------------===//
// KRNL to LLVM: KrnlPackedConstOpLowering
//===----------------------------------------------------------------------===//
class KrnlPackedConstOpLowering : public ConvertToLLVMPattern {
public:
explicit KrnlPackedConstOpLowering(
MLIRContext *context, LLVMTypeConverter &lowering_)
: ConvertToLLVMPattern(
KrnlPackedConstantOp::getOperationName(), context, lowering_) {}
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto *context = op->getContext();
ModuleOp module = op->getParentOfType<ModuleOp>();
auto loc = op->getLoc();
auto *llvmDialect =
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
assert(llvmDialect && "expected llvm dialect to be registered");
auto packedConstOp = llvm::dyn_cast<KrnlPackedConstantOp>(op);
LLVM::GlobalOp globalBase;
// Some frequently used types.
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
{
OpBuilder::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
globalBase = rewriter.create<LLVM::GlobalOp>(loc, llvmI8PtrTy,
/*isConstant=*/false, LLVM::Linkage::Internal, "packedConst",
nullptr);
}
auto mainFunc = module.lookupSymbol<FuncOp>("main_graph");
assert(mainFunc);
rewriter.setInsertionPoint(
&mainFunc.getBody().front(), mainFunc.getBody().front().begin());
// - Initialize the global constant base.
Value basePtrAddr = rewriter.create<LLVM::AddressOfOp>(loc, globalBase);
auto getEmbeddedConstPoolRef = getOrInsertExternFunc(
KrnlPackedConstantOp::getEmbeddedDataLoaderMethodName(), module,
LLVM::LLVMType::getFunctionTy(
llvmI8PtrTy, {llvmI64Ty}, /*isVarArg=*/false),
rewriter);
auto constPackSize = rewriter.create<LLVM::ConstantOp>(loc,
LLVM::LLVMType::getInt64Ty(llvmDialect),
packedConstOp.size_in_bytesAttr());
Value alloc = rewriter
.create<CallOp>(loc, getEmbeddedConstPoolRef, llvmI8PtrTy,
ArrayRef<Value>({constPackSize}))
.getResult(0);
rewriter.create<LLVM::StoreOp>(loc, alloc, basePtrAddr);
{
OpBuilder::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
// Record constant pack *file path* as a global variable (by recording the
// file path string's underlying char array + its length).
const auto &fileNameAttr = packedConstOp.file_nameAttr();
auto type =
LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(llvmDialect),
fileNameAttr.getValue().size());
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
LLVM::Linkage::External,
mlir::KrnlPackedConstantOp::getConstPackFilePathSymbolName(),
fileNameAttr);
type = LLVM::LLVMType::getInt64Ty(llvmDialect);
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
LLVM::Linkage::External,
mlir::KrnlPackedConstantOp::getConstPackFilePathStrLenSymbolName(),
rewriter.getI64IntegerAttr(fileNameAttr.getValue().size()));
// Record constant pack *file name* as a global variable (by recording the
// file name string's underlying char array + its length).
auto constPackFileName =
llvm::sys::path::filename(fileNameAttr.getValue());
type = LLVM::LLVMType::getArrayTy(
LLVM::LLVMType::getInt8Ty(llvmDialect), constPackFileName.size());
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
LLVM::Linkage::External,
mlir::KrnlPackedConstantOp::getConstPackFileNameSymbolName(),
rewriter.getStringAttr(constPackFileName));
type = LLVM::LLVMType::getInt64Ty(llvmDialect);
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
LLVM::Linkage::External,
mlir::KrnlPackedConstantOp::getConstPackFileNameStrLenSymbolName(),
rewriter.getI64IntegerAttr(constPackFileName.size()));
type = LLVM::LLVMType::getInt8Ty(llvmDialect);
rewriter.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
LLVM::Linkage::External,
mlir::KrnlPackedConstantOp::getConstPackIsLESymbolName(),
rewriter.getI8IntegerAttr(packedConstOp.is_le()));
}
rewriter.eraseOp(op);
return success();
}
private:
static int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
return (a.getValue()[i]).cast<IntegerAttr>().getInt();
}
};
} // end namespace } // end namespace
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -712,7 +845,8 @@ void KrnlToLLVMLoweringPass::runOnOperation() {
/*emitCWrapperS=*/true, /*emitCWrapperS=*/true,
/*useAlignedAlloc=*/false); /*useAlignedAlloc=*/false);
patterns.insert<KrnlGlobalOpLowering>(&getContext(), typeConverter); patterns.insert<KrnlGlobalOpLowering, KrnlPackedConstOpLowering>(
&getContext(), typeConverter);
patterns.insert<KrnlGetRefOpLowering>(&getContext(), typeConverter); patterns.insert<KrnlGetRefOpLowering>(&getContext(), typeConverter);
// Lower from the `krnl` dialect i.e. the Reshape operation. // Lower from the `krnl` dialect i.e. the Reshape operation.
@ -722,8 +856,9 @@ void KrnlToLLVMLoweringPass::runOnOperation() {
// We want to completely lower to LLVM, so we use a `FullConversion`. This // We want to completely lower to LLVM, so we use a `FullConversion`. This
// ensures that only legal operations will remain after the conversion. // ensures that only legal operations will remain after the conversion.
if (failed(applyFullConversion( if (failed(applyFullConversion(
getOperation(), target, patterns, &typeConverter))) getOperation(), target, patterns, &typeConverter))) {
signalPassFailure(); signalPassFailure();
}
} }
/// Create the pass for lowering `Krnl`, `Affine` and `Std` dialects to LLVM. /// Create the pass for lowering `Krnl`, `Affine` and `Std` dialects to LLVM.

View File

@ -50,8 +50,9 @@ public:
int64_t dynamicOperations = 0; int64_t dynamicOperations = 0;
f.walk([&](mlir::Operation *op) { f.walk([&](mlir::Operation *op) {
if (returnsDynamicShape(op)) if (returnsDynamicShape(op)) {
dynamicOperations++; dynamicOperations++;
}
}); });
// If any dynamic operations remain, this indicates a failure. // If any dynamic operations remain, this indicates a failure.

View File

@ -0,0 +1,127 @@
//===- ElideKrnlGlobalConstants.cpp - Krnl Constant lobal Value Elision ---===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// In practice, the constant values of Global Krnl operations may be large
// enough to hinder the readability of the MLIR intermediate representation.
//
// This file creates a pass which elides the explicit values of constant
// global operations. This pass has purely cosmetic purposes and should only be
// run to obtain a compact representation of the program when emitting Krnl
// dialect code. This pass should never be invoked on code meant to be run.
//
//===----------------------------------------------------------------------===//
#include <fstream>
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/FileSystem.h"
#include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Pass/Passes.hpp"
#include "src/Transform/ElideKrnlGlobalConstants.hpp"
using namespace mlir;
namespace {
/*!
* Function pass that performs constant value elision of Krnl globals.
*/
class PackKrnlGlobalConstantsPass
: public PassWrapper<PackKrnlGlobalConstantsPass, OperationPass<ModuleOp>> {
public:
/// Make sure that we have a valid default constructor and copy constructor to
/// make sure that the options are initialized properly.
PackKrnlGlobalConstantsPass() = default;
PackKrnlGlobalConstantsPass(const PackKrnlGlobalConstantsPass &pass) {}
void runOnOperation() override {
auto module = getOperation();
OpBuilder builder(&getContext());
// Packing constant arrays to packedConst.
std::vector<char> packedConst;
module.walk([&](KrnlGlobalOp op) {
assert(op.value());
op.offsetAttr(builder.getI64IntegerAttr(packedConst.size()));
assert(op.value()->isa<DenseElementsAttr>());
const auto &denseAttr = op.valueAttr().cast<DenseElementsAttr>();
auto numElements = denseAttr.getNumElements();
if (numElements <= elisionThreshold)
return;
// TODO(tjingrant) verify we can actually use the raw data.
std::vector<char> rawData = denseAttr.getRawData();
packedConst.insert(packedConst.end(), rawData.begin(), rawData.end());
});
// Remove value attributes from krnl constant op.
ConversionTarget target(getContext());
OwningRewritePatternList patterns;
patterns.insert<KrnlConstGlobalValueElision>(
&getContext(), elisionThreshold);
// Apply constant value elision.
module.walk(
[&](FuncOp func) { applyPatternsAndFoldGreedily(func, patterns); });
bool isLE = llvm::support::endian::system_endianness() ==
llvm::support::endianness::little;
mlir::OperationState state(module.getLoc(), "krnl.packed_const");
KrnlPackedConstantOp::build(builder, state,
builder.getIntegerType(/*width=*/64),
/*size_in_bytes=*/builder.getI64IntegerAttr(packedConst.size()),
/*is_le=*/builder.getBoolAttr(isLE),
/*value=*/nullptr,
/*file_name=*/nullptr);
auto packedConstOp =
llvm::cast<mlir::KrnlPackedConstantOp>(mlir::Operation::create(state));
module.insert(module.begin(), packedConstOp);
if (moveToFile) {
std::string pathStr;
if (filename.hasValue()) {
pathStr = filename.getValue();
} else {
llvm::SmallVector<char, 10> path;
llvm::sys::fs::createTemporaryFile("packed_const", "tmp", path);
pathStr = std::string(path.begin(), path.end());
}
packedConstOp.file_nameAttr(builder.getStringAttr(pathStr));
std::ofstream outfile(pathStr, std::ofstream::binary);
outfile.write(packedConst.data(), packedConst.size());
} else {
auto shapeTy =
RankedTensorType::get({static_cast<int64_t>(packedConst.size())},
builder.getIntegerType(8));
auto denseAttr =
DenseIntElementsAttr::get(shapeTy, llvm::makeArrayRef(packedConst));
packedConstOp.valueAttr(denseAttr);
}
}
Option<bool> moveToFile{*this, "move-to-file",
llvm::cl::desc("Whether to move the packed constant to a file."),
llvm::cl::init(true)};
Option<int64_t> elisionThreshold{*this, "elision-threshold",
llvm::cl::desc(
"A threshold value specifying the maximum number of elements a "
"constant operation can hold as an attribute. If the number exceeds "
"this threshold, constants will be packed together and, in the case "
"where `move-to-file` option is enabled, stored as a binary file on "
"disk. This can help preserve readability of IR dump and improve "
"compilation speed."),
llvm::cl::init(KrnlConstGlobalValueElision::kDefaultElisionThreshold)};
Option<std::string> filename{*this, "filename",
llvm::cl::desc(
"Specify a file in which the packed constant is to be stored.")};
};
} // namespace
std::unique_ptr<Pass> mlir::createPackKrnlGlobalConstantsPass() {
return std::make_unique<PackKrnlGlobalConstantsPass>();
}
static PassRegistration<PackKrnlGlobalConstantsPass> pass("pack-krnl-constants",
"Elide the constant values of the Global Krnl operations.");

View File

@ -6,6 +6,7 @@ from __future__ import unicode_literals
import os import os
import sys import sys
import unittest import unittest
import warnings
import onnx.backend.base import onnx.backend.base
import onnx.backend.test import onnx.backend.test
@ -29,7 +30,57 @@ from PyRuntime import ExecutionSession
def execute_commands(cmds): def execute_commands(cmds):
if (VERBOSE): if (VERBOSE):
print(" ".join(cmds)) print(" ".join(cmds))
subprocess.run(cmds, stdout=subprocess.PIPE) subprocess.run(cmds)
# There are two issues, which necessitates the adoption of this endianness
# aware wrapper around Execution Session:
# 1. Input arrays are given sometimes in native byte order, sometime in
# LE byte order, and as soon as the python array enters into py::array
# C++ objects through pybind, we will no longer be able to query their
# endianness. So we must intercept the inputs and convert them into
# native endianness.
# 2. Output arrays are compared with reference outputs, the comparison
# unfortunately includes checking that our outputs and reference outputs
# share the same endianness. So we try to figure out what is the desired
# reference output endianness, and convert our outputs to this desired
# endianness.
class EndiannessAwareExecutionSession(ExecutionSession):
def __init__(self, path, entry_point):
super().__init__(path, entry_point)
def is_input_le(self, inputs):
inputs_endianness = list(map(lambda x: x.dtype.byteorder, inputs))
endianness_is_consistent = len(set(inputs_endianness)) <= 1
assert endianness_is_consistent, \
"Input arrays contain a mixture of endianness configuration."
sys_is_le = sys.byteorder == 'little'
# To interpret character symbols indicating endianness:
# https://numpy.org/doc/stable/reference/generated/numpy.dtype.byteorder.html
explicitly_le = inputs_endianness[0] == "<"
implicitly_le = (inputs_endianness[0] == "=" and sys_is_le)
return explicitly_le or implicitly_le
def run(self, inputs, **kwargs):
if len(inputs):
# Deduce desired endianness of output from inputs.
sys_is_le = sys.byteorder == 'little'
inp_is_le = self.is_input_le(inputs)
if (sys_is_le != inp_is_le):
inputs = list(
map(lambda x: x.byteswap().newbyteorder(), inputs))
outputs = super().run(inputs)
if (sys_is_le != inp_is_le):
outputs = list(
map(lambda x: x.byteswap().newbyteorder(), outputs))
return outputs
else:
# Can't deduce desired output endianess, fingers crossed.
warnings.warn(
"Cannot deduce desired output endianness, using native endianness by default."
)
return super().run(inputs)
class DummyBackend(onnx.backend.base.Backend): class DummyBackend(onnx.backend.base.Backend):
@ -40,7 +91,8 @@ class DummyBackend(onnx.backend.base.Backend):
onnx.save(model, "temp_model.onnx") onnx.save(model, "temp_model.onnx")
# Call frontend to process temp_model.onnx, bit code will be generated. # Call frontend to process temp_model.onnx, bit code will be generated.
execute_commands([ONNX_MLIR, "temp_model.onnx"]) execute_commands([ONNX_MLIR, "temp_model.onnx"])
return ExecutionSession("./temp_model.so", "_dyn_entry_point_main_graph") return EndiannessAwareExecutionSession("./temp_model.so",
"_dyn_entry_point_main_graph")
@classmethod @classmethod
def supports_device(cls, device): def supports_device(cls, device):
@ -80,7 +132,6 @@ test_to_enable = [
"test_concat_3d_axis_0_cpu", "test_concat_3d_axis_0_cpu",
"test_concat_3d_axis_1_cpu", "test_concat_3d_axis_1_cpu",
"test_concat_3d_axis_2_cpu", "test_concat_3d_axis_2_cpu",
"test_concat_1d_axis_negative_1_cpu", "test_concat_1d_axis_negative_1_cpu",
"test_concat_2d_axis_negative_1_cpu", "test_concat_2d_axis_negative_1_cpu",
"test_concat_2d_axis_negative_2_cpu", "test_concat_2d_axis_negative_2_cpu",
@ -359,12 +410,18 @@ test_to_enable = [
"test_split_variable_parts_1d_cpu", "test_split_variable_parts_1d_cpu",
"test_split_variable_parts_2d_cpu", "test_split_variable_parts_2d_cpu",
"test_split_variable_parts_default_axis_cpu", "test_split_variable_parts_default_axis_cpu",
# ResNet
"test_resnet50_cpu",
] ]
# Extract name of all test cases. # Extract name of all test cases.
import inspect import inspect
all_tests = inspect.getmembers( all_tests = []
all_tests += inspect.getmembers(
backend_test.test_cases["OnnxBackendRealModelTest"])
all_tests += inspect.getmembers(
backend_test.test_cases["OnnxBackendNodeModelTest"]) backend_test.test_cases["OnnxBackendNodeModelTest"])
all_test_names = list(map(lambda x: x[0], all_tests)) all_test_names = list(map(lambda x: x[0], all_tests))
@ -372,8 +429,7 @@ all_test_names = list(map(lambda x: x[0], all_tests))
for test_name in test_to_enable: for test_name in test_to_enable:
assert test_name in all_test_names, """test name {} not found, it is likely assert test_name in all_test_names, """test name {} not found, it is likely
that you may have misspelled the test name or the specified test does not that you may have misspelled the test name or the specified test does not
exist in the version of onnx package you installed.""".format( exist in the version of onnx package you installed.""".format(test_name)
test_name)
backend_test.include(r"^{}$".format(test_name)) backend_test.include(r"^{}$".format(test_name))
# import all test cases at global scope to make them visible to python.unittest # import all test cases at global scope to make them visible to python.unittest

View File

@ -6,7 +6,7 @@ configure_lit_site_cfg(${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
MAIN_CONFIG MAIN_CONFIG
${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py) ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py)
set(ONNX_MLIR_TEST_DEPENDS onnx-mlir-opt) set(ONNX_MLIR_TEST_DEPENDS onnx-mlir-opt binary-decoder)
add_lit_testsuite(check-onnx-lit add_lit_testsuite(check-onnx-lit
"Running the ONNX MLIR regression tests" "Running the ONNX MLIR regression tests"

View File

@ -0,0 +1,10 @@
// RUN: onnx-mlir-opt --pack-krnl-constants='elision-threshold=3 move-to-file=true filename=test1.bin' %s -split-input-file && binary-decoder test1.bin -s 0 -n 16 --onnx::TensorProto::FLOAT -rm | FileCheck %s -check-prefix=BINARY_DECODER_1
// RUN: onnx-mlir-opt --pack-krnl-constants='elision-threshold=3 move-to-file=true filename=test2.bin' %s -split-input-file && binary-decoder test2.bin -s 16 -n 32 --onnx::TensorProto::INT32 -rm | FileCheck %s -check-prefix=BINARY_DECODER_2
// BINARY_DECODER_1: 0.1 0.2 0.3 0.4
// BINARY_DECODER_2: 1 2 3 4
func @test_krnl_const_packing_file_mixing_types() -> memref<1x4xf32> {
%0 = "krnl.global"() {name = "constant_0", shape = [1, 4], value = dense<[[0.1, 0.2, 0.3, 0.4]]> : tensor<1x4xf32>} : () -> memref<1x4xf32>
%1 = "krnl.global"() {name = "constant_1", shape = [1, 4], value = dense<[[1, 2, 3, 4]]> : tensor<1x4xi32>} : () -> memref<1x4xi32>
return %0 : memref<1x4xf32>
}

View File

@ -0,0 +1,8 @@
import sys
root = config.root
if sys.byteorder == "little":
config.unsupported = True
else:
config.unsupported = False

View File

@ -0,0 +1,11 @@
// RUN: onnx-mlir-opt --pack-krnl-constants='elision-threshold=3 move-to-file=false' %s -split-input-file | FileCheck %s
// CHECK: [[CONST_PACK:%.+]] = "krnl.packed_const"() {is_le = false, size_in_bytes = 32 : i64, value = dense<[0, 0, 0, 0, 63, -128, 0, 0, 64, 0, 0, 0, 64, 64, 0, 0, 0, 0, 0, 0, 63, -128, 0, 0, 64, 0, 0, 0, 64, 64, 0, 0]> : tensor<32xi8>} : () -> i64
// CHECK-LABEL: func @test_krnl_const_packing() -> memref<1x4xf32> {
// CHECK-NEXT: [[CONST0:%.+]] = "krnl.global"() {name = "constant_0", offset = 0 : i64, shape = [1, 4]} : () -> memref<1x4xf32>
// CHECK-NEXT: [[CONST1:%.+]] = "krnl.global"() {name = "constant_1", offset = 16 : i64, shape = [1, 4]} : () -> memref<1x4xf32>
func @test_krnl_const_packing() -> memref<1x4xf32> {
%0 = "krnl.global"() {name = "constant_0", shape = [1, 4], value = dense<[[0., 1., 2., 3.]]> : tensor<1x4xf32>} : () -> memref<1x4xf32>
%1 = "krnl.global"() {name = "constant_1", shape = [1, 4], value = dense<[[0., 1., 2., 3.]]> : tensor<1x4xf32>} : () -> memref<1x4xf32>
return %0 : memref<1x4xf32>
}

View File

@ -0,0 +1,8 @@
import sys
root = config.root
if sys.byteorder == "little":
config.unsupported = False
else:
config.unsupported = True

View File

@ -0,0 +1,11 @@
// RUN: onnx-mlir-opt --pack-krnl-constants='elision-threshold=3 move-to-file=false' %s -split-input-file | FileCheck %s
// CHECK: [[CONST_PACK:%.+]] = "krnl.packed_const"() {is_le = true, size_in_bytes = 32 : i64, value = dense<[0, 0, 0, 0, 0, 0, -128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 0, 0, 0, 0, -128, 63, 0, 0, 0, 64, 0, 0, 64, 64]> : tensor<32xi8>} : () -> i64
// CHECK-LABEL: func @test_krnl_const_packing() -> memref<1x4xf32> {
// CHECK-NEXT: [[CONST0:%.+]] = "krnl.global"() {name = "constant_0", offset = 0 : i64, shape = [1, 4]} : () -> memref<1x4xf32>
// CHECK-NEXT: [[CONST1:%.+]] = "krnl.global"() {name = "constant_1", offset = 16 : i64, shape = [1, 4]} : () -> memref<1x4xf32>
func @test_krnl_const_packing() -> memref<1x4xf32> {
%0 = "krnl.global"() {name = "constant_0", shape = [1, 4], value = dense<[[0., 1., 2., 3.]]> : tensor<1x4xf32>} : () -> memref<1x4xf32>
%1 = "krnl.global"() {name = "constant_1", shape = [1, 4], value = dense<[[0., 1., 2., 3.]]> : tensor<1x4xf32>} : () -> memref<1x4xf32>
return %0 : memref<1x4xf32>
}

View File

@ -0,0 +1,8 @@
// RUN: onnx-mlir-opt --pack-krnl-constants='elision-threshold=3 move-to-file=true filename=test-pack-consts-to-file-same-type.bin' %s -split-input-file && binary-decoder test-pack-consts-to-file-same-type.bin -s 0 -n 32 --onnx::TensorProto::FLOAT -rm | FileCheck %s
// CHECK: 0 1 2 3 0 1 2 3
func @test_krnl_const_packing_file() -> memref<1x4xf32> {
%0 = "krnl.global"() {name = "constant_0", shape = [1, 4], value = dense<[[0., 1., 2., 3.]]> : tensor<1x4xf32>} : () -> memref<1x4xf32>
%1 = "krnl.global"() {name = "constant_1", shape = [1, 4], value = dense<[[0., 1., 2., 3.]]> : tensor<1x4xf32>} : () -> memref<1x4xf32>
return %0 : memref<1x4xf32>
}

View File

@ -31,7 +31,7 @@ tool_dirs = [
config.onnx_mlir_tools_dir, config.mlir_tools_dir, config.llvm_tools_dir config.onnx_mlir_tools_dir, config.mlir_tools_dir, config.llvm_tools_dir
] ]
tool_names = [ tool_names = [
'onnx-mlir-opt', 'mlir-opt', 'mlir-translate' 'onnx-mlir-opt', 'mlir-opt', 'mlir-translate', "binary-decoder"
] ]
tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] tools = [ToolSubst(s, unresolved='ignore') for s in tool_names]
llvm_config.add_tool_substitutions(tools, tool_dirs) llvm_config.add_tool_substitutions(tools, tool_dirs)

View File

@ -1,4 +1,3 @@
import lit.llvm import lit.llvm
if '@BUILD_SHARED_LIBS@' == 'ON': if '@BUILD_SHARED_LIBS@' == 'ON':

View File

@ -1,10 +1,10 @@
// RUN: onnx-mlir-opt --elide-krnl-constants %s -split-input-file | FileCheck %s // RUN: onnx-mlir-opt --elide-krnl-constants %s -split-input-file | FileCheck %s
// CHECK-LABEL: func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x10xf32> // CHECK-LABEL: func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x70xf32>
func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x10xf32> { func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x70xf32> {
%0 = "krnl.global"() {name = "constant_0", shape = [1, 10], value = dense<[[-0.0448560268, 0.00779166119, 0.0681008175, 0.0299937408, -0.126409635, 0.14021875, -0.0552849025, -0.0493838154, 0.0843220502, -0.0545404144]]> : tensor<1x10xf32>} : () -> memref<1x10xf32> %0 = "krnl.global"() {name = "constant_0", shape = [1, 70], value = dense<[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]> : tensor<1x70xf32>} : () -> memref<1x70xf32>
return %0 : memref<1x10xf32> return %0 : memref<1x70xf32>
// CHECK: %0 = "krnl.global"() {name = "constant_0", shape = [1, 10]} : () -> memref<1x10xf32> // CHECK: {{.*}} = "krnl.global"() {name = "constant_0", shape = [1, 70]} : () -> memref<1x70xf32>
// CHECK: return %0 : memref<1x10xf32> // CHECK: return {{.*}} : memref<1x70xf32>
} }

View File

@ -37,7 +37,7 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H,
llvm::SmallVector<Type, 1> outputsType{yType}; llvm::SmallVector<Type, 1> outputsType{yType};
auto funcType = builder.getFunctionType(inputsType, outputsType); auto funcType = builder.getFunctionType(inputsType, outputsType);
string funcName = "test_conv"; string funcName = "main_graph";
llvm::SmallVector<NamedAttribute, 1> attrs; llvm::SmallVector<NamedAttribute, 1> attrs;
auto funcOp = auto funcOp =
builder.create<FuncOp>(UnknownLoc::get(&ctx), funcName, funcType, attrs); builder.create<FuncOp>(UnknownLoc::get(&ctx), funcName, funcType, attrs);
@ -88,13 +88,13 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H,
OwningModuleRef moduleRef(module); OwningModuleRef moduleRef(module);
llvm::SmallVector<char, 10> path; llvm::SmallVector<char, 10> path;
llvm::sys::fs::createTemporaryFile("_test_conv", "", path); llvm::sys::fs::createTemporaryFile("_main_graph", "", path);
string pathStr(path.begin(), path.end()); string pathStr(path.begin(), path.end());
llvm::FileRemover remover(path); llvm::FileRemover remover(path);
compileModule(moduleRef, ctx, pathStr, EmitLib); compileModule(moduleRef, ctx, pathStr, EmitLib);
onnx_mlir::ExecutionSession sess( onnx_mlir::ExecutionSession sess(
pathStr + ".so", "_dyn_entry_point_test_conv"); pathStr + ".so", "_dyn_entry_point_main_graph");
std::vector<unique_ptr<DynMemRef>> inputs; std::vector<unique_ptr<DynMemRef>> inputs;
auto xDmr = unique_ptr<DynMemRef>(getRndRealDmr<float>({N, C, H, W})); auto xDmr = unique_ptr<DynMemRef>(getRndRealDmr<float>({N, C, H, W}));