From e0ae583da041c7115054250aa1d0fdcc8bcb4957 Mon Sep 17 00:00:00 2001 From: Tian Jin Date: Fri, 12 Jun 2020 10:27:05 +0800 Subject: [PATCH] Compiling Models with Large Constant Arrays (#146) * PoC works. * MNist works. * Clean up. * Fix test. * Make Linux work. * Use consistent symbol name. * Fix variable name. * Fix array addr access. * Bug fix. * Bug fix. * install before running e2e tests. * Fix build config. * Use sudo when installing. * Make embeddedDataLoader position independent. * Enable ResNet50. * Format code. * Format MainUtil. * Try not using sudo to install. * Supply runtime dir via environment variable. * Dump problematic operation. * Dump entire function. * Debug. * Dump input. * Dump constant op. * Debug. * Debug. * Debug. * Print to stderr. * take care of endianness. * Use endianness-aware execution session. * Fix ZLinux error. * Include warning when desired output endianness can't be deduced. * Remove debug code. * Remove debug code in shape inference. * Support binary-decoder for testing constants packing. * Support filename, move-to-file, elision-threshold configurations in constant packing pass for easy testing. * Add lit test, fix lit test type mismatch. * Add more consts packing tests. * Ensure intermediate files are properly cleaned up. * No need for constant elimination. * Link with threading libraries. * Remove debug code. * Format code. * More tests. * test nit. * Remove debug code. * Reduce hard-coded constants. * Use temporary and unique working directory for hosting model parameters. * Test if it works. * Try to find objcopy. * Rename symbols using objcopy. * Move sanitized name to linux section. * Use verbose mode for debugging. * Disambiguate pass constructor. * Fix symbol name. * Use Command API to build and execute commands. * Move linux to use Command API. * Fix reset args. * Execute redefine sym. * Format code. * Do not use verbose mode for CircleCI. * Remove debug code. * Prettify code, add comments. * getSegmentData -> getEmbeddedConstPool * vector -> std::vector. * Make sure we properly clean up intermediate files. * Fix test cases. * Add runtime directory. * Trigger rebuild. * [Merge with master] fix debug script. * Diable affine fusion pass for now. * Support generic fallback const packing mechanism. * Remove debug code. * Handle the case where objcopy is not available. * Fix Windows missing types. * Support int64. * Copy packed constant to a local directory for non-Linux/Mac platforms. * Nit: remove debug code, refactor const pack preprocessing out as a separate function. * Cannot make preprocessConstPack a standalone function because file removers are stack-allocated, and they are deallocated prematurely when function stack gets popped, deleteing intermediate files too early. * Don't require executable filename. * Import ONNX data types directly. * Fix LIT test. * Bug fix, use moved string value. * Remove redundant filenames. * Fix CMake script. * Embed endianness information as a symbol, and check during runtime. * More comments, update lit tests. * Fix lit test on BE machine. * Copyright notices. --- .buildbot/z13.sh | 2 +- .circleci/config.yml | 6 +- MLIR.cmake | 1 + src/Builder/FrontendDialectHelper.cpp | 14 +- src/CMakeLists.txt | 7 +- src/Conversion/ONNXToKrnl/Tensor/Constant.cpp | 6 +- src/Dialect/Krnl/KrnlOps.td | 54 +++- src/Dialect/ONNX/ONNXOps.cpp | 1 - src/ExternalUtil.hpp.in | 7 +- src/MainUtils.cpp | 189 ++++++++++++-- src/Pass/Passes.hpp | 3 + src/Runtime/CMakeLists.txt | 9 +- src/Runtime/DynMemRef.cpp | 4 + src/Runtime/GetEmbeddedConstPool.cpp | 88 +++++++ src/Runtime/GetEmbeddedConstPool.h | 18 ++ src/Tool/BinaryDecoder/BinaryDecoder.cpp | 95 +++++++ src/Tool/BinaryDecoder/CMakeLists.txt | 9 + src/Tool/CMakeLists.txt | 3 +- src/Transform/CMakeLists.txt | 14 +- src/Transform/ElideKrnlGlobalConstants.cpp | 40 ++- src/Transform/ElideKrnlGlobalConstants.hpp | 35 +++ src/Transform/LowerToLLVM.cpp | 241 ++++++++++++++---- src/Transform/ONNX/ShapeInferencePass.cpp | 3 +- src/Transform/PackKrnlGlobalConstants.cpp | 127 +++++++++ test/backend/test.py | 68 ++++- test/mlir/CMakeLists.txt | 2 +- test/mlir/krnl/pack_consts_mix_types.mlir | 10 + .../krnl/pack_krnl_constants_be/lit.local.cfg | 8 + .../pack_krnl_constants.mlir | 11 + .../krnl/pack_krnl_constants_le/lit.local.cfg | 8 + .../pack_krnl_constants.mlir | 11 + test/mlir/krnl/pack_krnl_consts_to_file.mlir | 8 + test/mlir/lit.cfg.py | 2 +- test/mlir/lit.site.cfg.py.in | 1 - test/mlir/onnx/onnx_krnl_global_elision.mlir | 12 +- test/numerical/TestConv.cpp | 6 +- 36 files changed, 994 insertions(+), 129 deletions(-) create mode 100644 src/Runtime/GetEmbeddedConstPool.cpp create mode 100644 src/Runtime/GetEmbeddedConstPool.h create mode 100644 src/Tool/BinaryDecoder/BinaryDecoder.cpp create mode 100644 src/Tool/BinaryDecoder/CMakeLists.txt create mode 100644 src/Transform/ElideKrnlGlobalConstants.hpp create mode 100644 src/Transform/PackKrnlGlobalConstants.cpp create mode 100644 test/mlir/krnl/pack_consts_mix_types.mlir create mode 100644 test/mlir/krnl/pack_krnl_constants_be/lit.local.cfg create mode 100644 test/mlir/krnl/pack_krnl_constants_be/pack_krnl_constants.mlir create mode 100644 test/mlir/krnl/pack_krnl_constants_le/lit.local.cfg create mode 100644 test/mlir/krnl/pack_krnl_constants_le/pack_krnl_constants.mlir create mode 100644 test/mlir/krnl/pack_krnl_consts_to_file.mlir diff --git a/.buildbot/z13.sh b/.buildbot/z13.sh index 0ef81dc..c5361fa 100755 --- a/.buildbot/z13.sh +++ b/.buildbot/z13.sh @@ -71,4 +71,4 @@ cmake -DCMAKE_INSTALL_PREFIX=${INSTALL_PATH} .. \ make -j$(nproc) onnx-mlir make -j$(nproc) check-onnx-lit -make -j$(nproc) check-onnx-backend +RUNTIME_DIR=$(pwd)/lib make -j$(nproc) check-onnx-backend diff --git a/.circleci/config.yml b/.circleci/config.yml index c590cd2..e578e10 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -40,14 +40,14 @@ jobs: command: | sudo pip install -q -e ./onnx-mlir/third_party/onnx cd onnx-mlir/build - VERBOSE=1 cmake --build . --target check-onnx-backend + RUNTIME_DIR=$(pwd)/lib cmake --build . --target check-onnx-backend - run: name: Run Unit Tests command: | cd onnx-mlir/build # Need to include the bin directory in $PATH, # otherwise CTest fails to find the test executables. - PATH=$(pwd)/bin:$PATH make test -j$(nproc) + RUNTIME_DIR=$(pwd)/lib PATH=$(pwd)/bin:$PATH make test -j$(nproc) - run: name: Run DocCheck command: cd onnx-mlir/build && cmake --build . --target check-doc @@ -60,4 +60,4 @@ jobs: diff docs/Dialects ../docs/Dialects - run: name: Print the Current Time - command: date + command: date \ No newline at end of file diff --git a/MLIR.cmake b/MLIR.cmake index e084172..0350e4b 100644 --- a/MLIR.cmake +++ b/MLIR.cmake @@ -269,6 +269,7 @@ set(ONNXMLIRWholeArchiveLibs OMPromotableConstOperandsOpInterface OMElideConstants OMElideKrnlGlobalConstants + OMPackKrnlGlobalConstants OMEnableMemoryPool) # Function to construct linkage option for the static libraries that must be diff --git a/src/Builder/FrontendDialectHelper.cpp b/src/Builder/FrontendDialectHelper.cpp index 63bbde6..4efad0a 100644 --- a/src/Builder/FrontendDialectHelper.cpp +++ b/src/Builder/FrontendDialectHelper.cpp @@ -7,6 +7,8 @@ // Helper methods for handling input ONNX models. // //===----------------------------------------------------------------------===// +#include +#include #include "src/Builder/FrontendDialectHelper.hpp" @@ -104,8 +106,16 @@ static std::vector CreateArrayAttribute(onnx::TensorProto initializer) { std::copy(initializer.raw_data().begin(), initializer.raw_data().end(), back_inserter(byteInitializer)); size = initializer.raw_data().size() / sizeof(T); - T *res = reinterpret_cast(&byteInitializer[0]); - return std::vector(res, res + size); + T *arrayPtr = reinterpret_cast(&byteInitializer[0]); + auto array = std::vector(arrayPtr, arrayPtr + size); + // Perform byte swap if system endianness is BE. + // ONNX tensor content raw data is always in LE. + if (llvm::support::endian::system_endianness() != + llvm::support::endianness::little) + for (int i = 0; i < array.size(); i++) + llvm::sys::swapByteOrder(array[i]); + + return array; } // copy, no need to take care of endianness diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2dfa041..e6ef433 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -25,9 +25,6 @@ if(NOT EXISTS "${LLVM_PROJ_BUILD}/bin/llc") message(ERROR "Cannot find llc.") endif() -# Get the compiler command name, the C++ compiler is needed to to translate -# object files to shared libraries. -get_filename_component(CXX_COMPILER_FILENAME ${CMAKE_CXX_COMPILER} NAME) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/ExternalUtil.hpp.in ${CMAKE_CURRENT_BINARY_DIR}/ExternalUtil.hpp) @@ -57,6 +54,7 @@ target_link_libraries(MainUtils OMResultTypeInferenceOpInterface OMElideConstants OMElideKrnlGlobalConstants + OMPackKrnlGlobalConstants OMEnableMemoryPool OMKrnlToAffine OMKrnlToLLVM @@ -71,6 +69,9 @@ if (INCLUDE_ONNX_ML) add_dependencies(MainUtils OMMLONNXOpsInc) endif() +add_dependencies(onnx-mlir cruntime) +add_dependencies(onnx-mlir EmbeddedDataLoader) + target_include_directories(onnx-mlir PRIVATE ${ONNX_MLIR_SRC_ROOT}) target_include_directories(onnx-mlir PRIVATE ${CMAKE_BINARY_DIR}) target_include_directories(onnx-mlir PRIVATE ${ONNX_MLIR_BIN_ROOT}) diff --git a/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp b/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp index 3b68c09..d4fa675 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Constant.cpp @@ -38,9 +38,11 @@ struct ONNXConstantOpLowering : public ConversionPattern { // Emit the constant global in Krnl dialect. auto constantGlobal = rewriter.create(loc, memRefType, - rewriter.getI64ArrayAttr(shape), + /*shape=*/rewriter.getI64ArrayAttr(shape), + /*name=*/ rewriter.getStringAttr("constant_" + std::to_string(constantID)), - constantOp.value().getValue()); + /*value=*/constantOp.value().getValue(), + /*offset=*/nullptr); // Increment constant ID: constantID++; diff --git a/src/Dialect/Krnl/KrnlOps.td b/src/Dialect/Krnl/KrnlOps.td index 6b72883..baea4bf 100644 --- a/src/Dialect/Krnl/KrnlOps.td +++ b/src/Dialect/Krnl/KrnlOps.td @@ -196,16 +196,66 @@ def KrnlMemcpyOp : Op { def KrnlGlobalOp : Op { let summary = "Krnl global operation"; let description = [{ - Operation for holding global data values. + Operation for holding global data values. A global constant can have a + meaningful name recorded as its `name` attribute. Its content is stored + in the `value` dense element attribute. Alternatively, if the constants + are packed together, `offset` records the byte offset in the global + constant pool from which the current constant is to be read. }]; - let arguments = (ins AnyAttr:$shape, StrAttr:$name, OptionalAttr:$value); + let arguments = (ins AnyAttr:$shape, + StrAttr:$name, OptionalAttr:$value, OptionalAttr:$offset); let results = (outs AnyTypeOf<[AnyMemRef]>:$output); let parser = ?; let printer = ?; } +def KrnlPackedConstantOp : Op { + let summary = "Krnl packed constant operation"; + let description = [{ + Operation for holding packed constants. + }]; + + let arguments = (ins I64Attr:$size_in_bytes, + BoolAttr:$is_le, + OptionalAttr>:$value, + OptionalAttr:$file_name); + let results = (outs I64:$output); + + let extraClassDeclaration = [{ + // The *path* to the file storing the constant pack on disk is encoded + // as a global variable in the LLVM module of the lowered model. + // getConstPackFilePathSymbolName returns the name of this symbol; + // getConstPackFilePathStrLenSymbolName returns the name of the symbol holding + // a constant value equal to the length of the file path. + static StringRef getConstPackFilePathSymbolName() { return "constPackFilePath"; } + static StringRef getConstPackFilePathStrLenSymbolName() { return "constPackFilePathStrLen"; } + + // The *name* of the file storing the constant pack is also recorded for + // convenience. Similarly, getConstPackFileNameSymbolName and + // getConstPackFileNameStrLenSymbolName returns records the symbol holding + // the string constant representing the filename and the length of this + // string constant. + static StringRef getConstPackFileNameSymbolName() { return "constPackFileName"; } + static StringRef getConstPackFileNameStrLenSymbolName() { return "constPackFileNameStrLen"; } + + // We record whether the constant pack is stored in LE byte order as a + // int8 symbol; it is to be interpreted as a boolean switch - with 0 + // meaning that the constant pack is not stored in LE byte order and + // non-0 values meaning that it is stored in LE byte order. + static StringRef getConstPackIsLESymbolName() { return "constPackIsLE"; } + // The name of a function we call to read packed constants embedded within + // the current binary executable/library, or in the case of unsupported platform, + // from a binary constant pack file. + static StringRef getEmbeddedDataLoaderMethodName() { + return "getEmbeddedConstPool"; + } + }]; + let parser = ?; + let printer = ?; +} + def KrnlGetRefOp : Op { let summary = "Krnl a MemRef from within another MemRef starting at a specific offset."; let description = [{ diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index 367328b..99a6c78 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -1138,7 +1138,6 @@ LogicalResult ONNXReshapeOp::inferShapes() { if (constantOp) { DenseElementsAttr valueAttribute = constantOp.valueAttr().dyn_cast(); - if (!valueAttribute) return emitError("DenseElementsAttr expected"); // Get dims from valueAttribute. diff --git a/src/ExternalUtil.hpp.in b/src/ExternalUtil.hpp.in index 8defdef..5aa81a2 100644 --- a/src/ExternalUtil.hpp.in +++ b/src/ExternalUtil.hpp.in @@ -1,9 +1,10 @@ #pragma once + #include namespace onnx_mlir { const std::string kLlcPath = "@LLVM_PROJ_BUILD@/bin/llc"; const std::string kCxxPath = "@CMAKE_CXX_COMPILER@"; -const std::string kCxxFileName = "@CXX_COMPILER_FILENAME@"; -const std::string kRuntimeDirPath = "@CMAKE_BINARY_DIR@/lib"; -} +const std::string kLinkerPath = "@CMAKE_LINKER@"; +const std::string kObjCopyPath = "@CMAKE_OBJCOPY@"; +} // namespace onnx_mlir diff --git a/src/MainUtils.cpp b/src/MainUtils.cpp index 4de03ae..ff16f07 100644 --- a/src/MainUtils.cpp +++ b/src/MainUtils.cpp @@ -9,8 +9,15 @@ //===----------------------------------------------------------------------===// #include +#include #include +#include +#include + +#include #include +#include +#include #include "src/ExternalUtil.hpp" #include "src/MainUtils.hpp" @@ -26,6 +33,63 @@ using namespace std; using namespace onnx_mlir; +namespace { + +llvm::Optional getEnvVar(std::string name) { + if (const char *envVerbose = std::getenv(name.c_str())) + return std::string(envVerbose); + return llvm::None; +} + +// Helper struct to make command construction and execution easy & readable. +struct Command { + std::string _path; + std::vector _args; + + Command(std::string exePath) + : _path(std::move(exePath)), + _args({llvm::sys::path::filename(_path).str()}) {} + + // Append a single string argument. + Command &appendStr(const std::string &arg) { + _args.emplace_back(arg); + return *this; + } + + // Append a list of string arguments. + Command &appendList(const std::vector &args) { + _args.insert(_args.end(), args.begin(), args.end()); + return *this; + } + + // Reset arguments. + Command &resetArgs() { + auto exeFileName = _args.front(); + _args.clear(); + _args.emplace_back(exeFileName); + return *this; + } + + // Execute command. + void exec() { + auto argsRef = std::vector(_args.begin(), _args.end()); + bool verbose = false; + if (const auto &verboseStr = getEnvVar("VERBOSE")) + istringstream(verboseStr.getValue()) >> verbose; + + // If in verbose mode, print out command before execution. + if (verbose) + cout << llvm::join(argsRef, " ") << "\n"; + int rc = llvm::sys::ExecuteAndWait(_path, llvm::makeArrayRef(argsRef)); + + if (rc != 0) { + fprintf(stderr, "%s\n", llvm::join(argsRef, " ").c_str()); + llvm_unreachable("Command execution failed."); + } + } +}; +} // namespace + void LoadMLIR(string inputFilename, mlir::MLIRContext &context, mlir::OwningModuleRef &module) { // Handle '.mlir' input to the ONNX MLIR frontend. @@ -50,31 +114,123 @@ void LoadMLIR(string inputFilename, mlir::MLIRContext &context, void compileModuleToSharedLibrary( const mlir::OwningModuleRef &module, string outputBaseName) { + // Extract constant pack file name, which is embedded as a symbol in the + // module being compiled. + auto constPackFilePathSym = (*module).lookupSymbol( + mlir::KrnlPackedConstantOp::getConstPackFilePathSymbolName()); + auto constPackFilePath = constPackFilePathSym.valueAttr() + .dyn_cast_or_null() + .getValue() + .str(); + llvm::FileRemover constPackRemover(constPackFilePath); + + llvm::Optional constPackObjPath; +#if __APPLE__ + // Create a empty stub file, compile it to an empty obj file. + llvm::SmallVector stubSrcPath; + llvm::sys::fs::createTemporaryFile("stub", "cpp", stubSrcPath); + llvm::FileRemover subSrcRemover(stubSrcPath); + std::string stubSrcPathStr(stubSrcPath.begin(), stubSrcPath.end()); + Command createStubObj(/*exePath=*/kCxxPath); + std::string stubObjPathStr = stubSrcPathStr + ".o"; + createStubObj.appendList({"-o", stubObjPathStr}) + .appendList({"-c", stubSrcPathStr}) + .exec(); + llvm::FileRemover stubObjRemover(stubObjPathStr); + + // Embed data into the empty stub obj file. + constPackObjPath = constPackFilePath + ".o"; + Command genParamObj(/*exePath=*/kLinkerPath); + genParamObj.appendStr("-r") + .appendList({"-o", constPackObjPath.getValue()}) + .appendList({"-sectcreate", "binary", "param", constPackFilePath}) + .appendStr(stubObjPathStr) + .exec(); + llvm::FileRemover constPackObjRemover(constPackObjPath.getValue()); + +#elif __linux__ + // Create param.o holding packed parameter values. + constPackObjPath = constPackFilePath + ".o"; + Command genParamObj(/*exePath=*/kLinkerPath); + genParamObj.appendStr("-r") + .appendList({"-b", "binary"}) + .appendList({"-o", constPackObjPath.getValue()}) + .appendStr(constPackFilePath) + .exec(); + llvm::FileRemover constPackObjRemover(constPackObjPath.getValue()); + + // Figure out what is the default symbol name describing the start/end + // address of the embedded data. + std::regex e("[^0-9A-Za-z]"); + auto sanitizedName = + "_binary_" + std::regex_replace(constPackFilePath, e, "_"); + + // Rename the symbols to saner ones expected by the runtime function. + Command redefineSym(/*exePath=*/kObjCopyPath); + redefineSym.appendStr("--redefine-sym") + .appendStr(sanitizedName + "_start=_binary_param_bin_start") + .appendStr(constPackObjPath.getValue()) + .exec(); + redefineSym.resetArgs() + .appendStr("--redefine-sym") + .appendStr(sanitizedName + "_end=_binary_param_bin_end") + .appendStr(constPackObjPath.getValue()) + .exec(); + +#else + llvm::SmallVector permConstPackFileName( + constPackFilePath.begin(), constPackFilePath.end()); + llvm::sys::path::replace_extension(permConstPackFileName, "bin"); + std::string permConstPackFileNameStr( + permConstPackFileName.begin(), permConstPackFileName.end()); + auto constPackFileName = llvm::sys::path::filename(outputBaseName) + "." + + llvm::sys::path::filename(permConstPackFileNameStr); + llvm::sys::fs::rename(constPackFilePath, constPackFileName); + + mlir::Builder builder(*module); + (*module) + .lookupSymbol( + mlir::KrnlPackedConstantOp::getConstPackFileNameSymbolName()) + .valueAttr(builder.getStringAttr(constPackFileName.str())); + (*module) + .lookupSymbol( + mlir::KrnlPackedConstantOp::getConstPackFileNameStrLenSymbolName()) + .valueAttr(builder.getI64IntegerAttr(constPackFileName.str().size())); +#endif + // Write LLVM bitcode. string outputFilename = outputBaseName + ".bc"; error_code error; llvm::raw_fd_ostream moduleBitcodeStream( outputFilename, error, llvm::sys::fs::F_None); + llvm::WriteBitcodeToFile( *mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream); moduleBitcodeStream.flush(); + llvm::FileRemover bcRemover(outputFilename); - // Compile bitcode to object file. - std::vector llcArgs = { - "llc", "-filetype=obj", "-relocation-model=pic", outputFilename}; - auto llcArgStrRefs = - std::vector(llcArgs.begin(), llcArgs.end()); - llvm::sys::ExecuteAndWait(kLlcPath, llvm::makeArrayRef(llcArgStrRefs)); + // Compile LLVM bitcode to object file. + Command llvmToObj(/*exePath=*/kLlcPath); + llvmToObj.appendStr("-filetype=obj"); + llvmToObj.appendStr("-relocation-model=pic"); + llvmToObj.appendStr(outputFilename); + llvmToObj.exec(); + std::string modelObjPath = outputBaseName + ".o"; + llvm::FileRemover modelObjRemover(modelObjPath); - // Link with runtime. - // TODO(tjingrant): link with runtime library in LLVM, and make the shared - // library more self-contained. - std::vector cxxArgs = {kCxxFileName, "-shared", "-fPIC", - outputBaseName + ".o", "-o", outputBaseName + ".so", - "-L" + kRuntimeDirPath, "-lcruntime", "-Wl,-rpath," + kRuntimeDirPath}; - auto argsArrayRefVector = - std::vector(cxxArgs.begin(), cxxArgs.end()); - llvm::sys::ExecuteAndWait(kCxxPath, llvm::makeArrayRef(argsArrayRefVector)); + std::string runtimeDirInclFlag; + if (getEnvVar("RUNTIME_DIR").hasValue()) + runtimeDirInclFlag = "-L" + getEnvVar("RUNTIME_DIR").getValue(); + + // Link everything into a shared object. + Command link(kCxxPath); + link.appendList({"-shared", "-fPIC"}) + .appendStr(modelObjPath) + .appendStr(constPackObjPath.getValueOr("")) + .appendList({"-o", outputBaseName + ".so"}) + .appendStr(runtimeDirInclFlag) + .appendList({"-lEmbeddedDataLoader", "-lcruntime"}) + .exec(); } void registerDialects() { @@ -98,6 +254,7 @@ void addONNXToMLIRPasses(mlir::PassManager &pm) { void addONNXToKrnlPasses(mlir::PassManager &pm) { pm.addPass(mlir::createLowerToKrnlPass()); + pm.addPass(mlir::createPackKrnlGlobalConstantsPass()); // An additional pass of canonicalization is helpful because lowering // from ONNX dialect to Standard dialect exposes additional canonicalization // oppertunities. @@ -110,7 +267,7 @@ void addONNXToKrnlPasses(mlir::PassManager &pm) { void addKrnlToAffinePasses(mlir::PassManager &pm) { pm.addPass(mlir::createLowerKrnlPass()); // Fuse loops in Affine dialect. - pm.addPass(mlir::createLoopFusionPass()); + // pm.addPass(mlir::createLoopFusionPass()); } void addKrnlToLLVMPasses(mlir::PassManager &pm) { diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index fff6dbf..9988537 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -43,4 +43,7 @@ std::unique_ptr createElideConstGlobalValuePass(); /// Pass for lowering Krnl dialect to LLVM dialect. std::unique_ptr createKrnlLowerToLLVMPass(); +/// Pass for packing Krnl global constants. +std::unique_ptr createPackKrnlGlobalConstantsPass(); + } // end namespace mlir diff --git a/src/Runtime/CMakeLists.txt b/src/Runtime/CMakeLists.txt index 0e7742c..df921a1 100644 --- a/src/Runtime/CMakeLists.txt +++ b/src/Runtime/CMakeLists.txt @@ -1,6 +1,6 @@ # Create shared libcruntime.so since model.so linkage for backend tests # will fail on x86 Linux if cruntime is statically linked. -add_library(cruntime SHARED +add_library(cruntime STATIC DynMemRef.cpp DynMemRef.h DataType.h) @@ -34,6 +34,13 @@ target_include_directories(PyRuntime PRIVATE ${ONNX_MLIR_BIN_ROOT} ${ONNX_MLIR_SRC_ROOT}) +add_library(EmbeddedDataLoader STATIC + GetEmbeddedConstPool.h + GetEmbeddedConstPool.cpp) +set_target_properties(EmbeddedDataLoader PROPERTIES + POSITION_INDEPENDENT_CODE TRUE) + add_dependencies(PyRuntime cruntime) install(FILES DynMemRef.h DESTINATION include) install(TARGETS cruntime DESTINATION lib) +install(TARGETS EmbeddedDataLoader DESTINATION lib) diff --git a/src/Runtime/DynMemRef.cpp b/src/Runtime/DynMemRef.cpp index 9fa12cd..7238f82 100644 --- a/src/Runtime/DynMemRef.cpp +++ b/src/Runtime/DynMemRef.cpp @@ -14,6 +14,10 @@ #include #include +#include +#include +#include + #include "DynMemRef.h" namespace { diff --git a/src/Runtime/GetEmbeddedConstPool.cpp b/src/Runtime/GetEmbeddedConstPool.cpp new file mode 100644 index 0000000..5337286 --- /dev/null +++ b/src/Runtime/GetEmbeddedConstPool.cpp @@ -0,0 +1,88 @@ +//===---- GetEmbeddedConstPool.h - Get Embedded Const Pool API Func Impl---===// +// +// Copyright 2019-2020 The IBM Research Authors. +// +// ============================================================================= +// +// This file contains runtime API implementation to extract constant pool values +// embedded within the shared library binary files. +// +//===----------------------------------------------------------------------===// + +#include "GetEmbeddedConstPool.h" + +#include +#include +#include + +// Adapted from: +// https://developer.ibm.com/technologies/systems/articles/au-endianc/ +const int i = 1; +#define IS_SYSTEM_LE() (!((*(char *)&i) == 0)) + +#define XOR(a, b) (!(a) != !(b)) + +extern const char constPackIsLE; + +void checkEndianness() { + if (XOR(IS_SYSTEM_LE(), constPackIsLE)) { + fprintf(stderr, "Constant pack is stored in a byte order that is not " + "native to this current system."); + exit(1); + } +} + +#if __APPLE__ +#include +extern const struct mach_header_64 _mh_dylib_header; + +void *getEmbeddedConstPool(int64_t size_in_byte) { + checkEndianness(); + size_t size = size_in_byte; + unsigned char *data = + getsectiondata(&_mh_dylib_header, "binary", "param", &size); + float *data_ptr = (float *)data; + void *buffer = malloc(size); + memcpy(buffer, data, size); + return data; +} + +#elif __linux__ +extern char _binary_param_bin_start; +extern char _binary_param_bin_end; + +void *getEmbeddedConstPool(int64_t _) { + checkEndianness(); + auto size = (unsigned int)(&_binary_param_bin_end - &_binary_param_bin_start); + void *buffer = malloc(size); + memcpy(buffer, &_binary_param_bin_start, size); + return buffer; +} + +#else + +extern char constPackFileName[]; +extern int64_t constPackFileNameStrLen; + +void *getEmbeddedConstPool(int64_t _) { + checkEndianness(); + char *fname = (char *)calloc(1, constPackFileNameStrLen + 1); + memcpy(fname, constPackFileName, constPackFileNameStrLen); + + // Adapted from https://stackoverflow.com/a/22059317 . + FILE *fileptr; + char *buffer; + long filelen; + + fileptr = fopen(fname, "rb"); // Open the file in binary mode + fseek(fileptr, 0, SEEK_END); // Jump to the end of the file + filelen = ftell(fileptr); // Get the current byte offset in the file + rewind(fileptr); // Jump back to the beginning of the file + + buffer = (char *)malloc(filelen * sizeof(char)); // Enough memory for the file + fread(buffer, filelen, 1, fileptr); // Read in the entire file + fclose(fileptr); // Close the file + + return (void *)buffer; +} +#endif \ No newline at end of file diff --git a/src/Runtime/GetEmbeddedConstPool.h b/src/Runtime/GetEmbeddedConstPool.h new file mode 100644 index 0000000..8286b5d --- /dev/null +++ b/src/Runtime/GetEmbeddedConstPool.h @@ -0,0 +1,18 @@ +//===---- GetEmbeddedConstPool.h - Get Embedded Const Pool API Func Decl---===// +// +// Copyright 2019-2020 The IBM Research Authors. +// +// ============================================================================= +// +// This file contains runtime API declarations to extract constant pool values +// embedded within the shared library binary files. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +extern "C" { +void *getEmbeddedConstPool(int64_t size_in_byte); +} \ No newline at end of file diff --git a/src/Tool/BinaryDecoder/BinaryDecoder.cpp b/src/Tool/BinaryDecoder/BinaryDecoder.cpp new file mode 100644 index 0000000..4af5697 --- /dev/null +++ b/src/Tool/BinaryDecoder/BinaryDecoder.cpp @@ -0,0 +1,95 @@ +//===----- BinaryDecoder.cpp - Decode binary files into typed arrays ------===// +// +// Copyright 2019-2020 The IBM Research Authors. +// +// ============================================================================= +// +// This file contains implementation of a utility called BinaryDecoder, which +// decodes a sequence of binary data within a binary file specified by an +// offset and a length into a typed array and print to stdout. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include + +#include "onnx/onnx_pb.h" +#include + +#if defined(_WIN32) + +#include + +typedef uint8_t u_int8_t; +typedef uint16_t u_int16_t; +typedef uint32_t u_int32_t; +typedef uint64_t u_int64_t; + +#endif + +llvm::cl::opt Filename( + llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::Required); +llvm::cl::opt Start("s", + llvm::cl::desc("Specify the index of the starting byte"), + llvm::cl::value_desc("start"), llvm::cl::Required); +llvm::cl::opt Size("n", + llvm::cl::desc("Specify the number of bytes of data to decode"), + llvm::cl::value_desc("size"), llvm::cl::Required); +llvm::cl::opt Remove( + "rm", llvm::cl::desc( + "Whether to remove the file being decoded after inspection.")); + +llvm::cl::opt DataType( + llvm::cl::desc("Choose data type to decode:"), + llvm::cl::values(clEnumVal(onnx::TensorProto::FLOAT, "FLOAT"), + clEnumVal(onnx::TensorProto::UINT8, "UINT8"), + clEnumVal(onnx::TensorProto::INT8, "INT8"), + clEnumVal(onnx::TensorProto::UINT16, "UINT16"), + clEnumVal(onnx::TensorProto::INT16, "INT16"), + clEnumVal(onnx::TensorProto::INT32, "INT32"), + clEnumVal(onnx::TensorProto::INT64, "INT64"), + clEnumVal(onnx::TensorProto::STRING, "STRING"), + clEnumVal(onnx::TensorProto::BOOL, "BOOL"), + clEnumVal(onnx::TensorProto::FLOAT16, "FLOAT16"), + clEnumVal(onnx::TensorProto::DOUBLE, "DOUBLE"), + clEnumVal(onnx::TensorProto::UINT32, "UINT32"), + clEnumVal(onnx::TensorProto::UINT64, "UINT64"))); + +template +int printBuffer(std::vector buffer) { + auto *ptr = (T *)&buffer[0]; + auto data = std::vector(ptr, ptr + buffer.size() / sizeof(T)); + for (const auto &elem : data) + std::cout << elem << " "; + return 0; +} + +int main(int argc, char **argv) { + llvm::cl::ParseCommandLineOptions(argc, argv); + std::vector buffer(Size); + std::ifstream file(Filename, std::ios::in | std::ios::binary); + if (!file) + return -1; + file.seekg(Start, file.beg); + file.read(&buffer[0], Size); + file.close(); + + if (Remove) + llvm::sys::fs::remove(Filename); + +#define PRINT_BUFFER_FOR_TYPE(ONNX_TYPE, CPP_TYPE) \ + if (DataType == ONNX_TYPE) \ + return printBuffer(buffer); + + PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::UINT8, u_int8_t); + PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::UINT16, u_int16_t); + PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::INT16, int16_t); + PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::INT32, int32_t); + PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::INT64, int64_t); + + PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::FLOAT, float); + PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::DOUBLE, double); + PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::UINT32, u_int32_t); + PRINT_BUFFER_FOR_TYPE(onnx::TensorProto::UINT64, u_int64_t); +} \ No newline at end of file diff --git a/src/Tool/BinaryDecoder/CMakeLists.txt b/src/Tool/BinaryDecoder/CMakeLists.txt new file mode 100644 index 0000000..a56538d --- /dev/null +++ b/src/Tool/BinaryDecoder/CMakeLists.txt @@ -0,0 +1,9 @@ +add_executable(binary-decoder BinaryDecoder.cpp) +target_link_libraries(binary-decoder + ${LLVMSupport} + ${LLVMDemangle} + ${CURSES_LIBRARIES} + ${CMAKE_THREAD_LIBS_INIT}) + +message(STATUS "incl dir" ${ONNX_INCLUDE_DIRS}) +target_link_libraries(binary-decoder onnx) \ No newline at end of file diff --git a/src/Tool/CMakeLists.txt b/src/Tool/CMakeLists.txt index 73d56f9..22cbfe7 100644 --- a/src/Tool/CMakeLists.txt +++ b/src/Tool/CMakeLists.txt @@ -1 +1,2 @@ -add_subdirectory(ONNXMLIROpt) \ No newline at end of file +add_subdirectory(ONNXMLIROpt) +add_subdirectory(BinaryDecoder) \ No newline at end of file diff --git a/src/Transform/CMakeLists.txt b/src/Transform/CMakeLists.txt index 8c5cdb4..af2bc9a 100644 --- a/src/Transform/CMakeLists.txt +++ b/src/Transform/CMakeLists.txt @@ -27,7 +27,8 @@ target_link_libraries(OMKrnlToLLVM onnx) add_library(OMElideKrnlGlobalConstants - ElideKrnlGlobalConstants.cpp) + ElideKrnlGlobalConstants.cpp + ElideKrnlGlobalConstants.hpp) target_include_directories(OMElideKrnlGlobalConstants PRIVATE ${ONNX_MLIR_SRC_ROOT} @@ -38,6 +39,14 @@ add_dependencies(OMElideKrnlGlobalConstants OMKrnlOpsInc) # Linking dependencies add_dependencies(OMElideKrnlGlobalConstants OMKrnlOps) +add_library(OMPackKrnlGlobalConstants + PackKrnlGlobalConstants.cpp) +target_include_directories(OMPackKrnlGlobalConstants + PRIVATE + ${ONNX_MLIR_SRC_ROOT} + ${ONNX_MLIR_BIN_ROOT} + ${ONNX_MLIR_SRC_ROOT}) + add_library(OMEnableMemoryPool EnableMemoryPool.cpp) target_include_directories(OMEnableMemoryPool @@ -45,6 +54,9 @@ target_include_directories(OMEnableMemoryPool ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} ${ONNX_MLIR_SRC_ROOT}) +add_dependencies(OMPackKrnlGlobalConstants + OMKrnlOps) + target_link_libraries(OMEnableMemoryPool onnx) add_dependencies(OMEnableMemoryPool diff --git a/src/Transform/ElideKrnlGlobalConstants.cpp b/src/Transform/ElideKrnlGlobalConstants.cpp index 73a0bbe..3630809 100644 --- a/src/Transform/ElideKrnlGlobalConstants.cpp +++ b/src/Transform/ElideKrnlGlobalConstants.cpp @@ -22,34 +22,31 @@ #include "src/Dialect/Krnl/KrnlOps.hpp" #include "src/Pass/Passes.hpp" +#include "ElideKrnlGlobalConstants.hpp" + using namespace mlir; -namespace { +const int64_t KrnlConstGlobalValueElision::kDefaultElisionThreshold = 32; -/*! - * RewritePattern that replaces existing constant Krnl global values - * with a similar operation which preserves all attributes except the value - * attribute. - */ +mlir::LogicalResult KrnlConstGlobalValueElision::matchAndRewrite( + mlir::KrnlGlobalOp op, mlir::PatternRewriter &rewriter) const { + auto loc = op.getLoc(); -class KrnlConstGlobalValueElision : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite( - KrnlGlobalOp op, PatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - if (op.value().hasValue()) { - auto newGlobalOp = rewriter.create( - loc, op.getResult().getType(), op.shape(), op.name(), nullptr); + if (op.value().hasValue()) { + const auto &valAttr = op.valueAttr().dyn_cast_or_null(); + if (valAttr.getNumElements() > elisionThreshold) { + IntegerAttr offsetAttr = op.offset() ? op.offsetAttr() : nullptr; + auto newGlobalOp = rewriter.create(loc, + op.getResult().getType(), /*shape=*/op.shape(), + /*name=*/op.name(), /*value=*/nullptr, /*offset=*/offsetAttr); rewriter.replaceOp(op, newGlobalOp.getResult()); } - - return success(); } -}; + return success(); +} + +namespace { /*! * Function pass that performs constant value elision of Krnl globals. */ @@ -66,6 +63,7 @@ public: applyPatternsAndFoldGreedily(function, patterns); } }; + } // namespace std::unique_ptr mlir::createElideConstGlobalValuePass() { @@ -73,4 +71,4 @@ std::unique_ptr mlir::createElideConstGlobalValuePass() { } static PassRegistration pass("elide-krnl-constants", - "Elide the constant values of the Global Krnl operations."); \ No newline at end of file + "Elide the constant values of the Global Krnl operations."); diff --git a/src/Transform/ElideKrnlGlobalConstants.hpp b/src/Transform/ElideKrnlGlobalConstants.hpp new file mode 100644 index 0000000..2fdd9b7 --- /dev/null +++ b/src/Transform/ElideKrnlGlobalConstants.hpp @@ -0,0 +1,35 @@ +#pragma once + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "src/Dialect/Krnl/KrnlOps.hpp" + +/*! + * RewritePattern that replaces existing constant Krnl global values + * with a similar operation which preserves all attributes except the value + * attribute. + */ +class KrnlConstGlobalValueElision + : public mlir::OpRewritePattern { +public: + /* + * A threshold value specifying the maximum number of elements a constant + * operation can hold as an attribute. If the number exceeds this threshold, + * constants will be packed together and, in the case where `move-to-file` + * option is enabled, stored as a binary file on disk. This can help preserve + * readability of IR dump and improve compilation speed. + */ + static const int64_t kDefaultElisionThreshold; + + int64_t elisionThreshold; + + using mlir::OpRewritePattern::OpRewritePattern; + + explicit KrnlConstGlobalValueElision( + mlir::MLIRContext *context, int64_t elisionThreshold) + : OpRewritePattern(context), elisionThreshold(elisionThreshold) {} + + mlir::LogicalResult matchAndRewrite( + mlir::KrnlGlobalOp op, mlir::PatternRewriter &rewriter) const override; +}; diff --git a/src/Transform/LowerToLLVM.cpp b/src/Transform/LowerToLLVM.cpp index 313b1ed..1b4f6cb 100644 --- a/src/Transform/LowerToLLVM.cpp +++ b/src/Transform/LowerToLLVM.cpp @@ -32,10 +32,9 @@ namespace { static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName, ModuleOp module, mlir::LLVM::LLVMType funcType, PatternRewriter &rewriter) { auto *context = module.getContext(); - if (module.lookupSymbol(funcName)) { - auto symbolRef = SymbolRefAttr::get(funcName, context); - assert(symbolRef.getType() == funcType && "wrong symbol type"); - return symbolRef; + if (auto sym = module.lookupSymbol(funcName)) { + assert(sym.getType() == funcType && "wrong symbol type"); + return SymbolRefAttr::get(funcName, context); } // Insert the function into the body of the parent module. @@ -203,55 +202,80 @@ public: // The llvm type of the global (example: [2 x [8 x float]]) auto llvmGlobalType = globalType.cast(); - { - OpBuilder::InsertionGuard insertGuard(rewriter); - rewriter.setInsertionPointToStart(module.getBody()); + mlir::Value alloc; + if (krnlGlobalOp.value().hasValue()) { + { + OpBuilder::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); - assert(krnlGlobalOp.value().hasValue() && - "Krnl Global must always have a value"); - global = rewriter.create(loc, llvmGlobalType, - /*isConstant=*/true, LLVM::Linkage::Internal, name, - krnlGlobalOp.value().getValue()); + assert(krnlGlobalOp.value().hasValue() && + "Krnl Global must always have a value"); + global = rewriter.create(loc, llvmGlobalType, + /*isConstant=*/true, LLVM::Linkage::Internal, name, + krnlGlobalOp.value().getValue()); + } + + // Some frequently used types. + auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); + auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); + + // Allocate the memory where the constants will be used from. + // This is a region of local memory and needs to be emitted as an alloca. + auto one = rewriter.create( + loc, llvmI64Ty, rewriter.getI64IntegerAttr(1)); + alloc = rewriter.create( + loc, llvmGlobalType.getPointerTo(), one, /*alignment=*/0); + + // Copy constant value into the local alloca: + // - Bitcast alloc to i8* + Value int8PtrAlloc = + rewriter.create(loc, llvmI8PtrTy, alloc); + // - Bitcast global to i8* + Value globalValue = rewriter.create(loc, global); + Value i8PtrGlobal = + rewriter.create(loc, llvmI8PtrTy, globalValue); + // - Set size. + Value memRefElementSize = + rewriter.create(loc, llvmI64Ty, + rewriter.getI64IntegerAttr(getMemRefEltSizeInBytes(memRefTy))); + Value numElementsValue = rewriter.create( + loc, llvmI64Ty, rewriter.getI64IntegerAttr(numElements)); + Value totalElementsSize = rewriter.create( + loc, memRefElementSize, numElementsValue); + Value int64Size = + rewriter.create(loc, llvmI64Ty, totalElementsSize); + // - Set volatile. + Value isVolatile = rewriter.create(loc, + LLVM::LLVMType::getInt1Ty(llvmDialect), + rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0)); + // - Copy constant data into the alloca. + auto memcpyRef = getOrInsertMemcpy(rewriter, module, llvmDialect); + rewriter.create(loc, memcpyRef, + LLVM::LLVMType::getVoidTy(llvmDialect), + ArrayRef({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile})); + } else { + // Some frequently used types. + auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); + auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); + + // Allocate the memory where the constants will be used from. + // This is a region of local memory and needs to be emitted as an alloca. + auto one = rewriter.create( + loc, llvmI64Ty, rewriter.getI64IntegerAttr(1)); + + auto base = module.lookupSymbol("packedConst"); + assert(base && "Cannot find symbol packedConst."); + + Value constPackBasePtrAddr = + rewriter.create(loc, base); + Value constPackBasePtr = rewriter.create( + loc, base.getType(), constPackBasePtrAddr); + auto offset = rewriter.create(loc, llvmI64Ty, + rewriter.getI64IntegerAttr( + krnlGlobalOp.offsetAttr().getValue().getSExtValue())); + alloc = rewriter.create( + loc, llvmI8PtrTy, constPackBasePtr, ValueRange({offset})); } - - // Some frequently used types. - auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); - auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); - - // Allocate the memory where the constants will be used from. - // This is a region of local memory and needs to be emitted as an alloca. - auto one = rewriter.create( - loc, llvmI64Ty, rewriter.getI64IntegerAttr(1)); - auto alloc = rewriter.create( - loc, llvmGlobalType.getPointerTo(), one, /*alignment=*/0); - - // Copy constant value into the local alloca: - // - Bitcast alloc to i8* - Value int8PtrAlloc = - rewriter.create(loc, llvmI8PtrTy, alloc); - // - Bitcast global to i8* - Value globalValue = rewriter.create(loc, global); - Value i8PtrGlobal = - rewriter.create(loc, llvmI8PtrTy, globalValue); - // - Set size. - Value memRefElementSize = rewriter.create(loc, llvmI64Ty, - rewriter.getI64IntegerAttr(getMemRefEltSizeInBytes(memRefTy))); - Value numElementsValue = rewriter.create( - loc, llvmI64Ty, rewriter.getI64IntegerAttr(numElements)); - Value totalElementsSize = - rewriter.create(loc, memRefElementSize, numElementsValue); - Value int64Size = - rewriter.create(loc, llvmI64Ty, totalElementsSize); - // - Set volatile. - Value isVolatile = rewriter.create(loc, - LLVM::LLVMType::getInt1Ty(llvmDialect), - rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0)); - // - Copy constant data into the alloca. - auto memcpyRef = getOrInsertMemcpy(rewriter, module, llvmDialect); - rewriter.create(loc, memcpyRef, - LLVM::LLVMType::getVoidTy(llvmDialect), - ArrayRef({int8PtrAlloc, i8PtrGlobal, int64Size, isVolatile})); - // Prepare data to be inserted into MemRef. auto llvmConstantElementType = constantElementType.cast(); Value typedAlloc = rewriter.create( @@ -681,6 +705,115 @@ private: } } }; + +//===----------------------------------------------------------------------===// +// KRNL to LLVM: KrnlPackedConstOpLowering +//===----------------------------------------------------------------------===// + +class KrnlPackedConstOpLowering : public ConvertToLLVMPattern { +public: + explicit KrnlPackedConstOpLowering( + MLIRContext *context, LLVMTypeConverter &lowering_) + : ConvertToLLVMPattern( + KrnlPackedConstantOp::getOperationName(), context, lowering_) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto *context = op->getContext(); + ModuleOp module = op->getParentOfType(); + auto loc = op->getLoc(); + + auto *llvmDialect = + op->getContext()->getRegisteredDialect(); + assert(llvmDialect && "expected llvm dialect to be registered"); + + auto packedConstOp = llvm::dyn_cast(op); + LLVM::GlobalOp globalBase; + // Some frequently used types. + auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); + auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); + { + OpBuilder::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + + globalBase = rewriter.create(loc, llvmI8PtrTy, + /*isConstant=*/false, LLVM::Linkage::Internal, "packedConst", + nullptr); + } + + auto mainFunc = module.lookupSymbol("main_graph"); + assert(mainFunc); + + rewriter.setInsertionPoint( + &mainFunc.getBody().front(), mainFunc.getBody().front().begin()); + + // - Initialize the global constant base. + Value basePtrAddr = rewriter.create(loc, globalBase); + auto getEmbeddedConstPoolRef = getOrInsertExternFunc( + KrnlPackedConstantOp::getEmbeddedDataLoaderMethodName(), module, + LLVM::LLVMType::getFunctionTy( + llvmI8PtrTy, {llvmI64Ty}, /*isVarArg=*/false), + rewriter); + auto constPackSize = rewriter.create(loc, + LLVM::LLVMType::getInt64Ty(llvmDialect), + packedConstOp.size_in_bytesAttr()); + Value alloc = rewriter + .create(loc, getEmbeddedConstPoolRef, llvmI8PtrTy, + ArrayRef({constPackSize})) + .getResult(0); + rewriter.create(loc, alloc, basePtrAddr); + { + OpBuilder::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + + // Record constant pack *file path* as a global variable (by recording the + // file path string's underlying char array + its length). + const auto &fileNameAttr = packedConstOp.file_nameAttr(); + auto type = + LLVM::LLVMType::getArrayTy(LLVM::LLVMType::getInt8Ty(llvmDialect), + fileNameAttr.getValue().size()); + rewriter.create(loc, type, /*isConstant=*/true, + LLVM::Linkage::External, + mlir::KrnlPackedConstantOp::getConstPackFilePathSymbolName(), + fileNameAttr); + type = LLVM::LLVMType::getInt64Ty(llvmDialect); + rewriter.create(loc, type, /*isConstant=*/true, + LLVM::Linkage::External, + mlir::KrnlPackedConstantOp::getConstPackFilePathStrLenSymbolName(), + rewriter.getI64IntegerAttr(fileNameAttr.getValue().size())); + + // Record constant pack *file name* as a global variable (by recording the + // file name string's underlying char array + its length). + auto constPackFileName = + llvm::sys::path::filename(fileNameAttr.getValue()); + type = LLVM::LLVMType::getArrayTy( + LLVM::LLVMType::getInt8Ty(llvmDialect), constPackFileName.size()); + rewriter.create(loc, type, /*isConstant=*/true, + LLVM::Linkage::External, + mlir::KrnlPackedConstantOp::getConstPackFileNameSymbolName(), + rewriter.getStringAttr(constPackFileName)); + type = LLVM::LLVMType::getInt64Ty(llvmDialect); + rewriter.create(loc, type, /*isConstant=*/true, + LLVM::Linkage::External, + mlir::KrnlPackedConstantOp::getConstPackFileNameStrLenSymbolName(), + rewriter.getI64IntegerAttr(constPackFileName.size())); + + type = LLVM::LLVMType::getInt8Ty(llvmDialect); + rewriter.create(loc, type, /*isConstant=*/true, + LLVM::Linkage::External, + mlir::KrnlPackedConstantOp::getConstPackIsLESymbolName(), + rewriter.getI8IntegerAttr(packedConstOp.is_le())); + } + + rewriter.eraseOp(op); + return success(); + } + +private: + static int64_t ArrayAttrIntVal(ArrayAttr a, int i) { + return (a.getValue()[i]).cast().getInt(); + } +}; } // end namespace //===----------------------------------------------------------------------===// @@ -712,7 +845,8 @@ void KrnlToLLVMLoweringPass::runOnOperation() { /*emitCWrapperS=*/true, /*useAlignedAlloc=*/false); - patterns.insert(&getContext(), typeConverter); + patterns.insert( + &getContext(), typeConverter); patterns.insert(&getContext(), typeConverter); // Lower from the `krnl` dialect i.e. the Reshape operation. @@ -722,8 +856,9 @@ void KrnlToLLVMLoweringPass::runOnOperation() { // We want to completely lower to LLVM, so we use a `FullConversion`. This // ensures that only legal operations will remain after the conversion. if (failed(applyFullConversion( - getOperation(), target, patterns, &typeConverter))) + getOperation(), target, patterns, &typeConverter))) { signalPassFailure(); + } } /// Create the pass for lowering `Krnl`, `Affine` and `Std` dialects to LLVM. diff --git a/src/Transform/ONNX/ShapeInferencePass.cpp b/src/Transform/ONNX/ShapeInferencePass.cpp index 2e67d17..1ea68bb 100644 --- a/src/Transform/ONNX/ShapeInferencePass.cpp +++ b/src/Transform/ONNX/ShapeInferencePass.cpp @@ -50,8 +50,9 @@ public: int64_t dynamicOperations = 0; f.walk([&](mlir::Operation *op) { - if (returnsDynamicShape(op)) + if (returnsDynamicShape(op)) { dynamicOperations++; + } }); // If any dynamic operations remain, this indicates a failure. diff --git a/src/Transform/PackKrnlGlobalConstants.cpp b/src/Transform/PackKrnlGlobalConstants.cpp new file mode 100644 index 0000000..65c6166 --- /dev/null +++ b/src/Transform/PackKrnlGlobalConstants.cpp @@ -0,0 +1,127 @@ +//===- ElideKrnlGlobalConstants.cpp - Krnl Constant lobal Value Elision ---===// +// +// Copyright 2019-2020 The IBM Research Authors. +// +// ============================================================================= +// +// In practice, the constant values of Global Krnl operations may be large +// enough to hinder the readability of the MLIR intermediate representation. +// +// This file creates a pass which elides the explicit values of constant +// global operations. This pass has purely cosmetic purposes and should only be +// run to obtain a compact representation of the program when emitting Krnl +// dialect code. This pass should never be invoked on code meant to be run. +// +//===----------------------------------------------------------------------===// +#include + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/FileSystem.h" + +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "src/Pass/Passes.hpp" +#include "src/Transform/ElideKrnlGlobalConstants.hpp" + +using namespace mlir; + +namespace { + +/*! + * Function pass that performs constant value elision of Krnl globals. + */ +class PackKrnlGlobalConstantsPass + : public PassWrapper> { +public: + /// Make sure that we have a valid default constructor and copy constructor to + /// make sure that the options are initialized properly. + PackKrnlGlobalConstantsPass() = default; + PackKrnlGlobalConstantsPass(const PackKrnlGlobalConstantsPass &pass) {} + + void runOnOperation() override { + auto module = getOperation(); + OpBuilder builder(&getContext()); + + // Packing constant arrays to packedConst. + std::vector packedConst; + module.walk([&](KrnlGlobalOp op) { + assert(op.value()); + op.offsetAttr(builder.getI64IntegerAttr(packedConst.size())); + assert(op.value()->isa()); + const auto &denseAttr = op.valueAttr().cast(); + auto numElements = denseAttr.getNumElements(); + if (numElements <= elisionThreshold) + return; + + // TODO(tjingrant) verify we can actually use the raw data. + std::vector rawData = denseAttr.getRawData(); + packedConst.insert(packedConst.end(), rawData.begin(), rawData.end()); + }); + + // Remove value attributes from krnl constant op. + ConversionTarget target(getContext()); + OwningRewritePatternList patterns; + patterns.insert( + &getContext(), elisionThreshold); + // Apply constant value elision. + module.walk( + [&](FuncOp func) { applyPatternsAndFoldGreedily(func, patterns); }); + + bool isLE = llvm::support::endian::system_endianness() == + llvm::support::endianness::little; + mlir::OperationState state(module.getLoc(), "krnl.packed_const"); + KrnlPackedConstantOp::build(builder, state, + builder.getIntegerType(/*width=*/64), + /*size_in_bytes=*/builder.getI64IntegerAttr(packedConst.size()), + /*is_le=*/builder.getBoolAttr(isLE), + /*value=*/nullptr, + /*file_name=*/nullptr); + auto packedConstOp = + llvm::cast(mlir::Operation::create(state)); + module.insert(module.begin(), packedConstOp); + if (moveToFile) { + std::string pathStr; + if (filename.hasValue()) { + pathStr = filename.getValue(); + } else { + llvm::SmallVector path; + llvm::sys::fs::createTemporaryFile("packed_const", "tmp", path); + pathStr = std::string(path.begin(), path.end()); + } + packedConstOp.file_nameAttr(builder.getStringAttr(pathStr)); + std::ofstream outfile(pathStr, std::ofstream::binary); + outfile.write(packedConst.data(), packedConst.size()); + } else { + auto shapeTy = + RankedTensorType::get({static_cast(packedConst.size())}, + builder.getIntegerType(8)); + auto denseAttr = + DenseIntElementsAttr::get(shapeTy, llvm::makeArrayRef(packedConst)); + packedConstOp.valueAttr(denseAttr); + } + } + + Option moveToFile{*this, "move-to-file", + llvm::cl::desc("Whether to move the packed constant to a file."), + llvm::cl::init(true)}; + Option elisionThreshold{*this, "elision-threshold", + llvm::cl::desc( + "A threshold value specifying the maximum number of elements a " + "constant operation can hold as an attribute. If the number exceeds " + "this threshold, constants will be packed together and, in the case " + "where `move-to-file` option is enabled, stored as a binary file on " + "disk. This can help preserve readability of IR dump and improve " + "compilation speed."), + llvm::cl::init(KrnlConstGlobalValueElision::kDefaultElisionThreshold)}; + Option filename{*this, "filename", + llvm::cl::desc( + "Specify a file in which the packed constant is to be stored.")}; +}; +} // namespace + +std::unique_ptr mlir::createPackKrnlGlobalConstantsPass() { + return std::make_unique(); +} + +static PassRegistration pass("pack-krnl-constants", + "Elide the constant values of the Global Krnl operations."); \ No newline at end of file diff --git a/test/backend/test.py b/test/backend/test.py index 5792967..d46a27d 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -6,6 +6,7 @@ from __future__ import unicode_literals import os import sys import unittest +import warnings import onnx.backend.base import onnx.backend.test @@ -29,7 +30,57 @@ from PyRuntime import ExecutionSession def execute_commands(cmds): if (VERBOSE): print(" ".join(cmds)) - subprocess.run(cmds, stdout=subprocess.PIPE) + subprocess.run(cmds) + + +# There are two issues, which necessitates the adoption of this endianness +# aware wrapper around Execution Session: +# 1. Input arrays are given sometimes in native byte order, sometime in +# LE byte order, and as soon as the python array enters into py::array +# C++ objects through pybind, we will no longer be able to query their +# endianness. So we must intercept the inputs and convert them into +# native endianness. +# 2. Output arrays are compared with reference outputs, the comparison +# unfortunately includes checking that our outputs and reference outputs +# share the same endianness. So we try to figure out what is the desired +# reference output endianness, and convert our outputs to this desired +# endianness. +class EndiannessAwareExecutionSession(ExecutionSession): + def __init__(self, path, entry_point): + super().__init__(path, entry_point) + + def is_input_le(self, inputs): + inputs_endianness = list(map(lambda x: x.dtype.byteorder, inputs)) + endianness_is_consistent = len(set(inputs_endianness)) <= 1 + assert endianness_is_consistent, \ + "Input arrays contain a mixture of endianness configuration." + + sys_is_le = sys.byteorder == 'little' + # To interpret character symbols indicating endianness: + # https://numpy.org/doc/stable/reference/generated/numpy.dtype.byteorder.html + explicitly_le = inputs_endianness[0] == "<" + implicitly_le = (inputs_endianness[0] == "=" and sys_is_le) + return explicitly_le or implicitly_le + + def run(self, inputs, **kwargs): + if len(inputs): + # Deduce desired endianness of output from inputs. + sys_is_le = sys.byteorder == 'little' + inp_is_le = self.is_input_le(inputs) + if (sys_is_le != inp_is_le): + inputs = list( + map(lambda x: x.byteswap().newbyteorder(), inputs)) + outputs = super().run(inputs) + if (sys_is_le != inp_is_le): + outputs = list( + map(lambda x: x.byteswap().newbyteorder(), outputs)) + return outputs + else: + # Can't deduce desired output endianess, fingers crossed. + warnings.warn( + "Cannot deduce desired output endianness, using native endianness by default." + ) + return super().run(inputs) class DummyBackend(onnx.backend.base.Backend): @@ -40,7 +91,8 @@ class DummyBackend(onnx.backend.base.Backend): onnx.save(model, "temp_model.onnx") # Call frontend to process temp_model.onnx, bit code will be generated. execute_commands([ONNX_MLIR, "temp_model.onnx"]) - return ExecutionSession("./temp_model.so", "_dyn_entry_point_main_graph") + return EndiannessAwareExecutionSession("./temp_model.so", + "_dyn_entry_point_main_graph") @classmethod def supports_device(cls, device): @@ -80,7 +132,6 @@ test_to_enable = [ "test_concat_3d_axis_0_cpu", "test_concat_3d_axis_1_cpu", "test_concat_3d_axis_2_cpu", - "test_concat_1d_axis_negative_1_cpu", "test_concat_2d_axis_negative_1_cpu", "test_concat_2d_axis_negative_2_cpu", @@ -359,12 +410,18 @@ test_to_enable = [ "test_split_variable_parts_1d_cpu", "test_split_variable_parts_2d_cpu", "test_split_variable_parts_default_axis_cpu", + + # ResNet + "test_resnet50_cpu", ] # Extract name of all test cases. import inspect -all_tests = inspect.getmembers( +all_tests = [] +all_tests += inspect.getmembers( + backend_test.test_cases["OnnxBackendRealModelTest"]) +all_tests += inspect.getmembers( backend_test.test_cases["OnnxBackendNodeModelTest"]) all_test_names = list(map(lambda x: x[0], all_tests)) @@ -372,8 +429,7 @@ all_test_names = list(map(lambda x: x[0], all_tests)) for test_name in test_to_enable: assert test_name in all_test_names, """test name {} not found, it is likely that you may have misspelled the test name or the specified test does not - exist in the version of onnx package you installed.""".format( - test_name) + exist in the version of onnx package you installed.""".format(test_name) backend_test.include(r"^{}$".format(test_name)) # import all test cases at global scope to make them visible to python.unittest diff --git a/test/mlir/CMakeLists.txt b/test/mlir/CMakeLists.txt index 86db166..09ac94c 100644 --- a/test/mlir/CMakeLists.txt +++ b/test/mlir/CMakeLists.txt @@ -6,7 +6,7 @@ configure_lit_site_cfg(${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in MAIN_CONFIG ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py) -set(ONNX_MLIR_TEST_DEPENDS onnx-mlir-opt) +set(ONNX_MLIR_TEST_DEPENDS onnx-mlir-opt binary-decoder) add_lit_testsuite(check-onnx-lit "Running the ONNX MLIR regression tests" diff --git a/test/mlir/krnl/pack_consts_mix_types.mlir b/test/mlir/krnl/pack_consts_mix_types.mlir new file mode 100644 index 0000000..87041b7 --- /dev/null +++ b/test/mlir/krnl/pack_consts_mix_types.mlir @@ -0,0 +1,10 @@ +// RUN: onnx-mlir-opt --pack-krnl-constants='elision-threshold=3 move-to-file=true filename=test1.bin' %s -split-input-file && binary-decoder test1.bin -s 0 -n 16 --onnx::TensorProto::FLOAT -rm | FileCheck %s -check-prefix=BINARY_DECODER_1 +// RUN: onnx-mlir-opt --pack-krnl-constants='elision-threshold=3 move-to-file=true filename=test2.bin' %s -split-input-file && binary-decoder test2.bin -s 16 -n 32 --onnx::TensorProto::INT32 -rm | FileCheck %s -check-prefix=BINARY_DECODER_2 + +// BINARY_DECODER_1: 0.1 0.2 0.3 0.4 +// BINARY_DECODER_2: 1 2 3 4 +func @test_krnl_const_packing_file_mixing_types() -> memref<1x4xf32> { + %0 = "krnl.global"() {name = "constant_0", shape = [1, 4], value = dense<[[0.1, 0.2, 0.3, 0.4]]> : tensor<1x4xf32>} : () -> memref<1x4xf32> + %1 = "krnl.global"() {name = "constant_1", shape = [1, 4], value = dense<[[1, 2, 3, 4]]> : tensor<1x4xi32>} : () -> memref<1x4xi32> + return %0 : memref<1x4xf32> +} \ No newline at end of file diff --git a/test/mlir/krnl/pack_krnl_constants_be/lit.local.cfg b/test/mlir/krnl/pack_krnl_constants_be/lit.local.cfg new file mode 100644 index 0000000..8ad73b1 --- /dev/null +++ b/test/mlir/krnl/pack_krnl_constants_be/lit.local.cfg @@ -0,0 +1,8 @@ +import sys + +root = config.root + +if sys.byteorder == "little": + config.unsupported = True +else: + config.unsupported = False \ No newline at end of file diff --git a/test/mlir/krnl/pack_krnl_constants_be/pack_krnl_constants.mlir b/test/mlir/krnl/pack_krnl_constants_be/pack_krnl_constants.mlir new file mode 100644 index 0000000..ddce9ac --- /dev/null +++ b/test/mlir/krnl/pack_krnl_constants_be/pack_krnl_constants.mlir @@ -0,0 +1,11 @@ +// RUN: onnx-mlir-opt --pack-krnl-constants='elision-threshold=3 move-to-file=false' %s -split-input-file | FileCheck %s + +// CHECK: [[CONST_PACK:%.+]] = "krnl.packed_const"() {is_le = false, size_in_bytes = 32 : i64, value = dense<[0, 0, 0, 0, 63, -128, 0, 0, 64, 0, 0, 0, 64, 64, 0, 0, 0, 0, 0, 0, 63, -128, 0, 0, 64, 0, 0, 0, 64, 64, 0, 0]> : tensor<32xi8>} : () -> i64 +// CHECK-LABEL: func @test_krnl_const_packing() -> memref<1x4xf32> { +// CHECK-NEXT: [[CONST0:%.+]] = "krnl.global"() {name = "constant_0", offset = 0 : i64, shape = [1, 4]} : () -> memref<1x4xf32> +// CHECK-NEXT: [[CONST1:%.+]] = "krnl.global"() {name = "constant_1", offset = 16 : i64, shape = [1, 4]} : () -> memref<1x4xf32> +func @test_krnl_const_packing() -> memref<1x4xf32> { + %0 = "krnl.global"() {name = "constant_0", shape = [1, 4], value = dense<[[0., 1., 2., 3.]]> : tensor<1x4xf32>} : () -> memref<1x4xf32> + %1 = "krnl.global"() {name = "constant_1", shape = [1, 4], value = dense<[[0., 1., 2., 3.]]> : tensor<1x4xf32>} : () -> memref<1x4xf32> + return %0 : memref<1x4xf32> +} \ No newline at end of file diff --git a/test/mlir/krnl/pack_krnl_constants_le/lit.local.cfg b/test/mlir/krnl/pack_krnl_constants_le/lit.local.cfg new file mode 100644 index 0000000..c9cd0e9 --- /dev/null +++ b/test/mlir/krnl/pack_krnl_constants_le/lit.local.cfg @@ -0,0 +1,8 @@ +import sys + +root = config.root + +if sys.byteorder == "little": + config.unsupported = False +else: + config.unsupported = True \ No newline at end of file diff --git a/test/mlir/krnl/pack_krnl_constants_le/pack_krnl_constants.mlir b/test/mlir/krnl/pack_krnl_constants_le/pack_krnl_constants.mlir new file mode 100644 index 0000000..938357e --- /dev/null +++ b/test/mlir/krnl/pack_krnl_constants_le/pack_krnl_constants.mlir @@ -0,0 +1,11 @@ +// RUN: onnx-mlir-opt --pack-krnl-constants='elision-threshold=3 move-to-file=false' %s -split-input-file | FileCheck %s + +// CHECK: [[CONST_PACK:%.+]] = "krnl.packed_const"() {is_le = true, size_in_bytes = 32 : i64, value = dense<[0, 0, 0, 0, 0, 0, -128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 0, 0, 0, 0, -128, 63, 0, 0, 0, 64, 0, 0, 64, 64]> : tensor<32xi8>} : () -> i64 +// CHECK-LABEL: func @test_krnl_const_packing() -> memref<1x4xf32> { +// CHECK-NEXT: [[CONST0:%.+]] = "krnl.global"() {name = "constant_0", offset = 0 : i64, shape = [1, 4]} : () -> memref<1x4xf32> +// CHECK-NEXT: [[CONST1:%.+]] = "krnl.global"() {name = "constant_1", offset = 16 : i64, shape = [1, 4]} : () -> memref<1x4xf32> +func @test_krnl_const_packing() -> memref<1x4xf32> { + %0 = "krnl.global"() {name = "constant_0", shape = [1, 4], value = dense<[[0., 1., 2., 3.]]> : tensor<1x4xf32>} : () -> memref<1x4xf32> + %1 = "krnl.global"() {name = "constant_1", shape = [1, 4], value = dense<[[0., 1., 2., 3.]]> : tensor<1x4xf32>} : () -> memref<1x4xf32> + return %0 : memref<1x4xf32> +} \ No newline at end of file diff --git a/test/mlir/krnl/pack_krnl_consts_to_file.mlir b/test/mlir/krnl/pack_krnl_consts_to_file.mlir new file mode 100644 index 0000000..c491ab9 --- /dev/null +++ b/test/mlir/krnl/pack_krnl_consts_to_file.mlir @@ -0,0 +1,8 @@ +// RUN: onnx-mlir-opt --pack-krnl-constants='elision-threshold=3 move-to-file=true filename=test-pack-consts-to-file-same-type.bin' %s -split-input-file && binary-decoder test-pack-consts-to-file-same-type.bin -s 0 -n 32 --onnx::TensorProto::FLOAT -rm | FileCheck %s + +// CHECK: 0 1 2 3 0 1 2 3 +func @test_krnl_const_packing_file() -> memref<1x4xf32> { + %0 = "krnl.global"() {name = "constant_0", shape = [1, 4], value = dense<[[0., 1., 2., 3.]]> : tensor<1x4xf32>} : () -> memref<1x4xf32> + %1 = "krnl.global"() {name = "constant_1", shape = [1, 4], value = dense<[[0., 1., 2., 3.]]> : tensor<1x4xf32>} : () -> memref<1x4xf32> + return %0 : memref<1x4xf32> +} \ No newline at end of file diff --git a/test/mlir/lit.cfg.py b/test/mlir/lit.cfg.py index ac22dc2..7dc6819 100644 --- a/test/mlir/lit.cfg.py +++ b/test/mlir/lit.cfg.py @@ -31,7 +31,7 @@ tool_dirs = [ config.onnx_mlir_tools_dir, config.mlir_tools_dir, config.llvm_tools_dir ] tool_names = [ - 'onnx-mlir-opt', 'mlir-opt', 'mlir-translate' + 'onnx-mlir-opt', 'mlir-opt', 'mlir-translate', "binary-decoder" ] tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] llvm_config.add_tool_substitutions(tools, tool_dirs) \ No newline at end of file diff --git a/test/mlir/lit.site.cfg.py.in b/test/mlir/lit.site.cfg.py.in index 28df6b2..b820168 100644 --- a/test/mlir/lit.site.cfg.py.in +++ b/test/mlir/lit.site.cfg.py.in @@ -1,4 +1,3 @@ - import lit.llvm if '@BUILD_SHARED_LIBS@' == 'ON': diff --git a/test/mlir/onnx/onnx_krnl_global_elision.mlir b/test/mlir/onnx/onnx_krnl_global_elision.mlir index 9bce410..31aa2d1 100644 --- a/test/mlir/onnx/onnx_krnl_global_elision.mlir +++ b/test/mlir/onnx/onnx_krnl_global_elision.mlir @@ -1,10 +1,10 @@ // RUN: onnx-mlir-opt --elide-krnl-constants %s -split-input-file | FileCheck %s -// CHECK-LABEL: func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x10xf32> -func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x10xf32> { - %0 = "krnl.global"() {name = "constant_0", shape = [1, 10], value = dense<[[-0.0448560268, 0.00779166119, 0.0681008175, 0.0299937408, -0.126409635, 0.14021875, -0.0552849025, -0.0493838154, 0.0843220502, -0.0545404144]]> : tensor<1x10xf32>} : () -> memref<1x10xf32> - return %0 : memref<1x10xf32> +// CHECK-LABEL: func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x70xf32> +func @test_elide_krnl_global_constant(%arg0: memref<1xf32>) -> memref<1x70xf32> { + %0 = "krnl.global"() {name = "constant_0", shape = [1, 70], value = dense<[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]> : tensor<1x70xf32>} : () -> memref<1x70xf32> + return %0 : memref<1x70xf32> - // CHECK: %0 = "krnl.global"() {name = "constant_0", shape = [1, 10]} : () -> memref<1x10xf32> - // CHECK: return %0 : memref<1x10xf32> + // CHECK: {{.*}} = "krnl.global"() {name = "constant_0", shape = [1, 70]} : () -> memref<1x70xf32> + // CHECK: return {{.*}} : memref<1x70xf32> } \ No newline at end of file diff --git a/test/numerical/TestConv.cpp b/test/numerical/TestConv.cpp index c2cbd66..c381a10 100644 --- a/test/numerical/TestConv.cpp +++ b/test/numerical/TestConv.cpp @@ -37,7 +37,7 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H, llvm::SmallVector outputsType{yType}; auto funcType = builder.getFunctionType(inputsType, outputsType); - string funcName = "test_conv"; + string funcName = "main_graph"; llvm::SmallVector attrs; auto funcOp = builder.create(UnknownLoc::get(&ctx), funcName, funcType, attrs); @@ -88,13 +88,13 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H, OwningModuleRef moduleRef(module); llvm::SmallVector path; - llvm::sys::fs::createTemporaryFile("_test_conv", "", path); + llvm::sys::fs::createTemporaryFile("_main_graph", "", path); string pathStr(path.begin(), path.end()); llvm::FileRemover remover(path); compileModule(moduleRef, ctx, pathStr, EmitLib); onnx_mlir::ExecutionSession sess( - pathStr + ".so", "_dyn_entry_point_test_conv"); + pathStr + ".so", "_dyn_entry_point_main_graph"); std::vector> inputs; auto xDmr = unique_ptr(getRndRealDmr({N, C, H, W}));