From d235f248e420f973931d79eb18123b3d1d0639ef Mon Sep 17 00:00:00 2001 From: gongsu832 Date: Sat, 11 Jul 2020 01:23:13 -0400 Subject: [PATCH] Add emitjni target (#204) * Detect llvm-project commit change in utils/clone-mlir.sh and rebuild llvm-project for zLinux Jenkins build bot * Add --EmitJNI target (tested working with mnist and resnet50) - MainUtils * first shot at refactoring compileModuleToSharedLibrary * add setExecPath call to allow resolving runtime directory from onnx-mlir executable path when ONNX_MLIR_RUNTIME_DIR is not set. This allows tests to run without having to install onnx-mlir or to explicitly set ONNX_MLIR_RUNTIME_DIR - RtMemRef * add getDataSize for C (equivalent of size() for C++). * fix setStrides bug (setting sizes, not strides) - TestConv * _main_graph-*.so were filling up /tmp. Change to use fixed shared library in build directory * Fix clang-format-lint complaints * - getRuntimeDir checks lib64 - install targets for javaruntime and jniruntime - remove ONNX_MLIR_LD_PRELOAD_onnx-mlir and ONNX_MLIR_LD_PRELOAD_onnx-mlir-opt * See what happens when `kExecPath` decl is dropped. Co-authored-by: Tian Jin --- .buildbot/z13.sh | 4 +- src/CMakeLists.txt | 6 - src/ExternalUtil.hpp.in | 4 + src/MainUtils.cpp | 180 ++++++-- src/MainUtils.hpp | 6 + src/Runtime/CMakeLists.txt | 2 + src/Runtime/RtMemRef.cpp | 9 +- src/Runtime/RtMemRef.h | 5 +- src/Runtime/jni/CMakeLists.txt | 30 ++ .../jni/com_ibm_onnxmlir_DynEntryPoint.h | 22 + src/Runtime/jni/jnidummy.c | 7 + src/Runtime/jni/jnilog.c | 94 ++++ src/Runtime/jni/jnilog.h | 125 ++++++ src/Runtime/jni/jniwrapper.c | 356 ++++++++++++++++ .../src/com/ibm/onnxmlir/DynEntryPoint.java | 64 +++ .../com/ibm/onnxmlir/OrderedRtMemRefDict.java | 118 ++++++ .../jni/src/com/ibm/onnxmlir/RtMemRef.java | 400 ++++++++++++++++++ src/Tool/ONNXMLIROpt/CMakeLists.txt | 6 - src/main.cpp | 5 +- test/numerical/TestConv.cpp | 16 +- 20 files changed, 1404 insertions(+), 55 deletions(-) create mode 100644 src/Runtime/jni/CMakeLists.txt create mode 100644 src/Runtime/jni/com_ibm_onnxmlir_DynEntryPoint.h create mode 100644 src/Runtime/jni/jnidummy.c create mode 100644 src/Runtime/jni/jnilog.c create mode 100644 src/Runtime/jni/jnilog.h create mode 100644 src/Runtime/jni/jniwrapper.c create mode 100644 src/Runtime/jni/src/com/ibm/onnxmlir/DynEntryPoint.java create mode 100644 src/Runtime/jni/src/com/ibm/onnxmlir/OrderedRtMemRefDict.java create mode 100644 src/Runtime/jni/src/com/ibm/onnxmlir/RtMemRef.java diff --git a/.buildbot/z13.sh b/.buildbot/z13.sh index 3b9b2b5..2e2f01b 100755 --- a/.buildbot/z13.sh +++ b/.buildbot/z13.sh @@ -71,5 +71,5 @@ cmake -DCMAKE_INSTALL_PREFIX=${INSTALL_PATH} .. \ make -j$(nproc) make -j$(nproc) check-onnx-lit -RUNTIME_DIR=$(pwd)/lib make -j$(nproc) check-onnx-backend -RUNTIME_DIR=$(pwd)/lib PATH=$(pwd)/bin:$PATH make -j$(nproc) test +make -j$(nproc) check-onnx-backend +PATH=$(pwd)/bin:$PATH make -j$(nproc) test diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index da88be1..f28faca 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -55,12 +55,6 @@ endif() configure_file(${CMAKE_CURRENT_SOURCE_DIR}/ExternalUtil.hpp.in ${CMAKE_CURRENT_BINARY_DIR}/ExternalUtil.hpp) -set(ONNX_MLIR_LD_PRELOAD_onnx-mlir "" CACHE STRING "" FORCE) -if(BUILD_SHARED_LIBS) - message(STATUS "To run dynamically linked onnx-mlir, you must specify:") - message(STATUS "LD_PRELOAD=${ONNX_MLIR_LD_PRELOAD_onnx-mlir}") -endif() - # Libraries specified on the target_link_libraries for the add_subdirectory # targets get added to the end of the list here. This creates two problems: # 1. It produces duplicated libraries being specified for the link command. diff --git a/src/ExternalUtil.hpp.in b/src/ExternalUtil.hpp.in index 5aa81a2..4d3662f 100644 --- a/src/ExternalUtil.hpp.in +++ b/src/ExternalUtil.hpp.in @@ -3,8 +3,12 @@ #include namespace onnx_mlir { +std::string kExecPath = "@CMAKE_INSTALL_PREFIX@/bin/onnx-mlir"; /* fallback if not set by main */ +const std::string kInstPath = "@CMAKE_INSTALL_PREFIX@"; const std::string kLlcPath = "@LLVM_PROJ_BUILD@/bin/llc"; const std::string kCxxPath = "@CMAKE_CXX_COMPILER@"; const std::string kLinkerPath = "@CMAKE_LINKER@"; const std::string kObjCopyPath = "@CMAKE_OBJCOPY@"; +const std::string kArPath = "@CMAKE_AR@"; +const std::string kJarPath = "@Java_JAR_EXECUTABLE@"; } // namespace onnx_mlir diff --git a/src/MainUtils.cpp b/src/MainUtils.cpp index b916bb1..0a43d1d 100644 --- a/src/MainUtils.cpp +++ b/src/MainUtils.cpp @@ -1,4 +1,4 @@ -//===--------------------------- main_utils.cpp ---------------------------===// +//===--------------------------- MainUtils.cpp ---------------------------===// // // Copyright 2019-2020 The IBM Research Authors. // @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -22,8 +23,6 @@ #include "src/ExternalUtil.hpp" #include "src/MainUtils.hpp" -#include "MainUtils.hpp" - #ifdef _WIN32 #include #else @@ -41,6 +40,42 @@ llvm::Optional getEnvVar(std::string name) { return llvm::None; } +// Runtime directory contains all the libraries, jars, etc. that are +// necessary for running onnx-mlir. It's resolved in the following order: +// +// - if ONNX_MLIR_RUNTIME_DIR is set, use it, otherwise +// - get path from where onnx-mlir is run, if it's of the form +// /foo/bar/bin/onnx-mlir, +// the runtime directory is /foo/bar/lib (note that when onnx-mlir is +// installed system wide, which is typically /usr/local/bin, this will +// correctly resolve to /usr/local/lib), but some systems still have +// lib64 so we check that first. If neither exists, then +// - use CMAKE_INSTALL_PREFIX/lib, which is typically /usr/local/lib +string getRuntimeDir() { + const auto &envDir = getEnvVar("ONNX_MLIR_RUNTIME_DIR"); + if (envDir && llvm::sys::fs::exists(envDir.getValue())) + return envDir.getValue(); + + string execDir = llvm::sys::path::parent_path(kExecPath).str(); + if (llvm::sys::path::stem(execDir).str().compare("bin") == 0) { + string p = execDir.substr(0, execDir.size() - 3); + if (llvm::sys::fs::exists(p + "lib64")) + return p + "lib64"; + if (llvm::sys::fs::exists(p + "lib")) + return p + "lib"; + } + + llvm::SmallString<8> instDir64(kInstPath); + llvm::sys::path::append(instDir64, "lib64"); + string p = llvm::StringRef(instDir64).str(); + if (llvm::sys::fs::exists(p)) + return p; + + llvm::SmallString<8> instDir(kInstPath); + llvm::sys::path::append(instDir, "lib"); + return llvm::StringRef(instDir).str(); +} + // Helper struct to make command construction and execution easy & readable. struct Command { std::string _path; @@ -97,6 +132,12 @@ struct Command { }; } // namespace +void setExecPath(const char *argv0, void *fmain) { + string p; + if (!(p = llvm::sys::fs::getMainExecutable(argv0, fmain)).empty()) + kExecPath = p; +} + void LoadMLIR(string inputFilename, mlir::MLIRContext &context, mlir::OwningModuleRef &module) { // Handle '.mlir' input to the ONNX MLIR frontend. @@ -119,8 +160,8 @@ void LoadMLIR(string inputFilename, mlir::MLIRContext &context, } } -void compileModuleToSharedLibrary( - const mlir::OwningModuleRef &module, string outputBaseName) { +void genConstPackObj(const mlir::OwningModuleRef &module, + llvm::Optional &constPackObjPath) { // Extract constant pack file name, which is embedded as a symbol in the // module being compiled. auto constPackFilePathSym = (*module).lookupSymbol( @@ -131,7 +172,6 @@ void compileModuleToSharedLibrary( .str(); llvm::FileRemover constPackRemover(constPackFilePath); - llvm::Optional constPackObjPath; #if __APPLE__ // Create a empty stub file, compile it to an empty obj file. llvm::SmallVector stubSrcPath; @@ -153,7 +193,6 @@ void compileModuleToSharedLibrary( .appendList({"-sectcreate", "binary", "param", constPackFilePath}) .appendStr(stubObjPathStr) .exec(); - llvm::FileRemover constPackObjRemover(constPackObjPath.getValue()); #elif __linux__ // Create param.o holding packed parameter values. @@ -164,7 +203,6 @@ void compileModuleToSharedLibrary( .appendList({"-o", constPackObjPath.getValue()}) .appendStr(constPackFilePath) .exec(); - llvm::FileRemover constPackObjRemover(constPackObjPath.getValue()); // Figure out what is the default symbol name describing the start/end // address of the embedded data. @@ -204,40 +242,119 @@ void compileModuleToSharedLibrary( mlir::KrnlPackedConstantOp::getConstPackFileNameStrLenSymbolName()) .valueAttr(builder.getI64IntegerAttr(constPackFileName.str().size())); #endif +} - // Write LLVM bitcode. - string outputFilename = outputBaseName + ".bc"; +// Write LLVM bitcode. +void genLLVMBitcode(const mlir::OwningModuleRef &module, string bitcodePath) { error_code error; + llvm::raw_fd_ostream moduleBitcodeStream( - outputFilename, error, llvm::sys::fs::F_None); + bitcodePath, error, llvm::sys::fs::F_None); llvm::WriteBitcodeToFile( *mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream); moduleBitcodeStream.flush(); - llvm::FileRemover bcRemover(outputFilename); +} - // Compile LLVM bitcode to object file. +// Compile LLVM bitcode to object file. +void genModelObject(const mlir::OwningModuleRef &module, string bitcodePath, + string modelObjPath) { Command llvmToObj(/*exePath=*/kLlcPath); - llvmToObj.appendStr("-filetype=obj"); - llvmToObj.appendStr("-relocation-model=pic"); - llvmToObj.appendStr(outputFilename); - llvmToObj.exec(); - std::string modelObjPath = outputBaseName + ".o"; + llvmToObj.appendStr("-filetype=obj") + .appendStr("-relocation-model=pic") + .appendList({"-o", modelObjPath}) + .appendStr(bitcodePath) + .exec(); +} + +void genJniObject(const mlir::OwningModuleRef &module, string jniSharedLibPath, + string jniObjPath) { + Command ar(/*exePath=*/kArPath); + ar.appendStr("x").appendStr(jniSharedLibPath).appendStr(jniObjPath).exec(); +} + +// Link everything into a shared object. +void genSharedLib(const mlir::OwningModuleRef &module, + string modelSharedLibPath, std::vector opts, + std::vector objs, std::vector libs) { + + string runtimeDirInclFlag = "-L" + getRuntimeDir(); + + Command link(kCxxPath); + link.appendList(opts) + .appendList(objs) + .appendList({"-o", modelSharedLibPath}) + .appendStrOpt(runtimeDirInclFlag) + .appendList(libs) + .exec(); +} + +// Create jar containing java runtime and model shared library (which includes +// jni runtime). +void genJniJar(const mlir::OwningModuleRef &module, string modelSharedLibPath, + string modelJniJarPath) { + llvm::SmallString<8> runtimeDir(getRuntimeDir()); + llvm::sys::path::append(runtimeDir, "javaruntime.jar"); + string javaRuntimeJarPath = llvm::StringRef(runtimeDir).str(); + + // Copy javaruntime.jar to model jar. + llvm::sys::fs::copy_file(javaRuntimeJarPath, modelJniJarPath); + + // Add shared library to model jar. + Command jar(kJarPath); + jar.appendList({"uf", modelJniJarPath}).appendStr(modelSharedLibPath).exec(); +} + +void compileModuleToSharedLibrary( + const mlir::OwningModuleRef &module, std::string outputBaseName) { + + llvm::Optional constPackObjPath; + genConstPackObj(module, constPackObjPath); + llvm::FileRemover constPackObjRemover(constPackObjPath.getValue()); + + string bitcodePath = outputBaseName + ".bc"; + genLLVMBitcode(module, bitcodePath); + llvm::FileRemover bitcodeRemover(bitcodePath); + + string modelObjPath = outputBaseName + ".o"; + genModelObject(module, bitcodePath, modelObjPath); llvm::FileRemover modelObjRemover(modelObjPath); - llvm::Optional runtimeDirInclFlag; - if (getEnvVar("RUNTIME_DIR").hasValue()) - runtimeDirInclFlag = "-L" + getEnvVar("RUNTIME_DIR").getValue(); + string modelSharedLibPath = outputBaseName + ".so"; + genSharedLib(module, modelSharedLibPath, {"-shared", "-fPIC"}, + {constPackObjPath.getValueOr(""), modelObjPath}, + {"-lEmbeddedDataLoader", "-lcruntime"}); +} - // Link everything into a shared object. - Command link(kCxxPath); - link.appendList({"-shared", "-fPIC"}) - .appendStr(modelObjPath) - .appendStr(constPackObjPath.getValueOr("")) - .appendList({"-o", outputBaseName + ".so"}) - .appendStrOpt(runtimeDirInclFlag) - .appendList({"-lEmbeddedDataLoader", "-lcruntime"}) - .exec(); +void compileModuleToJniJar( + const mlir::OwningModuleRef &module, std::string outputBaseName) { + + llvm::Optional constPackObjPath; + genConstPackObj(module, constPackObjPath); + llvm::FileRemover constPackObjRemover(constPackObjPath.getValue()); + + string bitcodePath = outputBaseName + ".bc"; + genLLVMBitcode(module, bitcodePath); + llvm::FileRemover bitcodeRemover(bitcodePath); + + string modelObjPath = outputBaseName + ".o"; + genModelObject(module, bitcodePath, modelObjPath); + llvm::FileRemover modelObjRemover(modelObjPath); + + string jniSharedLibPath = getRuntimeDir() + "/libjniruntime.a"; + string jniObjPath = "jnidummy.c.o"; + genJniObject(module, jniSharedLibPath, jniObjPath); + llvm::FileRemover jniObjRemover(jniObjPath); + + string modelSharedLibPath = "libmodel.so"; + genSharedLib(module, modelSharedLibPath, + {"-shared", "-fPIC", "-z", "noexecstack"}, + {constPackObjPath.getValueOr(""), modelObjPath, jniObjPath}, + {"-lEmbeddedDataLoader", "-lcruntime", "-ljniruntime"}); + llvm::FileRemover modelSharedLibRemover(modelSharedLibPath); + + string modelJniJarPath = outputBaseName + ".jar"; + genJniJar(module, modelSharedLibPath, modelJniJarPath); } void registerDialects() { @@ -349,6 +466,9 @@ void emitOutputFiles(string outputBaseName, EmissionTargetType emissionTarget, // Write LLVM bitcode to disk, compile & link. compileModuleToSharedLibrary(module, outputBaseName); printf("Shared library %s.so has been compiled.\n", outputBaseName.c_str()); + } else if (emissionTarget == EmitJNI) { + compileModuleToJniJar(module, outputBaseName); + printf("JNI archive %s.jar has been compiled.\n", outputBaseName.c_str()); } else { // Emit the version with all constants included. outputCode(module, outputBaseName, ".onnx.mlir"); diff --git a/src/MainUtils.hpp b/src/MainUtils.hpp index 2ccb60f..b680bc7 100644 --- a/src/MainUtils.hpp +++ b/src/MainUtils.hpp @@ -43,14 +43,20 @@ enum EmissionTargetType { EmitMLIR, EmitLLVMIR, EmitLib, + EmitJNI, }; +void setExecPath(const char *argv0, void *fmain); + void LoadMLIR(std::string inputFilename, mlir::MLIRContext &context, mlir::OwningModuleRef &module); void compileModuleToSharedLibrary( const mlir::OwningModuleRef &module, std::string outputBaseName); +void compileModuleToJniJar( + const mlir::OwningModuleRef &module, std::string outputBaseName); + void registerDialects(); void addONNXToMLIRPasses(mlir::PassManager &pm); diff --git a/src/Runtime/CMakeLists.txt b/src/Runtime/CMakeLists.txt index 4f3cc9d..82137fa 100644 --- a/src/Runtime/CMakeLists.txt +++ b/src/Runtime/CMakeLists.txt @@ -1,3 +1,5 @@ +add_subdirectory(jni) + # Create static libcruntime.a to be embedded in model.so to make model.so self contained. # However, by default object code for static library is not compiled with -fPIC. Embedding # such static library in a shared library can cause runtime failure on some architectures, diff --git a/src/Runtime/RtMemRef.cpp b/src/Runtime/RtMemRef.cpp index 26c4859..518958f 100644 --- a/src/Runtime/RtMemRef.cpp +++ b/src/Runtime/RtMemRef.cpp @@ -150,6 +150,13 @@ int64_t *getStrides(RtMemRef *dynMemRef) { return dynMemRef->strides; } int64_t getSize(OrderedRtMemRefDict *dict) { return dict->orderedNames.size(); } +INDEX_TYPE getDataSize(RtMemRef *rtMemRef) { + INDEX_TYPE n = rtMemRef->sizes[0]; + for (int i = 1; i < rtMemRef->rank; i++) + n *= rtMemRef->sizes[i]; + return n; +} + void setDType(RtMemRef *dynMemRef, int onnxType) { dynMemRef->onnx_dtype = onnxType; } @@ -160,5 +167,5 @@ unsigned int getRank(RtMemRef *dynMemRef) { return dynMemRef->rank; } void setStrides(RtMemRef *dynMemRef, int64_t *strides) { for (int i = 0; i < dynMemRef->rank; i++) - dynMemRef->sizes[i] = strides[i]; + dynMemRef->strides[i] = strides[i]; } diff --git a/src/Runtime/RtMemRef.h b/src/Runtime/RtMemRef.h index 08de392..16d8306 100644 --- a/src/Runtime/RtMemRef.h +++ b/src/Runtime/RtMemRef.h @@ -151,6 +151,9 @@ OrderedRtMemRefDict *createOrderedRtMemRefDict(); // Get how many dynamic memrefs are in dict. int64_t getSize(OrderedRtMemRefDict *dict); +// Get how many data elements are in RtMemRef. +INDEX_TYPE getDataSize(RtMemRef *rtMemRef); + // Create a dynmemref with a certain rank. RtMemRef *createRtMemRef(int rank); @@ -277,4 +280,4 @@ inline bool isRmrClose( #endif // Will transition from RtMemRef to RtMemRef soon. -typedef RtMemRef RtMemRef; \ No newline at end of file +typedef RtMemRef RtMemRef; diff --git a/src/Runtime/jni/CMakeLists.txt b/src/Runtime/jni/CMakeLists.txt new file mode 100644 index 0000000..58068af --- /dev/null +++ b/src/Runtime/jni/CMakeLists.txt @@ -0,0 +1,30 @@ +find_package(Java COMPONENTS Development) +find_package(JNI) + +if(Java_Development_FOUND AND JNI_FOUND) + include(UseJava) + + # Target for Java runtime jar + add_jar(javaruntime + src/com/ibm/onnxmlir/DynEntryPoint.java + src/com/ibm/onnxmlir/OrderedRtMemRefDict.java + src/com/ibm/onnxmlir/RtMemRef.java + OUTPUT_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) + + # Target for JNI runtime lib + add_library(jniruntime STATIC + jniwrapper.c jnilog.c jnidummy.c + com_ibm_onnxmlir_DynEntryPoint.h jnilog.h ../RtMemRef.h) + set_target_properties(jniruntime PROPERTIES + POSITION_INDEPENDENT_CODE TRUE) + target_include_directories(jniruntime PRIVATE + ${ONNX_MLIR_SRC_ROOT}/src/Runtime + ${JAVA_INCLUDE_PATH} + ${JAVA_INCLUDE_PATH2}) + + install_jar(javaruntime DESTINATION lib) + install(TARGETS jniruntime DESTINATION lib) + +else() + message(WARNING "Java Development component or JNI not found, JNI targets will not work") +endif() diff --git a/src/Runtime/jni/com_ibm_onnxmlir_DynEntryPoint.h b/src/Runtime/jni/com_ibm_onnxmlir_DynEntryPoint.h new file mode 100644 index 0000000..979cfa7 --- /dev/null +++ b/src/Runtime/jni/com_ibm_onnxmlir_DynEntryPoint.h @@ -0,0 +1,22 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class com_ibm_onnxmlir_DynEntryPoint */ + +#ifndef _Included_com_ibm_onnxmlir_DynEntryPoint +#define _Included_com_ibm_onnxmlir_DynEntryPoint +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: com_ibm_onnxmlir_DynEntryPoint + * Method: main_graph_jni + * Signature: + * (Lcom/ibm/onnxmlir/OrderedRtMemRefDict;)Lcom/ibm/onnxmlir/OrderedRtMemRefDict; + */ +JNIEXPORT jobject JNICALL Java_com_ibm_onnxmlir_DynEntryPoint_main_1graph_1jni( + JNIEnv *, jclass, jobject); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/src/Runtime/jni/jnidummy.c b/src/Runtime/jni/jnidummy.c new file mode 100644 index 0000000..6ea584a --- /dev/null +++ b/src/Runtime/jni/jnidummy.c @@ -0,0 +1,7 @@ +#include "com_ibm_onnxmlir_DynEntryPoint.h" + +/* Dummy routine to force the link editor to embed code in libjniruntime.a + into libmodel.so */ +void __dummy_do_not_call__(JNIEnv *env, jclass cls, jobject obj) { + Java_com_ibm_onnxmlir_DynEntryPoint_main_1graph_1jni(NULL, NULL, NULL); +} diff --git a/src/Runtime/jni/jnilog.c b/src/Runtime/jni/jnilog.c new file mode 100644 index 0000000..f9fd6b3 --- /dev/null +++ b/src/Runtime/jni/jnilog.c @@ -0,0 +1,94 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "jnilog.h" + +static int log_initd = 0; +static int log_level; +static FILE *log_fp; + +/* Must match enum in log.h */ +static char *log_level_name[] = { + "trace", "debug", "info", "warning", "error", "fatal"}; + +/* Return numerical log level of give level name */ +static int get_log_level_by_name(char *name) { + int level = -1; + for (int i = 0; i < sizeof(log_level_name) / sizeof(char *); i++) { + if (!strcmp(name, log_level_name[i])) { + level = i; + break; + } + } + return level; +} + +/* Return FILE pointer of given file name */ +static FILE *get_log_file_by_name(char *name) { + FILE *fp = NULL; + if (!strcmp(name, "stdout")) + fp = stdout; + else if (!strcmp(name, "stderr")) + fp = stderr; + else + fp = fopen(name, "w"); + return fp; +} + +/* Initialize log system. Set default log level and file or use environment + * variables ONNX_MLIR_JNI_LOG_LEVEL and ONNX_MLIR_JNI_LOG_FILE, respectively. + */ +static void log_init() { + if (log_initd) + return; + + log_level = LOG_INFO; + char *strlevel = getenv("ONNX_MLIR_JNI_LOG_LEVEL"); + int level; + if (strlevel && (level = get_log_level_by_name(strlevel)) != -1) + log_level = level; + + log_fp = stderr; + char *strfname = getenv("ONNX_MLIR_JNI_LOG_FILE"); + FILE *fp; + if (strfname && (fp = get_log_file_by_name(strfname))) + log_fp = fp; + + tzset(); + log_initd = 1; +} + +/* Generic log routine */ +void log_printf( + int level, char *file, const char *func, int line, char *fmt, ...) { + if (!log_initd) + log_init(); + if (level < log_level) + return; + + time_t now; + struct tm *tm; + char buf[80]; + + /* Get local time and format as 2020-07-03 05:17:42 -0400 */ + if (time(&now) == -1 || (tm = localtime(&now)) == NULL || + strftime(buf, sizeof(buf), "%F %T %z", tm) == 0) + sprintf(buf, "-"); + + /* Output log prefix */ + fprintf(log_fp, "[%s][%s]%s:%s:%d ", buf, log_level_name[level], + basename(file), func, line); + + /* Output actually log data */ + va_list va_list; + va_start(va_list, fmt); + vfprintf(log_fp, fmt, va_list); + va_end(va_list); + + fprintf(log_fp, "\n"); +} diff --git a/src/Runtime/jni/jnilog.h b/src/Runtime/jni/jnilog.h new file mode 100644 index 0000000..0b8de76 --- /dev/null +++ b/src/Runtime/jni/jnilog.h @@ -0,0 +1,125 @@ +#ifndef __JNILOG_H__ +#define __JNILOG_H__ + +#include + +enum { LOG_TRACE, LOG_DEBUG, LOG_INFO, LOG_WARNING, LOG_ERROR, LOG_FATAL }; + +#define LOG_MAX_NUM 16 /* max number of elements to output */ + +#define MIN(x, y) ((x) > (y) ? y : x) + +/* Construct string of up to LOG_MAX_NUM elements of a char array */ +#define LOG_CHAR_BUF(buf, data, n) \ + do { \ + buf[0] = '\0'; \ + for (int i = 0; i < MIN(n, LOG_MAX_NUM); i++) \ + sprintf(buf + strlen(buf), " %02x", ((char *)data)[i]); \ + sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \ + } while (0) + +/* Construct string of up to LOG_MAX_NUM elements of a short array */ +#define LOG_SHORT_BUF(buf, data, n) \ + do { \ + buf[0] = '\0'; \ + for (int i = 0; i < MIN(n, LOG_MAX_NUM); i++) \ + sprintf(buf + strlen(buf), " %d", ((short *)data)[i]); \ + sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \ + } while (0) + +/* Construct string of up to LOG_MAX_NUM elements of a int array */ +#define LOG_INT_BUF(buf, data, n) \ + do { \ + buf[0] = '\0'; \ + for (int i = 0; i < MIN(n, LOG_MAX_NUM); i++) \ + sprintf(buf + strlen(buf), " %d", ((int *)data)[i]); \ + sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \ + } while (0) + +/* Construct string of up to LOG_MAX_NUM elements of a long array */ +#define LOG_LONG_BUF(buf, data, n) \ + do { \ + buf[0] = '\0'; \ + for (int i = 0; i < MIN(n, LOG_MAX_NUM); i++) \ + sprintf(buf + strlen(buf), " %ld", ((long *)data)[i]); \ + sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \ + } while (0) + +/* Construct string of up to LOG_MAX_NUM elements of a float array */ +#define LOG_FLOAT_BUF(buf, data, n) \ + do { \ + buf[0] = '\0'; \ + for (int i = 0; i < MIN(n, LOG_MAX_NUM); i++) \ + sprintf(buf + strlen(buf), " %f", ((float *)data)[i]); \ + sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \ + } while (0) + +/* Construct string of up to LOG_MAX_NUM elements of a double array */ +#define LOG_DOUBLE_BUF(buf, data, n) \ + do { \ + buf[0] = '\0'; \ + for (int i = 0; i < MIN(n, LOG_MAX_NUM); i++) \ + sprintf(buf + strlen(buf), " %f", ((double *)data)[i]); \ + sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \ + } while (0) + +enum { + ONNX_TYPE_UNDEFINED, /* 0 */ + ONNX_TYPE_FLOAT, /* 1 */ + ONNX_TYPE_UINT8, /* 2 */ + ONNX_TYPE_INT8, /* 3 */ + ONNX_TYPE_UINT16, /* 4 */ + ONNX_TYPE_INT16, /* 5 */ + ONNX_TYPE_INT32, /* 6 */ + ONNX_TYPE_INT64, /* 7 */ + ONNX_TYPE_STRING, /* 8 */ + ONNX_TYPE_BOOL, /* 9 */ + ONNX_TYPE_FLOAT16, /* 10 */ + ONNX_TYPE_DOUBLE, /* 11 */ + ONNX_TYPE_UINT32, /* 12 */ + ONNX_TYPE_UINT64, /* 13 */ + ONNX_TYPE_COMPLEX64, /* 14 */ + ONNX_TYPE_COMPLEX128, /* 15 */ + ONNX_TYPE_BFLOAT16, /* 16 */ +}; + +/* Construct string of up to LOG_MAX_NUM elements of a "type" array */ +#define LOG_TYPE_BUF(type, buf, data, n) \ + do { \ + switch (type) { \ + case ONNX_TYPE_UINT8: \ + case ONNX_TYPE_INT8: \ + LOG_CHAR_BUF(buf, data, n); \ + break; \ + case ONNX_TYPE_UINT16: \ + case ONNX_TYPE_INT16: \ + LOG_SHORT_BUF(buf, data, n); \ + break; \ + case ONNX_TYPE_UINT32: \ + case ONNX_TYPE_INT32: \ + LOG_INT_BUF(buf, data, n); \ + break; \ + case ONNX_TYPE_UINT64: \ + case ONNX_TYPE_INT64: \ + LOG_LONG_BUF(buf, data, n); \ + break; \ + case ONNX_TYPE_FLOAT: \ + LOG_FLOAT_BUF(buf, data, n); \ + break; \ + case ONNX_TYPE_DOUBLE: \ + LOG_DOUBLE_BUF(buf, data, n); \ + break; \ + defaut: \ + sprintf(buf, " unsupported data type %d ", type); \ + } \ + } while (0) + +/* Main macro for log output */ +#define LOG_PRINTF(level, ...) \ + log_printf(level, __FILE__, __FUNCTION__, __LINE__, __VA_ARGS__) + +/* Generic log routine */ +void log_printf( + int level, char *file, const char *func, int line, char *fmt, ...); + +#endif diff --git a/src/Runtime/jni/jniwrapper.c b/src/Runtime/jni/jniwrapper.c new file mode 100644 index 0000000..926251a --- /dev/null +++ b/src/Runtime/jni/jniwrapper.c @@ -0,0 +1,356 @@ +#include +#include +#include + +#include "RtMemRef.h" +#include "com_ibm_onnxmlir_DynEntryPoint.h" +#include "jnilog.h" + +/* Declare type var, make call and assign to var, check against val */ +#define CHECK_CALL(type, var, call, val) \ + type var = call; \ + if (var == val) \ + return NULL + +/* Make a JNI call and throw Java exception if the call failed */ +#define JNI_CALL(env, stmt) \ + stmt; \ + do { \ + jthrowable e = (*env)->ExceptionOccurred(env); \ + if (e) { \ + LOG_PRINTF(LOG_ERROR, "JNI call exception occurred"); \ + (*env)->Throw(env, e); \ + return NULL; \ + } \ + } while (0) + +/* Make a JNI call and assign return value to var, + * throw Java exception if the call failed + */ +#define JNI_VAR_CALL(env, var, call) JNI_CALL(env, var = call) + +/* Declare type var, make a JNI call and assign return value to var, + * throw Java exception if the call failed + */ +#define JNI_TYPE_VAR_CALL(env, type, var, call) JNI_CALL(env, type var = call); + +/* If cond is true (native code failed), log error and throw Java exception */ +#define JNI_COND(type, var, call, val, env, cls, ...) \ + type var = call; \ + do { \ + if (var == val) { \ + LOG_PRINTF(LOG_ERROR, __VA_ARGS__); \ + (*env)->ThrowNew(env, cls, "native code error"); \ + return NULL; \ + } \ + } while (0) + +/* Debug output of RtMemRef fields */ +#define RMR_DEBUG(i, type, rank, sizes, strides, data, datasize) \ + do { \ + char tmp[1024]; \ + LOG_PRINTF(LOG_DEBUG, "rmr[%d]:type=%d", i, type); \ + LOG_PRINTF(LOG_DEBUG, "rmr[%d]:rank=%d", i, rank); \ + LOG_LONG_BUF(tmp, sizes, rank); \ + LOG_PRINTF(LOG_DEBUG, "rmr[%d]:sizes=[%s]", i, tmp); \ + LOG_LONG_BUF(tmp, strides, rank); \ + LOG_PRINTF(LOG_DEBUG, "rmr[%d]:strides=[%s]", i, tmp); \ + LOG_TYPE_BUF(type, tmp, data, datasize); \ + LOG_PRINTF(LOG_DEBUG, "rmr[%d]:data=[%s]", i, tmp); \ + } while (0) + +/* Model shared library entry point */ +extern OrderedRtMemRefDict *_dyn_entry_point_main_graph(OrderedRtMemRefDict *); + +/* ONNX type to size (number of bytes) mapping */ +int onnx_type_size[] = { + 0, /* UNDEFINED = 0 */ + 4, /* FLOAT = 1 */ + 1, /* UINT8 = 2 */ + 1, /* INT8 = 3 */ + 2, /* UINT16 = 4 */ + 2, /* INT16 = 5 */ + 4, /* INT32 = 6 */ + 8, /* INT64 = 7 */ + 0, /* STRING = 8 */ + 1, /* BOOL = 9 */ + 2, /* FLOAT16 = 10 */ + 8, /* DOUBLE = 11 */ + 4, /* UINT32 = 12 */ + 8, /* UINT64 = 13 */ + 8, /* COMPLEX64 = 14 */ + 16, /* COMPLEX128 = 15 */ + 2, /* BFLOAT16 = 16 */ +}; + +/* Java classes and methods needed for making various JNI API calls */ +typedef struct { + jclass ecpt_cls; /* java/lang/Exception class */ + jclass long_cls; /* java/lang/Long class */ + jclass string_cls; /* java/lang/String class */ + jclass ormrd_cls; /* com/ibm/onnxmlir/OrderedRtMemRefDict class */ + jclass rmr_cls; /* com/ibm/onnxmlir/RtMemRef class */ + + jmethodID ormrd_constructor; /* OrderedRtMemRefDict constructor */ + jmethodID ormrd_getRmrs; /* OrderedRtMemRefDict getRmrs method */ + jmethodID ormrd_getNames; /* OrderedRtMemRefDict getNames method */ + + jmethodID rmr_constructor; /* RtMemRef constructor */ + jmethodID rmr_getType; /* RtMemRef getType method */ + jmethodID rmr_setType; /* RtMemRef setType method */ + jmethodID rmr_getRank; /* RtMemRef getRank method */ + jmethodID rmr_getData; /* RtMemRef getData method */ + jmethodID rmr_setData; /* RtMemRef setData method */ + jmethodID rmr_getSizes; /* RtMemRef getSizes method */ + jmethodID rmr_setSizes; /* RtMemRef setSizes method */ + jmethodID rmr_getStrides; /* RtMemRef getStrides method */ + jmethodID rmr_setStrides; /* RtMemRef setStrides method */ + jmethodID rmr_getDataSize; /* RtMemRef getDataSize method */ +} jniapi_t; + +jniapi_t jniapi; + +/* Fill in struct jniapi */ +jniapi_t *fill_jniapi(JNIEnv *env, jniapi_t *japi) { + /* Get Java Exception, Long, String, OrderedRtMemRefDict, and RtMemRef classes + */ + JNI_VAR_CALL( + env, japi->ecpt_cls, (*env)->FindClass(env, "java/lang/Exception")); + JNI_VAR_CALL(env, japi->long_cls, (*env)->FindClass(env, "java/lang/Long")); + JNI_VAR_CALL( + env, japi->string_cls, (*env)->FindClass(env, "java/lang/String")); + JNI_VAR_CALL(env, japi->ormrd_cls, + (*env)->FindClass(env, "com/ibm/onnxmlir/OrderedRtMemRefDict")); + JNI_VAR_CALL( + env, japi->rmr_cls, (*env)->FindClass(env, "com/ibm/onnxmlir/RtMemRef")); + + /* Get method ID of constructor and various methods in OrderedRtMemRefDict */ + JNI_VAR_CALL(env, japi->ormrd_constructor, + (*env)->GetMethodID( + env, japi->ormrd_cls, "", "([Lcom/ibm/onnxmlir/RtMemRef;)V")); + JNI_VAR_CALL(env, japi->ormrd_getRmrs, + (*env)->GetMethodID( + env, japi->ormrd_cls, "getRmrs", "()[Lcom/ibm/onnxmlir/RtMemRef;")); + JNI_VAR_CALL(env, japi->ormrd_getNames, + (*env)->GetMethodID( + env, japi->ormrd_cls, "getNames", "()[Ljava/lang/String;")); + + /* Get method ID of constructor and various methods in RtMemRef */ + JNI_VAR_CALL(env, japi->rmr_constructor, + (*env)->GetMethodID(env, japi->rmr_cls, "", "(I)V")); + JNI_VAR_CALL(env, japi->rmr_getType, + (*env)->GetMethodID(env, japi->rmr_cls, "getType", "()I")); + JNI_VAR_CALL(env, japi->rmr_setType, + (*env)->GetMethodID(env, japi->rmr_cls, "setType", "(I)V")); + JNI_VAR_CALL(env, japi->rmr_getRank, + (*env)->GetMethodID(env, japi->rmr_cls, "getRank", "()I")); + JNI_VAR_CALL(env, japi->rmr_getData, + (*env)->GetMethodID( + env, japi->rmr_cls, "getData", "()Ljava/nio/ByteBuffer;")); + JNI_VAR_CALL(env, japi->rmr_setData, + (*env)->GetMethodID( + env, japi->rmr_cls, "setData", "(Ljava/nio/ByteBuffer;)V")); + JNI_VAR_CALL(env, japi->rmr_getSizes, + (*env)->GetMethodID(env, japi->rmr_cls, "getSizes", "()[J")); + JNI_VAR_CALL(env, japi->rmr_setSizes, + (*env)->GetMethodID(env, japi->rmr_cls, "setSizes", "([J)V")); + JNI_VAR_CALL(env, japi->rmr_getStrides, + (*env)->GetMethodID(env, japi->rmr_cls, "getStrides", "()[J")); + JNI_VAR_CALL(env, japi->rmr_setStrides, + (*env)->GetMethodID(env, japi->rmr_cls, "setStrides", "([J)V")); + JNI_VAR_CALL(env, japi->rmr_getDataSize, + (*env)->GetMethodID(env, japi->rmr_cls, "getDataSize", "()J")); + + return japi; +} + +/* Convert Java object to native data structure */ +OrderedRtMemRefDict *ormrd_java_to_native( + JNIEnv *env, jclass cls, jobject obj, jniapi_t *japi) { + /* Get object array "rmrs" and "names" in OrderedRtMemRefDict */ + JNI_TYPE_VAR_CALL(env, jobjectArray, ormrd_rmrs, + (*env)->CallObjectMethod(env, obj, japi->ormrd_getRmrs)); + JNI_TYPE_VAR_CALL(env, jobjectArray, ormrd_names, + (*env)->CallObjectMethod(env, obj, japi->ormrd_getNames)); + + /* Get length of object array "rmrs" and "names" in OrderedRtMemRefDict */ + JNI_TYPE_VAR_CALL( + env, jsize, ormrd_rmrs_len, (*env)->GetArrayLength(env, ormrd_rmrs)); + JNI_TYPE_VAR_CALL( + env, jsize, ormrd_names_len, (*env)->GetArrayLength(env, ormrd_names)); + + /* Allocate memory for holding each Java rmr object and name string, + * and RtMemRef and char pointers for constructing native RtMemRef and name + * array + */ + JNI_COND(jobject *, obj_rmr, malloc(ormrd_rmrs_len * sizeof(jobject)), NULL, + env, japi->ecpt_cls, "obj_rmr=null"); + JNI_COND(jstring *, obj_name, malloc(ormrd_names_len * sizeof(jstring)), NULL, + env, japi->ecpt_cls, "obj_name=null"); + JNI_COND(RtMemRef **, jni_rmr, malloc(ormrd_rmrs_len * sizeof(RtMemRef *)), + NULL, env, japi->ecpt_cls, "jni_rmr=null"); + JNI_COND(const char **, jni_name, + malloc(ormrd_names_len * sizeof(const char *)), NULL, env, japi->ecpt_cls, + "jni_name=null"); + + /* Create OrderedRtMemRefDict to be constructed and passed to the model shared + * library */ + JNI_COND(OrderedRtMemRefDict *, ormrd, createOrderedRtMemRefDict(), NULL, env, + japi->ecpt_cls, "ormrd=null"); + + /* Loop through all the ormrd_rmrs and ormrd_names */ + for (int i = 0; i < ormrd_rmrs_len; i++) { + JNI_VAR_CALL( + env, obj_rmr[i], (*env)->GetObjectArrayElement(env, ormrd_rmrs, i)); + JNI_VAR_CALL( + env, obj_name[i], (*env)->GetObjectArrayElement(env, ormrd_names, i)); + + /* Get type, rank, data, sizes, and strides by calling corresponding methods + */ + JNI_TYPE_VAR_CALL(env, jint, rmr_type, + (*env)->CallIntMethod(env, obj_rmr[i], japi->rmr_getType)); + JNI_TYPE_VAR_CALL(env, jint, rmr_rank, + (*env)->CallIntMethod(env, obj_rmr[i], japi->rmr_getRank)); + JNI_TYPE_VAR_CALL(env, jlong, rmr_datasize, + (*env)->CallLongMethod(env, obj_rmr[i], japi->rmr_getDataSize)); + JNI_TYPE_VAR_CALL(env, jobject, rmr_data, + (*env)->CallObjectMethod(env, obj_rmr[i], japi->rmr_getData)); + JNI_TYPE_VAR_CALL(env, jobject, rmr_sizes, + (*env)->CallObjectMethod(env, obj_rmr[i], japi->rmr_getSizes)); + JNI_TYPE_VAR_CALL(env, jobject, rmr_strides, + (*env)->CallObjectMethod(env, obj_rmr[i], japi->rmr_getStrides)); + + /* Primitive type int and long can be directly used */ + int jni_type = rmr_type, jni_rank = rmr_rank; + long jni_datasize = rmr_datasize; + + /* Get direct buffer associated with data */ + JNI_TYPE_VAR_CALL( + env, void *, jni_data, (*env)->GetDirectBufferAddress(env, rmr_data)); + + /* Get long array associated with sizes and strides */ + JNI_TYPE_VAR_CALL(env, long *, jni_sizes, + (*env)->GetLongArrayElements(env, rmr_sizes, NULL)); + JNI_TYPE_VAR_CALL(env, long *, jni_strides, + (*env)->GetLongArrayElements(env, rmr_strides, NULL)); + + /* Print debug info on what we got from the Java side */ + RMR_DEBUG( + i, jni_type, jni_rank, jni_sizes, jni_strides, jni_data, jni_datasize); + + /* Create native RtMemRef struct and fill in its fields */ + jni_rmr[i] = createRtMemRef(jni_rank); + setDType(jni_rmr[i], jni_type); + setData(jni_rmr[i], jni_data); + setSizes(jni_rmr[i], jni_sizes); + setStrides(jni_rmr[i], jni_strides); + + /*jni_name[i] = (*env)->GetStringUTFChars(env, obj_name[i], NULL); + printf("jni_name=%s\n", jni_name[i]);*/ + + /* Install RtMemRef into OrderedRtMemRefDict */ + setRtMemRef(ormrd, i, jni_rmr[i]); + + /* Release reference to the java objects */ + JNI_CALL( + env, (*env)->ReleaseLongArrayElements(env, rmr_sizes, jni_sizes, 0)); + JNI_CALL(env, + (*env)->ReleaseLongArrayElements(env, rmr_strides, jni_strides, 0)); + } + + /* setRtMemRef(ormrd, jni_rmr, jni_name); */ + return ormrd; +} + +/* Convert native data structure to Java object */ +jobject ormrd_native_to_java( + JNIEnv *env, jclass cls, OrderedRtMemRefDict *dict, jniapi_t *japi) { + JNI_COND(int, nrmr, numRtMemRefs(dict), 0, env, japi->ecpt_cls, "nrmr=0"); + + /* Create RtMemRef java object array */ + JNI_TYPE_VAR_CALL(env, jobjectArray, rmrs, + (*env)->NewObjectArray(env, nrmr, japi->rmr_cls, NULL)); + + /* Loop through the native RtMemRef structs */ + for (int i = 0; i < nrmr; i++) { + JNI_COND(RtMemRef *, rmr, getRtMemRef(dict, i), NULL, env, japi->ecpt_cls, + "rmr[%d]=null", i); + + JNI_COND(int, jni_type, getDType(rmr), 0, env, japi->ecpt_cls, + "rmr[%d]:type=0", i); + JNI_COND(int, jni_rank, getRank(rmr), 0, env, japi->ecpt_cls, + "rmr[%d]:rank=0", i); + JNI_COND(long, jni_datasize, getDataSize(rmr), 0, env, japi->ecpt_cls, + "rmr[%d]:datasize=0", i); + JNI_COND(void *, jni_data, getData(rmr), NULL, env, japi->ecpt_cls, + "rmr[%d]:data=null", i); + JNI_COND(long *, jni_sizes, getSizes(rmr), NULL, env, japi->ecpt_cls, + "rmr[%d]:sizes=null", i); + JNI_COND(long *, jni_strides, getStrides(rmr), NULL, env, japi->ecpt_cls, + "rmr[%d]:strides=null", i); + + /* Print debug info on what we got from the native side */ + RMR_DEBUG( + i, jni_type, jni_rank, jni_sizes, jni_strides, jni_data, jni_datasize); + + /* create the following Java objects: + * - RtMemRef + * - DirectByteBuffer (from native buffers) + * - long array for sizes and strides + */ + JNI_TYPE_VAR_CALL(env, jobject, obj_rmr, + (*env)->NewObject(env, japi->rmr_cls, japi->rmr_constructor, jni_rank)); + JNI_TYPE_VAR_CALL(env, jobject, rmr_data, + (*env)->NewDirectByteBuffer( + env, jni_data, jni_datasize * onnx_type_size[jni_type])); + JNI_TYPE_VAR_CALL( + env, jlongArray, rmr_sizes, (*env)->NewLongArray(env, jni_rank)); + JNI_TYPE_VAR_CALL( + env, jlongArray, rmr_strides, (*env)->NewLongArray(env, jni_rank)); + + /* Call setType method */ + JNI_CALL(env, + (*env)->CallObjectMethod(env, obj_rmr, japi->rmr_setType, jni_type)); + + /* Call setData method */ + JNI_CALL(env, + (*env)->CallObjectMethod(env, obj_rmr, japi->rmr_setData, rmr_data)); + + /* Fill in sizes array from native array and call setSizes method */ + JNI_CALL(env, + (*env)->SetLongArrayRegion(env, rmr_sizes, 0, jni_rank, jni_sizes)); + JNI_CALL(env, + (*env)->CallObjectMethod(env, obj_rmr, japi->rmr_setSizes, rmr_sizes)); + + /* Fill in strides array from native array and call setStrides method */ + JNI_CALL(env, + (*env)->SetLongArrayRegion(env, rmr_strides, 0, jni_rank, jni_strides)); + JNI_CALL(env, (*env)->CallObjectMethod( + env, obj_rmr, japi->rmr_setStrides, rmr_strides)); + + /* Set DynMemRef object in the object array */ + JNI_CALL(env, (*env)->SetObjectArrayElement(env, rmrs, i, obj_rmr)); + } + + /* Create the OrderedRtMemRefDict java object */ + JNI_TYPE_VAR_CALL(env, jobject, ormrd, + (*env)->NewObject(env, japi->ormrd_cls, japi->ormrd_constructor, rmrs)); + + return ormrd; +} + +JNIEXPORT jobject JNICALL Java_com_ibm_onnxmlir_DynEntryPoint_main_1graph_1jni( + JNIEnv *env, jclass cls, jobject obj) { + CHECK_CALL(jniapi_t *, japi, fill_jniapi(env, &jniapi), NULL); + + CHECK_CALL(OrderedRtMemRefDict *, input_ormrd, + ormrd_java_to_native(env, cls, obj, japi), NULL); + + CHECK_CALL(OrderedRtMemRefDict *, dict, + _dyn_entry_point_main_graph(input_ormrd), NULL); + + CHECK_CALL( + jobject, output_ormrd, ormrd_native_to_java(env, cls, dict, japi), NULL); + + return output_ormrd; +} diff --git a/src/Runtime/jni/src/com/ibm/onnxmlir/DynEntryPoint.java b/src/Runtime/jni/src/com/ibm/onnxmlir/DynEntryPoint.java new file mode 100644 index 0000000..8bd6d4c --- /dev/null +++ b/src/Runtime/jni/src/com/ibm/onnxmlir/DynEntryPoint.java @@ -0,0 +1,64 @@ +package com.ibm.onnxmlir; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.file.Files; +import java.nio.file.StandardCopyOption; +import java.util.jar.JarFile; + +public class DynEntryPoint { + static String libname = "libmodel.so"; + + static { + File jar; + JarFile jf; + String jarDir = null; + String libPath = null; + try { + // Get path name of jar + jar = new File(DynEntryPoint.class.getProtectionDomain() + .getCodeSource() + .getLocation().toURI()); + jarDir = jar.getParentFile().getAbsolutePath(); + libPath = jarDir + "/" + libname; + + // Open jar file to read and check libname inside jar. + // If IOException thrown, load .so from where .jar is. + // + // Checking whether DynEntryPoint.class.getResourceAsStream returns null + // does NOT work. Because it checks whether the resource is + // available on the classpath, not only just inside the jar file. + jf = new JarFile(jar); + if (jf.getEntry(libname) != null) { + File libFile = new File(libPath); + // Copy .so to where jar is + Files.copy(jf.getInputStream(jf.getEntry(libname)), + libFile.toPath(), StandardCopyOption.REPLACE_EXISTING); + // z/OS USS requires "x" permission bit + libFile.setExecutable(true, false); + // Load the temporary .so copy + System.load(libPath); + // POSIX can unlink file after loading + libFile.delete(); + } + else { + // Throw subclass of IOException + throw new FileNotFoundException(".so not found inside jar"); + } + } catch (URISyntaxException e) { + // Failed to find jar path, assume the .so is in cwd + System.load(libname); + } catch (IOException e) { + // .so not found in jar, assume it's where jar is + System.load(libPath); + } + } + + private static native OrderedRtMemRefDict main_graph_jni(OrderedRtMemRefDict ormrd); + + public static OrderedRtMemRefDict main_graph(OrderedRtMemRefDict ormrd) { + return main_graph_jni(ormrd); + } +} diff --git a/src/Runtime/jni/src/com/ibm/onnxmlir/OrderedRtMemRefDict.java b/src/Runtime/jni/src/com/ibm/onnxmlir/OrderedRtMemRefDict.java new file mode 100644 index 0000000..deea1c5 --- /dev/null +++ b/src/Runtime/jni/src/com/ibm/onnxmlir/OrderedRtMemRefDict.java @@ -0,0 +1,118 @@ +package com.ibm.onnxmlir; + +import java.util.HashMap; + +public class OrderedRtMemRefDict { + + private RtMemRef[] _rmrs; + private String[] _names; + private HashMap _n2i; + + /** + * Constructor + * + * @param rmrs DynMemRef array + */ + public OrderedRtMemRefDict(RtMemRef[] rmrs) { + this(rmrs, null); + } + + /** + * Constructor + * + * @param rmrs DynMemRef array + * @param names name array + */ + public OrderedRtMemRefDict(RtMemRef[] rmrs, String[] names) { + /* rmrs cannot be null or empty */ + if (rmrs == null || rmrs.length == 0) + throw new IllegalArgumentException( + "Number of dmrs is invalid"); + + /* If names is null or empty, construct a default one with + * index as name. + */ + if (names == null || names.length == 0) { + names = new String[rmrs.length]; + for (int i = 0; i < rmrs.length; i++) + names[i] = Integer.toString(i); + } + + /* Number of rmrs and names must match */ + if (rmrs.length != names.length) + throw new IllegalArgumentException( + "Number of dmrs and names do not match"); + + /* Establish name to index mapping. Individual rmr is + * checked for validity. + */ + _n2i = new HashMap(); + for (int i = 0; i < names.length; i++) { + if (rmrs[i] == null || !rmrs[i].validRmr()) + throw new IllegalArgumentException( + "rmrs[" + i + "] is invalid"); + if (_n2i.put(names[i], i) != null) + throw new IllegalArgumentException( + "name[" + i + "] = " + names[i] + " not unique"); + } + _rmrs = rmrs; + _names = names; + } + + /** + * RtMemRef getter by index + * + * @param idx index of RtMemRef instance to get + * @return RtMemRef instance + */ + public RtMemRef getRmrbyIndex(int idx) { + return _rmrs[idx]; + } + + /** + * RtMemRef getter by name + * + * @param name name of RtMemRef instance to get + * @return RtMemRef instance + */ + public RtMemRef getRmrByName(String name) { + return _rmrs[_n2i.get(name)]; + } + + /** + * RtMemRef array getter + * + * @return RtMemRef array + */ + public RtMemRef[] getRmrs() { + return _rmrs; + } + + /** + * Name getter + * + * @param idx index of name to get + * @return name string + */ + public String getName(int idx) { + return _names[idx]; + } + + /** + * Name array getter + * + * @return name array + */ + public String[] getNames() { + return _names; + } + + /** + * RtMemRef array size getter + * + * @return RtMemRef array size + */ + public int size() { + return _rmrs.length; + } +} diff --git a/src/Runtime/jni/src/com/ibm/onnxmlir/RtMemRef.java b/src/Runtime/jni/src/com/ibm/onnxmlir/RtMemRef.java new file mode 100644 index 0000000..31b0aea --- /dev/null +++ b/src/Runtime/jni/src/com/ibm/onnxmlir/RtMemRef.java @@ -0,0 +1,400 @@ +package com.ibm.onnxmlir; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; +import java.nio.ShortBuffer; + +public class RtMemRef { + final ByteOrder endian = ByteOrder.nativeOrder(); + + /* We can use enum but that creates another class + * which complicates things for JNI. + */ + final int ONNX_TYPE_UNDEFINED = 0; + final int ONNX_TYPE_FLOAT = 1; + final int ONNX_TYPE_UINT8 = 2; + final int ONNX_TYPE_INT8 = 3; + final int ONNX_TYPE_UINT16 = 4; + final int ONNX_TYPE_INT16 = 5; + final int ONNX_TYPE_INT32 = 6; + final int ONNX_TYPE_INT64 = 7; + final int ONNX_TYPE_STRING = 8; + final int ONNX_TYPE_BOOL = 9; + final int ONNX_TYPE_FLOAT16 = 10; + final int ONNX_TYPE_DOUBLE = 11; + final int ONNX_TYPE_UINT32 = 12; + final int ONNX_TYPE_UINT64 = 13; + final int ONNX_TYPE_COMPLEX64 = 14; + final int ONNX_TYPE_COMPLEX128 = 15; + final int ONNX_TYPE_BFLOAT16 = 16; + + final int[] ONNX_TYPE_SIZE = new int[] { + 0, /* UNDEFINED */ + 4, /* FLOAT */ + 1, /* UINT8 */ + 1, /* INT8 */ + 2, /* UINT16 */ + 2, /* INT16 */ + 4, /* INT32 */ + 8, /* INT64 */ + 0, /* STRING */ + 1, /* BOOL */ + 2, /* FLOAT16 */ + 8, /* DOUBLE */ + 4, /* UINT32 */ + 8, /* UINT64 */ + 8, /* COMPLEX64 */ + 16, /* COMPLEX128 */ + 2, /* BFLOAT16 */ + }; + + private ByteBuffer _data; + private int _type; + private int _rank; + private long[] _sizes; + private long[] _strides; + + /** + * Constructor + */ + public RtMemRef(int rank) { + if (rank <= 0) + throw new IllegalArgumentException( + "invalid rank " + rank); + _data = null; + _type = ONNX_TYPE_UNDEFINED; + _rank = rank; + _sizes = new long[rank]; + _strides = new long[rank]; + } + + /* ---------- Data type getter and setter ---------- */ + /* For JNI wrapper only. Not intended for end user. */ + + /** + * Type getter + * + * @return data type + */ + @SuppressWarnings("unused") + private int getType() { + return _type; + } + + /** + * Type setter + * + * @param type data type to be set + */ + @SuppressWarnings("unused") + private void setType(int type) { + _type = type; + } + + /* ---------- Raw data getter and setter ---------- */ + /* For JNI wrapper only. Not intended for end user. */ + + /** + * Raw data getter + * + * @return raw data + */ + @SuppressWarnings("unused") + private ByteBuffer getData() { + return _data; + } + + /** + * Raw data setter + * + * @param data raw data to be set + */ + @SuppressWarnings("unused") + private void setData(ByteBuffer data) { + _data = data.order(endian); + } + + /* ---------- Byte data getter and setter ---------- */ + + /** + * Byte data getter + * + * @return byte data array + */ + public byte[] getByteData() { + if (_data == null) return null; + + /* asReadOnlyBuffer() creates a new view so the position of the + * original data will stay at 0 for subsequent getByteData() + * after get(b). + */ + byte[] b = new byte[_data.limit()]; + _data.asReadOnlyBuffer().get(b); + return b; + } + + /** + * Byte data setter + * + * @param data byte array to be set + */ + public void setByteData(byte[] data) { + /* slice() creates a new view so the position of the + * original data will stay at 0 for getByteData() after put(data). + */ + _data = ByteBuffer.allocateDirect(data.length); + _data.slice().put(data); + _type = ONNX_TYPE_INT8; + } + + /* ---------- Short data getter and setter ---------- */ + + /** + * Short data getter + * + * @return short data array + */ + public short[] getShortData() { + if (_data == null) return null; + + /* asShortBuffer() creates a new view so the position of the + * original data will stay at 0 for subsequent getShortData() + * after get(s). + */ + ShortBuffer sb = _data.asShortBuffer(); + short[] s = new short[sb.limit()]; + sb.get(s); + return s; + } + + /** + * Short data setter + * + * @param data short array to be set + */ + public void setShortData(short[] data) { + /* asShortBuffer() creates a new view so the position of the + * original data will stay at 0 for getShortData() after put(data). + */ + _data = ByteBuffer.allocateDirect(data.length*2).order(endian); + _data.asShortBuffer().put(data); + _type = ONNX_TYPE_INT16; + } + + /* ---------- Int data getter and setter ---------- */ + + /** + * Int data getter + * + * @return int data array + */ + public int[] getIntData() { + if (_data == null) return null; + + /* asIntBuffer() creates a new view so the position of the + * original data will stay at 0 for subsequent getIntData() + * after get(i). + */ + IntBuffer ib = _data.asIntBuffer(); + int[] i = new int[ib.limit()]; + ib.get(i); + return i; + } + + /** + * Int data setter + * + * @param data int array to be set + */ + public void setIntData(int[] data) { + /* asIntBuffer() creates a new view so the position of the + * original data will stay at 0 for getIntData() after put(data). + */ + _data = ByteBuffer.allocateDirect(data.length*4).order(endian); + _data.asIntBuffer().put(data); + _type = ONNX_TYPE_INT32; + } + + /* ---------- Long data getter and setter ---------- */ + + /** + * Long data getter + * + * @return long data array + */ + public long[] getLongData() { + if (_data == null) return null; + + /* asLongBuffer() creates a new view so the position of the + * original data will stay at 0 for subsequent getLongData() + * after get(l). + */ + LongBuffer lb = _data.asLongBuffer(); + long[] l = new long[lb.limit()]; + lb.get(l); + return l; + } + + /** + * Long data setter + * + * @param data long array to be set + */ + public void setLongData(long[] data) { + /* asLongBuffer() creates a new view so the position of the + * original data will stay at 0 for getLongData() after put(data). + */ + _data = ByteBuffer.allocateDirect(data.length*8).order(endian); + _data.asLongBuffer().put(data); + _type = ONNX_TYPE_INT64; + } + + /* ---------- Float data getter and setter ---------- */ + + /** + * Float data getter + * + * @return float data array + */ + public float[] getFloatData() { + if (_data == null) return null; + + /* asFloatBuffer() creates a new view so the position of the + * original data will stay at 0 for subsequent getFloatData() + * after get(f). + */ + FloatBuffer fb = _data.asFloatBuffer(); + float[] f = new float[fb.limit()]; + fb.get(f); + return f; + } + + /** + * Float data setter + * + * @param data float array to be set + */ + public void setFloatData(float[] data) { + /* asFloatBuffer() creates a new view so the position of the + * original data will stay at 0 for getFloatData() after put(data). + */ + _data = ByteBuffer.allocateDirect(data.length*4).order(endian); + _data.asFloatBuffer().put(data); + _type = ONNX_TYPE_FLOAT; + } + + /* ---------- Double data getter and setter ---------- */ + + /** + * Double data getter + * + * @return double data array + */ + public double[] getDoubleData() { + if (_data == null) return null; + + /* asDoubleBuffer() creates a new view so the position of the + * original data will stay at 0 for subsequent getDoubleData() + * after get(d). + */ + DoubleBuffer db = _data.asDoubleBuffer(); + double[] d = new double[db.limit()]; + db.get(d); + return d; + } + + /** + * Double data setter + * + * @param data double array to be set + */ + public void setDoubleData(double[] data) { + /* asDoubleBuffer() creates a new view so the position of the + * original data will stay at 0 for getDoubleData() after put(data). + */ + _data = ByteBuffer.allocateDirect(data.length*8).order(endian); + _data.asDoubleBuffer().put(data); + _type = ONNX_TYPE_DOUBLE; + } + + /** + * Rank getter + * + * @return rank + */ + public int getRank() { + return _rank; + } + + /* ---------- Sizes getter and setter ---------- */ + + /** + * Sizes getter + * + * @return sizes array + */ + public long[] getSizes() { + return _sizes; + } + + /** + * Sizes setter + * + * @param sizes sizes array to be set + */ + public void setSizes(long[] sizes) { + if (sizes.length != _rank) + throw new IllegalArgumentException( + "array length " + sizes.length + " != rank " + _rank); + _sizes = sizes.clone(); + } + + /* ---------- Strides getter and setter ---------- */ + + /** + * Strides getter + * + * @return strides array + */ + public long[] getStrides() { + return _strides; + } + + /** + * Strides setter + * + * @param strides strides array to be set + */ + public void setStrides(long[] strides) { + if (strides.length != _rank) + throw new IllegalArgumentException( + "array length " + strides.length + " != rank " + _rank); + _strides = strides.clone(); + } + + /** + * Size getter + * + * @return product of sizes array, i.e., total number of data elements + */ + public long getDataSize() { + long n = _sizes[0]; + for (int i = 1; i < _sizes.length; i++) n *= _sizes[i]; + return n; + } + + /** + * Check validity of RtMemRef + * + * @return true if RtMemRef is valid, false otherwise + */ + public boolean validRmr() { + return (_data != null && + _data.limit() != 0 && + _data.limit() == getDataSize() * ONNX_TYPE_SIZE[_type]); + } +} diff --git a/src/Tool/ONNXMLIROpt/CMakeLists.txt b/src/Tool/ONNXMLIROpt/CMakeLists.txt index edb8986..9b49d80 100644 --- a/src/Tool/ONNXMLIROpt/CMakeLists.txt +++ b/src/Tool/ONNXMLIROpt/CMakeLists.txt @@ -4,12 +4,6 @@ add_dependencies(onnx-mlir-opt OMKrnlOpsInc OMONNXOpsInc) target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_SRC_ROOT}) target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_BIN_ROOT}) -set(ONNX_MLIR_LD_PRELOAD_onnx-mlir-opt "" CACHE STRING "" FORCE) -if(BUILD_SHARED_LIBS) - message(STATUS "To run dynamically linked onnx-mlir-opt, you must specify:") - message(STATUS "LD_PRELOAD=${ONNX_MLIR_LD_PRELOAD_onnx-mlir-opt}") -endif() - target_link_libraries(onnx-mlir-opt ${OMLibs} ${MLIRLibs} diff --git a/src/main.cpp b/src/main.cpp index 69e115c..5585c07 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -12,6 +12,7 @@ using namespace std; using namespace onnx_mlir; int main(int argc, char *argv[]) { + setExecPath(argv[0], (void *)main); registerDialects(); llvm::cl::OptionCategory OnnxMlirOptions( @@ -33,7 +34,9 @@ int main(int argc, char *argv[]) { clEnumVal(EmitLLVMIR, "Lower model to LLVM IR (LLVM dialect)."), clEnumVal(EmitLib, "Lower model to LLVM IR, emit (to file) " "LLVM bitcode for model, compile and link it to a " - "shared library.")), + "shared library."), + clEnumVal(EmitJNI, "Lower model to LLMV IR -> LLVM bitcode " + "-> JNI shared library -> jar")), llvm::cl::init(EmitLib), llvm::cl::cat(OnnxMlirOptions)); llvm::cl::HideUnrelatedOptions(OnnxMlirOptions); diff --git a/test/numerical/TestConv.cpp b/test/numerical/TestConv.cpp index ab48915..c880afa 100644 --- a/test/numerical/TestConv.cpp +++ b/test/numerical/TestConv.cpp @@ -13,6 +13,8 @@ #include "src/MainUtils.hpp" #include "src/Runtime/ExecusionSession.hpp" +#define SHARED_LIB_BASE string("./TestConv_main_graph") + using namespace std; // Returns whether onnx-mlir compiled convolution is producing the same results @@ -87,14 +89,9 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H, OwningModuleRef moduleRef(module); - llvm::SmallVector path; - llvm::sys::fs::createTemporaryFile("_main_graph", "", path); - string pathStr(path.begin(), path.end()); - llvm::FileRemover remover(path); - - compileModule(moduleRef, ctx, pathStr, EmitLib); + compileModule(moduleRef, ctx, SHARED_LIB_BASE, EmitLib); onnx_mlir::ExecutionSession sess( - pathStr + ".so", "_dyn_entry_point_main_graph"); + SHARED_LIB_BASE + ".so", "_dyn_entry_point_main_graph"); std::vector> inputs; auto xRmr = unique_ptr(getRndRealRmr({N, C, H, W})); @@ -127,7 +124,10 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H, return isRmrClose(conv.get(), ref); } -int main() { +int main(int argc, char *argv[]) { + setExecPath(argv[0], (void *)main); + llvm::FileRemover remover(SHARED_LIB_BASE + ".so"); + // RapidCheck test case generation. rc::check("convolution implementation correctness", []() { const auto N = *rc::gen::inRange(1, 10);