Add emitjni target (#204)

* Detect llvm-project commit change in utils/clone-mlir.sh and rebuild llvm-project
for zLinux Jenkins build bot

* Add --EmitJNI target (tested working with mnist and resnet50)
- MainUtils
  * first shot at refactoring compileModuleToSharedLibrary
  * add setExecPath call to allow resolving runtime directory from onnx-mlir
    executable path when ONNX_MLIR_RUNTIME_DIR is not set. This allows
    tests to run without having to install onnx-mlir or to explicitly set
    ONNX_MLIR_RUNTIME_DIR
- RtMemRef
  * add getDataSize for C (equivalent of size() for C++).
  * fix setStrides bug (setting sizes, not strides)
- TestConv
  * _main_graph-*.so were filling up /tmp. Change to use fixed shared library
    in build directory

* Fix clang-format-lint complaints

* - getRuntimeDir checks lib64
- install targets for javaruntime and jniruntime
- remove ONNX_MLIR_LD_PRELOAD_onnx-mlir and ONNX_MLIR_LD_PRELOAD_onnx-mlir-opt

* See what happens when `kExecPath` decl is dropped.

Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
gongsu832 2020-07-11 01:23:13 -04:00 committed by GitHub
parent af75b4c75e
commit d235f248e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 1404 additions and 55 deletions

View File

@ -71,5 +71,5 @@ cmake -DCMAKE_INSTALL_PREFIX=${INSTALL_PATH} .. \
make -j$(nproc) make -j$(nproc)
make -j$(nproc) check-onnx-lit make -j$(nproc) check-onnx-lit
RUNTIME_DIR=$(pwd)/lib make -j$(nproc) check-onnx-backend make -j$(nproc) check-onnx-backend
RUNTIME_DIR=$(pwd)/lib PATH=$(pwd)/bin:$PATH make -j$(nproc) test PATH=$(pwd)/bin:$PATH make -j$(nproc) test

View File

@ -55,12 +55,6 @@ endif()
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/ExternalUtil.hpp.in configure_file(${CMAKE_CURRENT_SOURCE_DIR}/ExternalUtil.hpp.in
${CMAKE_CURRENT_BINARY_DIR}/ExternalUtil.hpp) ${CMAKE_CURRENT_BINARY_DIR}/ExternalUtil.hpp)
set(ONNX_MLIR_LD_PRELOAD_onnx-mlir "" CACHE STRING "" FORCE)
if(BUILD_SHARED_LIBS)
message(STATUS "To run dynamically linked onnx-mlir, you must specify:")
message(STATUS "LD_PRELOAD=${ONNX_MLIR_LD_PRELOAD_onnx-mlir}")
endif()
# Libraries specified on the target_link_libraries for the add_subdirectory # Libraries specified on the target_link_libraries for the add_subdirectory
# targets get added to the end of the list here. This creates two problems: # targets get added to the end of the list here. This creates two problems:
# 1. It produces duplicated libraries being specified for the link command. # 1. It produces duplicated libraries being specified for the link command.

View File

@ -3,8 +3,12 @@
#include <string> #include <string>
namespace onnx_mlir { namespace onnx_mlir {
std::string kExecPath = "@CMAKE_INSTALL_PREFIX@/bin/onnx-mlir"; /* fallback if not set by main */
const std::string kInstPath = "@CMAKE_INSTALL_PREFIX@";
const std::string kLlcPath = "@LLVM_PROJ_BUILD@/bin/llc"; const std::string kLlcPath = "@LLVM_PROJ_BUILD@/bin/llc";
const std::string kCxxPath = "@CMAKE_CXX_COMPILER@"; const std::string kCxxPath = "@CMAKE_CXX_COMPILER@";
const std::string kLinkerPath = "@CMAKE_LINKER@"; const std::string kLinkerPath = "@CMAKE_LINKER@";
const std::string kObjCopyPath = "@CMAKE_OBJCOPY@"; const std::string kObjCopyPath = "@CMAKE_OBJCOPY@";
const std::string kArPath = "@CMAKE_AR@";
const std::string kJarPath = "@Java_JAR_EXECUTABLE@";
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@ -1,4 +1,4 @@
//===--------------------------- main_utils.cpp ---------------------------===// //===--------------------------- MainUtils.cpp ---------------------------===//
// //
// Copyright 2019-2020 The IBM Research Authors. // Copyright 2019-2020 The IBM Research Authors.
// //
@ -13,6 +13,7 @@
#include <fcntl.h> #include <fcntl.h>
#include <regex> #include <regex>
#include <string> #include <string>
#include <vector>
#include <llvm/Support/FileSystem.h> #include <llvm/Support/FileSystem.h>
#include <llvm/Support/Program.h> #include <llvm/Support/Program.h>
@ -22,8 +23,6 @@
#include "src/ExternalUtil.hpp" #include "src/ExternalUtil.hpp"
#include "src/MainUtils.hpp" #include "src/MainUtils.hpp"
#include "MainUtils.hpp"
#ifdef _WIN32 #ifdef _WIN32
#include <io.h> #include <io.h>
#else #else
@ -41,6 +40,42 @@ llvm::Optional<std::string> getEnvVar(std::string name) {
return llvm::None; return llvm::None;
} }
// Runtime directory contains all the libraries, jars, etc. that are
// necessary for running onnx-mlir. It's resolved in the following order:
//
// - if ONNX_MLIR_RUNTIME_DIR is set, use it, otherwise
// - get path from where onnx-mlir is run, if it's of the form
// /foo/bar/bin/onnx-mlir,
// the runtime directory is /foo/bar/lib (note that when onnx-mlir is
// installed system wide, which is typically /usr/local/bin, this will
// correctly resolve to /usr/local/lib), but some systems still have
// lib64 so we check that first. If neither exists, then
// - use CMAKE_INSTALL_PREFIX/lib, which is typically /usr/local/lib
string getRuntimeDir() {
const auto &envDir = getEnvVar("ONNX_MLIR_RUNTIME_DIR");
if (envDir && llvm::sys::fs::exists(envDir.getValue()))
return envDir.getValue();
string execDir = llvm::sys::path::parent_path(kExecPath).str();
if (llvm::sys::path::stem(execDir).str().compare("bin") == 0) {
string p = execDir.substr(0, execDir.size() - 3);
if (llvm::sys::fs::exists(p + "lib64"))
return p + "lib64";
if (llvm::sys::fs::exists(p + "lib"))
return p + "lib";
}
llvm::SmallString<8> instDir64(kInstPath);
llvm::sys::path::append(instDir64, "lib64");
string p = llvm::StringRef(instDir64).str();
if (llvm::sys::fs::exists(p))
return p;
llvm::SmallString<8> instDir(kInstPath);
llvm::sys::path::append(instDir, "lib");
return llvm::StringRef(instDir).str();
}
// Helper struct to make command construction and execution easy & readable. // Helper struct to make command construction and execution easy & readable.
struct Command { struct Command {
std::string _path; std::string _path;
@ -97,6 +132,12 @@ struct Command {
}; };
} // namespace } // namespace
void setExecPath(const char *argv0, void *fmain) {
string p;
if (!(p = llvm::sys::fs::getMainExecutable(argv0, fmain)).empty())
kExecPath = p;
}
void LoadMLIR(string inputFilename, mlir::MLIRContext &context, void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
mlir::OwningModuleRef &module) { mlir::OwningModuleRef &module) {
// Handle '.mlir' input to the ONNX MLIR frontend. // Handle '.mlir' input to the ONNX MLIR frontend.
@ -119,8 +160,8 @@ void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
} }
} }
void compileModuleToSharedLibrary( void genConstPackObj(const mlir::OwningModuleRef &module,
const mlir::OwningModuleRef &module, string outputBaseName) { llvm::Optional<string> &constPackObjPath) {
// Extract constant pack file name, which is embedded as a symbol in the // Extract constant pack file name, which is embedded as a symbol in the
// module being compiled. // module being compiled.
auto constPackFilePathSym = (*module).lookupSymbol<mlir::LLVM::GlobalOp>( auto constPackFilePathSym = (*module).lookupSymbol<mlir::LLVM::GlobalOp>(
@ -131,7 +172,6 @@ void compileModuleToSharedLibrary(
.str(); .str();
llvm::FileRemover constPackRemover(constPackFilePath); llvm::FileRemover constPackRemover(constPackFilePath);
llvm::Optional<std::string> constPackObjPath;
#if __APPLE__ #if __APPLE__
// Create a empty stub file, compile it to an empty obj file. // Create a empty stub file, compile it to an empty obj file.
llvm::SmallVector<char, 20> stubSrcPath; llvm::SmallVector<char, 20> stubSrcPath;
@ -153,7 +193,6 @@ void compileModuleToSharedLibrary(
.appendList({"-sectcreate", "binary", "param", constPackFilePath}) .appendList({"-sectcreate", "binary", "param", constPackFilePath})
.appendStr(stubObjPathStr) .appendStr(stubObjPathStr)
.exec(); .exec();
llvm::FileRemover constPackObjRemover(constPackObjPath.getValue());
#elif __linux__ #elif __linux__
// Create param.o holding packed parameter values. // Create param.o holding packed parameter values.
@ -164,7 +203,6 @@ void compileModuleToSharedLibrary(
.appendList({"-o", constPackObjPath.getValue()}) .appendList({"-o", constPackObjPath.getValue()})
.appendStr(constPackFilePath) .appendStr(constPackFilePath)
.exec(); .exec();
llvm::FileRemover constPackObjRemover(constPackObjPath.getValue());
// Figure out what is the default symbol name describing the start/end // Figure out what is the default symbol name describing the start/end
// address of the embedded data. // address of the embedded data.
@ -204,42 +242,121 @@ void compileModuleToSharedLibrary(
mlir::KrnlPackedConstantOp::getConstPackFileNameStrLenSymbolName()) mlir::KrnlPackedConstantOp::getConstPackFileNameStrLenSymbolName())
.valueAttr(builder.getI64IntegerAttr(constPackFileName.str().size())); .valueAttr(builder.getI64IntegerAttr(constPackFileName.str().size()));
#endif #endif
}
// Write LLVM bitcode. // Write LLVM bitcode.
string outputFilename = outputBaseName + ".bc"; void genLLVMBitcode(const mlir::OwningModuleRef &module, string bitcodePath) {
error_code error; error_code error;
llvm::raw_fd_ostream moduleBitcodeStream( llvm::raw_fd_ostream moduleBitcodeStream(
outputFilename, error, llvm::sys::fs::F_None); bitcodePath, error, llvm::sys::fs::F_None);
llvm::WriteBitcodeToFile( llvm::WriteBitcodeToFile(
*mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream); *mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream);
moduleBitcodeStream.flush(); moduleBitcodeStream.flush();
llvm::FileRemover bcRemover(outputFilename); }
// Compile LLVM bitcode to object file. // Compile LLVM bitcode to object file.
void genModelObject(const mlir::OwningModuleRef &module, string bitcodePath,
string modelObjPath) {
Command llvmToObj(/*exePath=*/kLlcPath); Command llvmToObj(/*exePath=*/kLlcPath);
llvmToObj.appendStr("-filetype=obj"); llvmToObj.appendStr("-filetype=obj")
llvmToObj.appendStr("-relocation-model=pic"); .appendStr("-relocation-model=pic")
llvmToObj.appendStr(outputFilename); .appendList({"-o", modelObjPath})
llvmToObj.exec(); .appendStr(bitcodePath)
std::string modelObjPath = outputBaseName + ".o"; .exec();
llvm::FileRemover modelObjRemover(modelObjPath); }
llvm::Optional<std::string> runtimeDirInclFlag; void genJniObject(const mlir::OwningModuleRef &module, string jniSharedLibPath,
if (getEnvVar("RUNTIME_DIR").hasValue()) string jniObjPath) {
runtimeDirInclFlag = "-L" + getEnvVar("RUNTIME_DIR").getValue(); Command ar(/*exePath=*/kArPath);
ar.appendStr("x").appendStr(jniSharedLibPath).appendStr(jniObjPath).exec();
}
// Link everything into a shared object. // Link everything into a shared object.
void genSharedLib(const mlir::OwningModuleRef &module,
string modelSharedLibPath, std::vector<string> opts,
std::vector<string> objs, std::vector<string> libs) {
string runtimeDirInclFlag = "-L" + getRuntimeDir();
Command link(kCxxPath); Command link(kCxxPath);
link.appendList({"-shared", "-fPIC"}) link.appendList(opts)
.appendStr(modelObjPath) .appendList(objs)
.appendStr(constPackObjPath.getValueOr("")) .appendList({"-o", modelSharedLibPath})
.appendList({"-o", outputBaseName + ".so"})
.appendStrOpt(runtimeDirInclFlag) .appendStrOpt(runtimeDirInclFlag)
.appendList({"-lEmbeddedDataLoader", "-lcruntime"}) .appendList(libs)
.exec(); .exec();
} }
// Create jar containing java runtime and model shared library (which includes
// jni runtime).
void genJniJar(const mlir::OwningModuleRef &module, string modelSharedLibPath,
string modelJniJarPath) {
llvm::SmallString<8> runtimeDir(getRuntimeDir());
llvm::sys::path::append(runtimeDir, "javaruntime.jar");
string javaRuntimeJarPath = llvm::StringRef(runtimeDir).str();
// Copy javaruntime.jar to model jar.
llvm::sys::fs::copy_file(javaRuntimeJarPath, modelJniJarPath);
// Add shared library to model jar.
Command jar(kJarPath);
jar.appendList({"uf", modelJniJarPath}).appendStr(modelSharedLibPath).exec();
}
void compileModuleToSharedLibrary(
const mlir::OwningModuleRef &module, std::string outputBaseName) {
llvm::Optional<string> constPackObjPath;
genConstPackObj(module, constPackObjPath);
llvm::FileRemover constPackObjRemover(constPackObjPath.getValue());
string bitcodePath = outputBaseName + ".bc";
genLLVMBitcode(module, bitcodePath);
llvm::FileRemover bitcodeRemover(bitcodePath);
string modelObjPath = outputBaseName + ".o";
genModelObject(module, bitcodePath, modelObjPath);
llvm::FileRemover modelObjRemover(modelObjPath);
string modelSharedLibPath = outputBaseName + ".so";
genSharedLib(module, modelSharedLibPath, {"-shared", "-fPIC"},
{constPackObjPath.getValueOr(""), modelObjPath},
{"-lEmbeddedDataLoader", "-lcruntime"});
}
void compileModuleToJniJar(
const mlir::OwningModuleRef &module, std::string outputBaseName) {
llvm::Optional<string> constPackObjPath;
genConstPackObj(module, constPackObjPath);
llvm::FileRemover constPackObjRemover(constPackObjPath.getValue());
string bitcodePath = outputBaseName + ".bc";
genLLVMBitcode(module, bitcodePath);
llvm::FileRemover bitcodeRemover(bitcodePath);
string modelObjPath = outputBaseName + ".o";
genModelObject(module, bitcodePath, modelObjPath);
llvm::FileRemover modelObjRemover(modelObjPath);
string jniSharedLibPath = getRuntimeDir() + "/libjniruntime.a";
string jniObjPath = "jnidummy.c.o";
genJniObject(module, jniSharedLibPath, jniObjPath);
llvm::FileRemover jniObjRemover(jniObjPath);
string modelSharedLibPath = "libmodel.so";
genSharedLib(module, modelSharedLibPath,
{"-shared", "-fPIC", "-z", "noexecstack"},
{constPackObjPath.getValueOr(""), modelObjPath, jniObjPath},
{"-lEmbeddedDataLoader", "-lcruntime", "-ljniruntime"});
llvm::FileRemover modelSharedLibRemover(modelSharedLibPath);
string modelJniJarPath = outputBaseName + ".jar";
genJniJar(module, modelSharedLibPath, modelJniJarPath);
}
void registerDialects() { void registerDialects() {
mlir::registerDialect<mlir::AffineDialect>(); mlir::registerDialect<mlir::AffineDialect>();
mlir::registerDialect<mlir::LLVM::LLVMDialect>(); mlir::registerDialect<mlir::LLVM::LLVMDialect>();
@ -349,6 +466,9 @@ void emitOutputFiles(string outputBaseName, EmissionTargetType emissionTarget,
// Write LLVM bitcode to disk, compile & link. // Write LLVM bitcode to disk, compile & link.
compileModuleToSharedLibrary(module, outputBaseName); compileModuleToSharedLibrary(module, outputBaseName);
printf("Shared library %s.so has been compiled.\n", outputBaseName.c_str()); printf("Shared library %s.so has been compiled.\n", outputBaseName.c_str());
} else if (emissionTarget == EmitJNI) {
compileModuleToJniJar(module, outputBaseName);
printf("JNI archive %s.jar has been compiled.\n", outputBaseName.c_str());
} else { } else {
// Emit the version with all constants included. // Emit the version with all constants included.
outputCode(module, outputBaseName, ".onnx.mlir"); outputCode(module, outputBaseName, ".onnx.mlir");

View File

@ -43,14 +43,20 @@ enum EmissionTargetType {
EmitMLIR, EmitMLIR,
EmitLLVMIR, EmitLLVMIR,
EmitLib, EmitLib,
EmitJNI,
}; };
void setExecPath(const char *argv0, void *fmain);
void LoadMLIR(std::string inputFilename, mlir::MLIRContext &context, void LoadMLIR(std::string inputFilename, mlir::MLIRContext &context,
mlir::OwningModuleRef &module); mlir::OwningModuleRef &module);
void compileModuleToSharedLibrary( void compileModuleToSharedLibrary(
const mlir::OwningModuleRef &module, std::string outputBaseName); const mlir::OwningModuleRef &module, std::string outputBaseName);
void compileModuleToJniJar(
const mlir::OwningModuleRef &module, std::string outputBaseName);
void registerDialects(); void registerDialects();
void addONNXToMLIRPasses(mlir::PassManager &pm); void addONNXToMLIRPasses(mlir::PassManager &pm);

View File

@ -1,3 +1,5 @@
add_subdirectory(jni)
# Create static libcruntime.a to be embedded in model.so to make model.so self contained. # Create static libcruntime.a to be embedded in model.so to make model.so self contained.
# However, by default object code for static library is not compiled with -fPIC. Embedding # However, by default object code for static library is not compiled with -fPIC. Embedding
# such static library in a shared library can cause runtime failure on some architectures, # such static library in a shared library can cause runtime failure on some architectures,

View File

@ -150,6 +150,13 @@ int64_t *getStrides(RtMemRef *dynMemRef) { return dynMemRef->strides; }
int64_t getSize(OrderedRtMemRefDict *dict) { return dict->orderedNames.size(); } int64_t getSize(OrderedRtMemRefDict *dict) { return dict->orderedNames.size(); }
INDEX_TYPE getDataSize(RtMemRef *rtMemRef) {
INDEX_TYPE n = rtMemRef->sizes[0];
for (int i = 1; i < rtMemRef->rank; i++)
n *= rtMemRef->sizes[i];
return n;
}
void setDType(RtMemRef *dynMemRef, int onnxType) { void setDType(RtMemRef *dynMemRef, int onnxType) {
dynMemRef->onnx_dtype = onnxType; dynMemRef->onnx_dtype = onnxType;
} }
@ -160,5 +167,5 @@ unsigned int getRank(RtMemRef *dynMemRef) { return dynMemRef->rank; }
void setStrides(RtMemRef *dynMemRef, int64_t *strides) { void setStrides(RtMemRef *dynMemRef, int64_t *strides) {
for (int i = 0; i < dynMemRef->rank; i++) for (int i = 0; i < dynMemRef->rank; i++)
dynMemRef->sizes[i] = strides[i]; dynMemRef->strides[i] = strides[i];
} }

View File

@ -151,6 +151,9 @@ OrderedRtMemRefDict *createOrderedRtMemRefDict();
// Get how many dynamic memrefs are in dict. // Get how many dynamic memrefs are in dict.
int64_t getSize(OrderedRtMemRefDict *dict); int64_t getSize(OrderedRtMemRefDict *dict);
// Get how many data elements are in RtMemRef.
INDEX_TYPE getDataSize(RtMemRef *rtMemRef);
// Create a dynmemref with a certain rank. // Create a dynmemref with a certain rank.
RtMemRef *createRtMemRef(int rank); RtMemRef *createRtMemRef(int rank);

View File

@ -0,0 +1,30 @@
find_package(Java COMPONENTS Development)
find_package(JNI)
if(Java_Development_FOUND AND JNI_FOUND)
include(UseJava)
# Target for Java runtime jar
add_jar(javaruntime
src/com/ibm/onnxmlir/DynEntryPoint.java
src/com/ibm/onnxmlir/OrderedRtMemRefDict.java
src/com/ibm/onnxmlir/RtMemRef.java
OUTPUT_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
# Target for JNI runtime lib
add_library(jniruntime STATIC
jniwrapper.c jnilog.c jnidummy.c
com_ibm_onnxmlir_DynEntryPoint.h jnilog.h ../RtMemRef.h)
set_target_properties(jniruntime PROPERTIES
POSITION_INDEPENDENT_CODE TRUE)
target_include_directories(jniruntime PRIVATE
${ONNX_MLIR_SRC_ROOT}/src/Runtime
${JAVA_INCLUDE_PATH}
${JAVA_INCLUDE_PATH2})
install_jar(javaruntime DESTINATION lib)
install(TARGETS jniruntime DESTINATION lib)
else()
message(WARNING "Java Development component or JNI not found, JNI targets will not work")
endif()

View File

@ -0,0 +1,22 @@
/* DO NOT EDIT THIS FILE - it is machine generated */
#include <jni.h>
/* Header for class com_ibm_onnxmlir_DynEntryPoint */
#ifndef _Included_com_ibm_onnxmlir_DynEntryPoint
#define _Included_com_ibm_onnxmlir_DynEntryPoint
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: com_ibm_onnxmlir_DynEntryPoint
* Method: main_graph_jni
* Signature:
* (Lcom/ibm/onnxmlir/OrderedRtMemRefDict;)Lcom/ibm/onnxmlir/OrderedRtMemRefDict;
*/
JNIEXPORT jobject JNICALL Java_com_ibm_onnxmlir_DynEntryPoint_main_1graph_1jni(
JNIEnv *, jclass, jobject);
#ifdef __cplusplus
}
#endif
#endif

View File

@ -0,0 +1,7 @@
#include "com_ibm_onnxmlir_DynEntryPoint.h"
/* Dummy routine to force the link editor to embed code in libjniruntime.a
into libmodel.so */
void __dummy_do_not_call__(JNIEnv *env, jclass cls, jobject obj) {
Java_com_ibm_onnxmlir_DynEntryPoint_main_1graph_1jni(NULL, NULL, NULL);
}

94
src/Runtime/jni/jnilog.c Normal file
View File

@ -0,0 +1,94 @@
#include <errno.h>
#include <libgen.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include "jnilog.h"
static int log_initd = 0;
static int log_level;
static FILE *log_fp;
/* Must match enum in log.h */
static char *log_level_name[] = {
"trace", "debug", "info", "warning", "error", "fatal"};
/* Return numerical log level of give level name */
static int get_log_level_by_name(char *name) {
int level = -1;
for (int i = 0; i < sizeof(log_level_name) / sizeof(char *); i++) {
if (!strcmp(name, log_level_name[i])) {
level = i;
break;
}
}
return level;
}
/* Return FILE pointer of given file name */
static FILE *get_log_file_by_name(char *name) {
FILE *fp = NULL;
if (!strcmp(name, "stdout"))
fp = stdout;
else if (!strcmp(name, "stderr"))
fp = stderr;
else
fp = fopen(name, "w");
return fp;
}
/* Initialize log system. Set default log level and file or use environment
* variables ONNX_MLIR_JNI_LOG_LEVEL and ONNX_MLIR_JNI_LOG_FILE, respectively.
*/
static void log_init() {
if (log_initd)
return;
log_level = LOG_INFO;
char *strlevel = getenv("ONNX_MLIR_JNI_LOG_LEVEL");
int level;
if (strlevel && (level = get_log_level_by_name(strlevel)) != -1)
log_level = level;
log_fp = stderr;
char *strfname = getenv("ONNX_MLIR_JNI_LOG_FILE");
FILE *fp;
if (strfname && (fp = get_log_file_by_name(strfname)))
log_fp = fp;
tzset();
log_initd = 1;
}
/* Generic log routine */
void log_printf(
int level, char *file, const char *func, int line, char *fmt, ...) {
if (!log_initd)
log_init();
if (level < log_level)
return;
time_t now;
struct tm *tm;
char buf[80];
/* Get local time and format as 2020-07-03 05:17:42 -0400 */
if (time(&now) == -1 || (tm = localtime(&now)) == NULL ||
strftime(buf, sizeof(buf), "%F %T %z", tm) == 0)
sprintf(buf, "-");
/* Output log prefix */
fprintf(log_fp, "[%s][%s]%s:%s:%d ", buf, log_level_name[level],
basename(file), func, line);
/* Output actually log data */
va_list va_list;
va_start(va_list, fmt);
vfprintf(log_fp, fmt, va_list);
va_end(va_list);
fprintf(log_fp, "\n");
}

125
src/Runtime/jni/jnilog.h Normal file
View File

@ -0,0 +1,125 @@
#ifndef __JNILOG_H__
#define __JNILOG_H__
#include <stdio.h>
enum { LOG_TRACE, LOG_DEBUG, LOG_INFO, LOG_WARNING, LOG_ERROR, LOG_FATAL };
#define LOG_MAX_NUM 16 /* max number of elements to output */
#define MIN(x, y) ((x) > (y) ? y : x)
/* Construct string of up to LOG_MAX_NUM elements of a char array */
#define LOG_CHAR_BUF(buf, data, n) \
do { \
buf[0] = '\0'; \
for (int i = 0; i < MIN(n, LOG_MAX_NUM); i++) \
sprintf(buf + strlen(buf), " %02x", ((char *)data)[i]); \
sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \
} while (0)
/* Construct string of up to LOG_MAX_NUM elements of a short array */
#define LOG_SHORT_BUF(buf, data, n) \
do { \
buf[0] = '\0'; \
for (int i = 0; i < MIN(n, LOG_MAX_NUM); i++) \
sprintf(buf + strlen(buf), " %d", ((short *)data)[i]); \
sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \
} while (0)
/* Construct string of up to LOG_MAX_NUM elements of a int array */
#define LOG_INT_BUF(buf, data, n) \
do { \
buf[0] = '\0'; \
for (int i = 0; i < MIN(n, LOG_MAX_NUM); i++) \
sprintf(buf + strlen(buf), " %d", ((int *)data)[i]); \
sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \
} while (0)
/* Construct string of up to LOG_MAX_NUM elements of a long array */
#define LOG_LONG_BUF(buf, data, n) \
do { \
buf[0] = '\0'; \
for (int i = 0; i < MIN(n, LOG_MAX_NUM); i++) \
sprintf(buf + strlen(buf), " %ld", ((long *)data)[i]); \
sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \
} while (0)
/* Construct string of up to LOG_MAX_NUM elements of a float array */
#define LOG_FLOAT_BUF(buf, data, n) \
do { \
buf[0] = '\0'; \
for (int i = 0; i < MIN(n, LOG_MAX_NUM); i++) \
sprintf(buf + strlen(buf), " %f", ((float *)data)[i]); \
sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \
} while (0)
/* Construct string of up to LOG_MAX_NUM elements of a double array */
#define LOG_DOUBLE_BUF(buf, data, n) \
do { \
buf[0] = '\0'; \
for (int i = 0; i < MIN(n, LOG_MAX_NUM); i++) \
sprintf(buf + strlen(buf), " %f", ((double *)data)[i]); \
sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \
} while (0)
enum {
ONNX_TYPE_UNDEFINED, /* 0 */
ONNX_TYPE_FLOAT, /* 1 */
ONNX_TYPE_UINT8, /* 2 */
ONNX_TYPE_INT8, /* 3 */
ONNX_TYPE_UINT16, /* 4 */
ONNX_TYPE_INT16, /* 5 */
ONNX_TYPE_INT32, /* 6 */
ONNX_TYPE_INT64, /* 7 */
ONNX_TYPE_STRING, /* 8 */
ONNX_TYPE_BOOL, /* 9 */
ONNX_TYPE_FLOAT16, /* 10 */
ONNX_TYPE_DOUBLE, /* 11 */
ONNX_TYPE_UINT32, /* 12 */
ONNX_TYPE_UINT64, /* 13 */
ONNX_TYPE_COMPLEX64, /* 14 */
ONNX_TYPE_COMPLEX128, /* 15 */
ONNX_TYPE_BFLOAT16, /* 16 */
};
/* Construct string of up to LOG_MAX_NUM elements of a "type" array */
#define LOG_TYPE_BUF(type, buf, data, n) \
do { \
switch (type) { \
case ONNX_TYPE_UINT8: \
case ONNX_TYPE_INT8: \
LOG_CHAR_BUF(buf, data, n); \
break; \
case ONNX_TYPE_UINT16: \
case ONNX_TYPE_INT16: \
LOG_SHORT_BUF(buf, data, n); \
break; \
case ONNX_TYPE_UINT32: \
case ONNX_TYPE_INT32: \
LOG_INT_BUF(buf, data, n); \
break; \
case ONNX_TYPE_UINT64: \
case ONNX_TYPE_INT64: \
LOG_LONG_BUF(buf, data, n); \
break; \
case ONNX_TYPE_FLOAT: \
LOG_FLOAT_BUF(buf, data, n); \
break; \
case ONNX_TYPE_DOUBLE: \
LOG_DOUBLE_BUF(buf, data, n); \
break; \
defaut: \
sprintf(buf, " unsupported data type %d ", type); \
} \
} while (0)
/* Main macro for log output */
#define LOG_PRINTF(level, ...) \
log_printf(level, __FILE__, __FUNCTION__, __LINE__, __VA_ARGS__)
/* Generic log routine */
void log_printf(
int level, char *file, const char *func, int line, char *fmt, ...);
#endif

View File

@ -0,0 +1,356 @@
#include <assert.h>
#include <malloc.h>
#include <string.h>
#include "RtMemRef.h"
#include "com_ibm_onnxmlir_DynEntryPoint.h"
#include "jnilog.h"
/* Declare type var, make call and assign to var, check against val */
#define CHECK_CALL(type, var, call, val) \
type var = call; \
if (var == val) \
return NULL
/* Make a JNI call and throw Java exception if the call failed */
#define JNI_CALL(env, stmt) \
stmt; \
do { \
jthrowable e = (*env)->ExceptionOccurred(env); \
if (e) { \
LOG_PRINTF(LOG_ERROR, "JNI call exception occurred"); \
(*env)->Throw(env, e); \
return NULL; \
} \
} while (0)
/* Make a JNI call and assign return value to var,
* throw Java exception if the call failed
*/
#define JNI_VAR_CALL(env, var, call) JNI_CALL(env, var = call)
/* Declare type var, make a JNI call and assign return value to var,
* throw Java exception if the call failed
*/
#define JNI_TYPE_VAR_CALL(env, type, var, call) JNI_CALL(env, type var = call);
/* If cond is true (native code failed), log error and throw Java exception */
#define JNI_COND(type, var, call, val, env, cls, ...) \
type var = call; \
do { \
if (var == val) { \
LOG_PRINTF(LOG_ERROR, __VA_ARGS__); \
(*env)->ThrowNew(env, cls, "native code error"); \
return NULL; \
} \
} while (0)
/* Debug output of RtMemRef fields */
#define RMR_DEBUG(i, type, rank, sizes, strides, data, datasize) \
do { \
char tmp[1024]; \
LOG_PRINTF(LOG_DEBUG, "rmr[%d]:type=%d", i, type); \
LOG_PRINTF(LOG_DEBUG, "rmr[%d]:rank=%d", i, rank); \
LOG_LONG_BUF(tmp, sizes, rank); \
LOG_PRINTF(LOG_DEBUG, "rmr[%d]:sizes=[%s]", i, tmp); \
LOG_LONG_BUF(tmp, strides, rank); \
LOG_PRINTF(LOG_DEBUG, "rmr[%d]:strides=[%s]", i, tmp); \
LOG_TYPE_BUF(type, tmp, data, datasize); \
LOG_PRINTF(LOG_DEBUG, "rmr[%d]:data=[%s]", i, tmp); \
} while (0)
/* Model shared library entry point */
extern OrderedRtMemRefDict *_dyn_entry_point_main_graph(OrderedRtMemRefDict *);
/* ONNX type to size (number of bytes) mapping */
int onnx_type_size[] = {
0, /* UNDEFINED = 0 */
4, /* FLOAT = 1 */
1, /* UINT8 = 2 */
1, /* INT8 = 3 */
2, /* UINT16 = 4 */
2, /* INT16 = 5 */
4, /* INT32 = 6 */
8, /* INT64 = 7 */
0, /* STRING = 8 */
1, /* BOOL = 9 */
2, /* FLOAT16 = 10 */
8, /* DOUBLE = 11 */
4, /* UINT32 = 12 */
8, /* UINT64 = 13 */
8, /* COMPLEX64 = 14 */
16, /* COMPLEX128 = 15 */
2, /* BFLOAT16 = 16 */
};
/* Java classes and methods needed for making various JNI API calls */
typedef struct {
jclass ecpt_cls; /* java/lang/Exception class */
jclass long_cls; /* java/lang/Long class */
jclass string_cls; /* java/lang/String class */
jclass ormrd_cls; /* com/ibm/onnxmlir/OrderedRtMemRefDict class */
jclass rmr_cls; /* com/ibm/onnxmlir/RtMemRef class */
jmethodID ormrd_constructor; /* OrderedRtMemRefDict constructor */
jmethodID ormrd_getRmrs; /* OrderedRtMemRefDict getRmrs method */
jmethodID ormrd_getNames; /* OrderedRtMemRefDict getNames method */
jmethodID rmr_constructor; /* RtMemRef constructor */
jmethodID rmr_getType; /* RtMemRef getType method */
jmethodID rmr_setType; /* RtMemRef setType method */
jmethodID rmr_getRank; /* RtMemRef getRank method */
jmethodID rmr_getData; /* RtMemRef getData method */
jmethodID rmr_setData; /* RtMemRef setData method */
jmethodID rmr_getSizes; /* RtMemRef getSizes method */
jmethodID rmr_setSizes; /* RtMemRef setSizes method */
jmethodID rmr_getStrides; /* RtMemRef getStrides method */
jmethodID rmr_setStrides; /* RtMemRef setStrides method */
jmethodID rmr_getDataSize; /* RtMemRef getDataSize method */
} jniapi_t;
jniapi_t jniapi;
/* Fill in struct jniapi */
jniapi_t *fill_jniapi(JNIEnv *env, jniapi_t *japi) {
/* Get Java Exception, Long, String, OrderedRtMemRefDict, and RtMemRef classes
*/
JNI_VAR_CALL(
env, japi->ecpt_cls, (*env)->FindClass(env, "java/lang/Exception"));
JNI_VAR_CALL(env, japi->long_cls, (*env)->FindClass(env, "java/lang/Long"));
JNI_VAR_CALL(
env, japi->string_cls, (*env)->FindClass(env, "java/lang/String"));
JNI_VAR_CALL(env, japi->ormrd_cls,
(*env)->FindClass(env, "com/ibm/onnxmlir/OrderedRtMemRefDict"));
JNI_VAR_CALL(
env, japi->rmr_cls, (*env)->FindClass(env, "com/ibm/onnxmlir/RtMemRef"));
/* Get method ID of constructor and various methods in OrderedRtMemRefDict */
JNI_VAR_CALL(env, japi->ormrd_constructor,
(*env)->GetMethodID(
env, japi->ormrd_cls, "<init>", "([Lcom/ibm/onnxmlir/RtMemRef;)V"));
JNI_VAR_CALL(env, japi->ormrd_getRmrs,
(*env)->GetMethodID(
env, japi->ormrd_cls, "getRmrs", "()[Lcom/ibm/onnxmlir/RtMemRef;"));
JNI_VAR_CALL(env, japi->ormrd_getNames,
(*env)->GetMethodID(
env, japi->ormrd_cls, "getNames", "()[Ljava/lang/String;"));
/* Get method ID of constructor and various methods in RtMemRef */
JNI_VAR_CALL(env, japi->rmr_constructor,
(*env)->GetMethodID(env, japi->rmr_cls, "<init>", "(I)V"));
JNI_VAR_CALL(env, japi->rmr_getType,
(*env)->GetMethodID(env, japi->rmr_cls, "getType", "()I"));
JNI_VAR_CALL(env, japi->rmr_setType,
(*env)->GetMethodID(env, japi->rmr_cls, "setType", "(I)V"));
JNI_VAR_CALL(env, japi->rmr_getRank,
(*env)->GetMethodID(env, japi->rmr_cls, "getRank", "()I"));
JNI_VAR_CALL(env, japi->rmr_getData,
(*env)->GetMethodID(
env, japi->rmr_cls, "getData", "()Ljava/nio/ByteBuffer;"));
JNI_VAR_CALL(env, japi->rmr_setData,
(*env)->GetMethodID(
env, japi->rmr_cls, "setData", "(Ljava/nio/ByteBuffer;)V"));
JNI_VAR_CALL(env, japi->rmr_getSizes,
(*env)->GetMethodID(env, japi->rmr_cls, "getSizes", "()[J"));
JNI_VAR_CALL(env, japi->rmr_setSizes,
(*env)->GetMethodID(env, japi->rmr_cls, "setSizes", "([J)V"));
JNI_VAR_CALL(env, japi->rmr_getStrides,
(*env)->GetMethodID(env, japi->rmr_cls, "getStrides", "()[J"));
JNI_VAR_CALL(env, japi->rmr_setStrides,
(*env)->GetMethodID(env, japi->rmr_cls, "setStrides", "([J)V"));
JNI_VAR_CALL(env, japi->rmr_getDataSize,
(*env)->GetMethodID(env, japi->rmr_cls, "getDataSize", "()J"));
return japi;
}
/* Convert Java object to native data structure */
OrderedRtMemRefDict *ormrd_java_to_native(
JNIEnv *env, jclass cls, jobject obj, jniapi_t *japi) {
/* Get object array "rmrs" and "names" in OrderedRtMemRefDict */
JNI_TYPE_VAR_CALL(env, jobjectArray, ormrd_rmrs,
(*env)->CallObjectMethod(env, obj, japi->ormrd_getRmrs));
JNI_TYPE_VAR_CALL(env, jobjectArray, ormrd_names,
(*env)->CallObjectMethod(env, obj, japi->ormrd_getNames));
/* Get length of object array "rmrs" and "names" in OrderedRtMemRefDict */
JNI_TYPE_VAR_CALL(
env, jsize, ormrd_rmrs_len, (*env)->GetArrayLength(env, ormrd_rmrs));
JNI_TYPE_VAR_CALL(
env, jsize, ormrd_names_len, (*env)->GetArrayLength(env, ormrd_names));
/* Allocate memory for holding each Java rmr object and name string,
* and RtMemRef and char pointers for constructing native RtMemRef and name
* array
*/
JNI_COND(jobject *, obj_rmr, malloc(ormrd_rmrs_len * sizeof(jobject)), NULL,
env, japi->ecpt_cls, "obj_rmr=null");
JNI_COND(jstring *, obj_name, malloc(ormrd_names_len * sizeof(jstring)), NULL,
env, japi->ecpt_cls, "obj_name=null");
JNI_COND(RtMemRef **, jni_rmr, malloc(ormrd_rmrs_len * sizeof(RtMemRef *)),
NULL, env, japi->ecpt_cls, "jni_rmr=null");
JNI_COND(const char **, jni_name,
malloc(ormrd_names_len * sizeof(const char *)), NULL, env, japi->ecpt_cls,
"jni_name=null");
/* Create OrderedRtMemRefDict to be constructed and passed to the model shared
* library */
JNI_COND(OrderedRtMemRefDict *, ormrd, createOrderedRtMemRefDict(), NULL, env,
japi->ecpt_cls, "ormrd=null");
/* Loop through all the ormrd_rmrs and ormrd_names */
for (int i = 0; i < ormrd_rmrs_len; i++) {
JNI_VAR_CALL(
env, obj_rmr[i], (*env)->GetObjectArrayElement(env, ormrd_rmrs, i));
JNI_VAR_CALL(
env, obj_name[i], (*env)->GetObjectArrayElement(env, ormrd_names, i));
/* Get type, rank, data, sizes, and strides by calling corresponding methods
*/
JNI_TYPE_VAR_CALL(env, jint, rmr_type,
(*env)->CallIntMethod(env, obj_rmr[i], japi->rmr_getType));
JNI_TYPE_VAR_CALL(env, jint, rmr_rank,
(*env)->CallIntMethod(env, obj_rmr[i], japi->rmr_getRank));
JNI_TYPE_VAR_CALL(env, jlong, rmr_datasize,
(*env)->CallLongMethod(env, obj_rmr[i], japi->rmr_getDataSize));
JNI_TYPE_VAR_CALL(env, jobject, rmr_data,
(*env)->CallObjectMethod(env, obj_rmr[i], japi->rmr_getData));
JNI_TYPE_VAR_CALL(env, jobject, rmr_sizes,
(*env)->CallObjectMethod(env, obj_rmr[i], japi->rmr_getSizes));
JNI_TYPE_VAR_CALL(env, jobject, rmr_strides,
(*env)->CallObjectMethod(env, obj_rmr[i], japi->rmr_getStrides));
/* Primitive type int and long can be directly used */
int jni_type = rmr_type, jni_rank = rmr_rank;
long jni_datasize = rmr_datasize;
/* Get direct buffer associated with data */
JNI_TYPE_VAR_CALL(
env, void *, jni_data, (*env)->GetDirectBufferAddress(env, rmr_data));
/* Get long array associated with sizes and strides */
JNI_TYPE_VAR_CALL(env, long *, jni_sizes,
(*env)->GetLongArrayElements(env, rmr_sizes, NULL));
JNI_TYPE_VAR_CALL(env, long *, jni_strides,
(*env)->GetLongArrayElements(env, rmr_strides, NULL));
/* Print debug info on what we got from the Java side */
RMR_DEBUG(
i, jni_type, jni_rank, jni_sizes, jni_strides, jni_data, jni_datasize);
/* Create native RtMemRef struct and fill in its fields */
jni_rmr[i] = createRtMemRef(jni_rank);
setDType(jni_rmr[i], jni_type);
setData(jni_rmr[i], jni_data);
setSizes(jni_rmr[i], jni_sizes);
setStrides(jni_rmr[i], jni_strides);
/*jni_name[i] = (*env)->GetStringUTFChars(env, obj_name[i], NULL);
printf("jni_name=%s\n", jni_name[i]);*/
/* Install RtMemRef into OrderedRtMemRefDict */
setRtMemRef(ormrd, i, jni_rmr[i]);
/* Release reference to the java objects */
JNI_CALL(
env, (*env)->ReleaseLongArrayElements(env, rmr_sizes, jni_sizes, 0));
JNI_CALL(env,
(*env)->ReleaseLongArrayElements(env, rmr_strides, jni_strides, 0));
}
/* setRtMemRef(ormrd, jni_rmr, jni_name); */
return ormrd;
}
/* Convert native data structure to Java object */
jobject ormrd_native_to_java(
JNIEnv *env, jclass cls, OrderedRtMemRefDict *dict, jniapi_t *japi) {
JNI_COND(int, nrmr, numRtMemRefs(dict), 0, env, japi->ecpt_cls, "nrmr=0");
/* Create RtMemRef java object array */
JNI_TYPE_VAR_CALL(env, jobjectArray, rmrs,
(*env)->NewObjectArray(env, nrmr, japi->rmr_cls, NULL));
/* Loop through the native RtMemRef structs */
for (int i = 0; i < nrmr; i++) {
JNI_COND(RtMemRef *, rmr, getRtMemRef(dict, i), NULL, env, japi->ecpt_cls,
"rmr[%d]=null", i);
JNI_COND(int, jni_type, getDType(rmr), 0, env, japi->ecpt_cls,
"rmr[%d]:type=0", i);
JNI_COND(int, jni_rank, getRank(rmr), 0, env, japi->ecpt_cls,
"rmr[%d]:rank=0", i);
JNI_COND(long, jni_datasize, getDataSize(rmr), 0, env, japi->ecpt_cls,
"rmr[%d]:datasize=0", i);
JNI_COND(void *, jni_data, getData(rmr), NULL, env, japi->ecpt_cls,
"rmr[%d]:data=null", i);
JNI_COND(long *, jni_sizes, getSizes(rmr), NULL, env, japi->ecpt_cls,
"rmr[%d]:sizes=null", i);
JNI_COND(long *, jni_strides, getStrides(rmr), NULL, env, japi->ecpt_cls,
"rmr[%d]:strides=null", i);
/* Print debug info on what we got from the native side */
RMR_DEBUG(
i, jni_type, jni_rank, jni_sizes, jni_strides, jni_data, jni_datasize);
/* create the following Java objects:
* - RtMemRef
* - DirectByteBuffer (from native buffers)
* - long array for sizes and strides
*/
JNI_TYPE_VAR_CALL(env, jobject, obj_rmr,
(*env)->NewObject(env, japi->rmr_cls, japi->rmr_constructor, jni_rank));
JNI_TYPE_VAR_CALL(env, jobject, rmr_data,
(*env)->NewDirectByteBuffer(
env, jni_data, jni_datasize * onnx_type_size[jni_type]));
JNI_TYPE_VAR_CALL(
env, jlongArray, rmr_sizes, (*env)->NewLongArray(env, jni_rank));
JNI_TYPE_VAR_CALL(
env, jlongArray, rmr_strides, (*env)->NewLongArray(env, jni_rank));
/* Call setType method */
JNI_CALL(env,
(*env)->CallObjectMethod(env, obj_rmr, japi->rmr_setType, jni_type));
/* Call setData method */
JNI_CALL(env,
(*env)->CallObjectMethod(env, obj_rmr, japi->rmr_setData, rmr_data));
/* Fill in sizes array from native array and call setSizes method */
JNI_CALL(env,
(*env)->SetLongArrayRegion(env, rmr_sizes, 0, jni_rank, jni_sizes));
JNI_CALL(env,
(*env)->CallObjectMethod(env, obj_rmr, japi->rmr_setSizes, rmr_sizes));
/* Fill in strides array from native array and call setStrides method */
JNI_CALL(env,
(*env)->SetLongArrayRegion(env, rmr_strides, 0, jni_rank, jni_strides));
JNI_CALL(env, (*env)->CallObjectMethod(
env, obj_rmr, japi->rmr_setStrides, rmr_strides));
/* Set DynMemRef object in the object array */
JNI_CALL(env, (*env)->SetObjectArrayElement(env, rmrs, i, obj_rmr));
}
/* Create the OrderedRtMemRefDict java object */
JNI_TYPE_VAR_CALL(env, jobject, ormrd,
(*env)->NewObject(env, japi->ormrd_cls, japi->ormrd_constructor, rmrs));
return ormrd;
}
JNIEXPORT jobject JNICALL Java_com_ibm_onnxmlir_DynEntryPoint_main_1graph_1jni(
JNIEnv *env, jclass cls, jobject obj) {
CHECK_CALL(jniapi_t *, japi, fill_jniapi(env, &jniapi), NULL);
CHECK_CALL(OrderedRtMemRefDict *, input_ormrd,
ormrd_java_to_native(env, cls, obj, japi), NULL);
CHECK_CALL(OrderedRtMemRefDict *, dict,
_dyn_entry_point_main_graph(input_ormrd), NULL);
CHECK_CALL(
jobject, output_ormrd, ormrd_native_to_java(env, cls, dict, japi), NULL);
return output_ormrd;
}

View File

@ -0,0 +1,64 @@
package com.ibm.onnxmlir;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.StandardCopyOption;
import java.util.jar.JarFile;
public class DynEntryPoint {
static String libname = "libmodel.so";
static {
File jar;
JarFile jf;
String jarDir = null;
String libPath = null;
try {
// Get path name of jar
jar = new File(DynEntryPoint.class.getProtectionDomain()
.getCodeSource()
.getLocation().toURI());
jarDir = jar.getParentFile().getAbsolutePath();
libPath = jarDir + "/" + libname;
// Open jar file to read and check libname inside jar.
// If IOException thrown, load .so from where .jar is.
//
// Checking whether DynEntryPoint.class.getResourceAsStream returns null
// does NOT work. Because it checks whether the resource is
// available on the classpath, not only just inside the jar file.
jf = new JarFile(jar);
if (jf.getEntry(libname) != null) {
File libFile = new File(libPath);
// Copy .so to where jar is
Files.copy(jf.getInputStream(jf.getEntry(libname)),
libFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
// z/OS USS requires "x" permission bit
libFile.setExecutable(true, false);
// Load the temporary .so copy
System.load(libPath);
// POSIX can unlink file after loading
libFile.delete();
}
else {
// Throw subclass of IOException
throw new FileNotFoundException(".so not found inside jar");
}
} catch (URISyntaxException e) {
// Failed to find jar path, assume the .so is in cwd
System.load(libname);
} catch (IOException e) {
// .so not found in jar, assume it's where jar is
System.load(libPath);
}
}
private static native OrderedRtMemRefDict main_graph_jni(OrderedRtMemRefDict ormrd);
public static OrderedRtMemRefDict main_graph(OrderedRtMemRefDict ormrd) {
return main_graph_jni(ormrd);
}
}

View File

@ -0,0 +1,118 @@
package com.ibm.onnxmlir;
import java.util.HashMap;
public class OrderedRtMemRefDict {
private RtMemRef[] _rmrs;
private String[] _names;
private HashMap<String, Integer> _n2i;
/**
* Constructor
*
* @param rmrs DynMemRef array
*/
public OrderedRtMemRefDict(RtMemRef[] rmrs) {
this(rmrs, null);
}
/**
* Constructor
*
* @param rmrs DynMemRef array
* @param names name array
*/
public OrderedRtMemRefDict(RtMemRef[] rmrs, String[] names) {
/* rmrs cannot be null or empty */
if (rmrs == null || rmrs.length == 0)
throw new IllegalArgumentException(
"Number of dmrs is invalid");
/* If names is null or empty, construct a default one with
* index as name.
*/
if (names == null || names.length == 0) {
names = new String[rmrs.length];
for (int i = 0; i < rmrs.length; i++)
names[i] = Integer.toString(i);
}
/* Number of rmrs and names must match */
if (rmrs.length != names.length)
throw new IllegalArgumentException(
"Number of dmrs and names do not match");
/* Establish name to index mapping. Individual rmr is
* checked for validity.
*/
_n2i = new HashMap<String, Integer>();
for (int i = 0; i < names.length; i++) {
if (rmrs[i] == null || !rmrs[i].validRmr())
throw new IllegalArgumentException(
"rmrs[" + i + "] is invalid");
if (_n2i.put(names[i], i) != null)
throw new IllegalArgumentException(
"name[" + i + "] = " + names[i] + " not unique");
}
_rmrs = rmrs;
_names = names;
}
/**
* RtMemRef getter by index
*
* @param idx index of RtMemRef instance to get
* @return RtMemRef instance
*/
public RtMemRef getRmrbyIndex(int idx) {
return _rmrs[idx];
}
/**
* RtMemRef getter by name
*
* @param name name of RtMemRef instance to get
* @return RtMemRef instance
*/
public RtMemRef getRmrByName(String name) {
return _rmrs[_n2i.get(name)];
}
/**
* RtMemRef array getter
*
* @return RtMemRef array
*/
public RtMemRef[] getRmrs() {
return _rmrs;
}
/**
* Name getter
*
* @param idx index of name to get
* @return name string
*/
public String getName(int idx) {
return _names[idx];
}
/**
* Name array getter
*
* @return name array
*/
public String[] getNames() {
return _names;
}
/**
* RtMemRef array size getter
*
* @return RtMemRef array size
*/
public int size() {
return _rmrs.length;
}
}

View File

@ -0,0 +1,400 @@
package com.ibm.onnxmlir;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
public class RtMemRef {
final ByteOrder endian = ByteOrder.nativeOrder();
/* We can use enum but that creates another class
* which complicates things for JNI.
*/
final int ONNX_TYPE_UNDEFINED = 0;
final int ONNX_TYPE_FLOAT = 1;
final int ONNX_TYPE_UINT8 = 2;
final int ONNX_TYPE_INT8 = 3;
final int ONNX_TYPE_UINT16 = 4;
final int ONNX_TYPE_INT16 = 5;
final int ONNX_TYPE_INT32 = 6;
final int ONNX_TYPE_INT64 = 7;
final int ONNX_TYPE_STRING = 8;
final int ONNX_TYPE_BOOL = 9;
final int ONNX_TYPE_FLOAT16 = 10;
final int ONNX_TYPE_DOUBLE = 11;
final int ONNX_TYPE_UINT32 = 12;
final int ONNX_TYPE_UINT64 = 13;
final int ONNX_TYPE_COMPLEX64 = 14;
final int ONNX_TYPE_COMPLEX128 = 15;
final int ONNX_TYPE_BFLOAT16 = 16;
final int[] ONNX_TYPE_SIZE = new int[] {
0, /* UNDEFINED */
4, /* FLOAT */
1, /* UINT8 */
1, /* INT8 */
2, /* UINT16 */
2, /* INT16 */
4, /* INT32 */
8, /* INT64 */
0, /* STRING */
1, /* BOOL */
2, /* FLOAT16 */
8, /* DOUBLE */
4, /* UINT32 */
8, /* UINT64 */
8, /* COMPLEX64 */
16, /* COMPLEX128 */
2, /* BFLOAT16 */
};
private ByteBuffer _data;
private int _type;
private int _rank;
private long[] _sizes;
private long[] _strides;
/**
* Constructor
*/
public RtMemRef(int rank) {
if (rank <= 0)
throw new IllegalArgumentException(
"invalid rank " + rank);
_data = null;
_type = ONNX_TYPE_UNDEFINED;
_rank = rank;
_sizes = new long[rank];
_strides = new long[rank];
}
/* ---------- Data type getter and setter ---------- */
/* For JNI wrapper only. Not intended for end user. */
/**
* Type getter
*
* @return data type
*/
@SuppressWarnings("unused")
private int getType() {
return _type;
}
/**
* Type setter
*
* @param type data type to be set
*/
@SuppressWarnings("unused")
private void setType(int type) {
_type = type;
}
/* ---------- Raw data getter and setter ---------- */
/* For JNI wrapper only. Not intended for end user. */
/**
* Raw data getter
*
* @return raw data
*/
@SuppressWarnings("unused")
private ByteBuffer getData() {
return _data;
}
/**
* Raw data setter
*
* @param data raw data to be set
*/
@SuppressWarnings("unused")
private void setData(ByteBuffer data) {
_data = data.order(endian);
}
/* ---------- Byte data getter and setter ---------- */
/**
* Byte data getter
*
* @return byte data array
*/
public byte[] getByteData() {
if (_data == null) return null;
/* asReadOnlyBuffer() creates a new view so the position of the
* original data will stay at 0 for subsequent getByteData()
* after get(b).
*/
byte[] b = new byte[_data.limit()];
_data.asReadOnlyBuffer().get(b);
return b;
}
/**
* Byte data setter
*
* @param data byte array to be set
*/
public void setByteData(byte[] data) {
/* slice() creates a new view so the position of the
* original data will stay at 0 for getByteData() after put(data).
*/
_data = ByteBuffer.allocateDirect(data.length);
_data.slice().put(data);
_type = ONNX_TYPE_INT8;
}
/* ---------- Short data getter and setter ---------- */
/**
* Short data getter
*
* @return short data array
*/
public short[] getShortData() {
if (_data == null) return null;
/* asShortBuffer() creates a new view so the position of the
* original data will stay at 0 for subsequent getShortData()
* after get(s).
*/
ShortBuffer sb = _data.asShortBuffer();
short[] s = new short[sb.limit()];
sb.get(s);
return s;
}
/**
* Short data setter
*
* @param data short array to be set
*/
public void setShortData(short[] data) {
/* asShortBuffer() creates a new view so the position of the
* original data will stay at 0 for getShortData() after put(data).
*/
_data = ByteBuffer.allocateDirect(data.length*2).order(endian);
_data.asShortBuffer().put(data);
_type = ONNX_TYPE_INT16;
}
/* ---------- Int data getter and setter ---------- */
/**
* Int data getter
*
* @return int data array
*/
public int[] getIntData() {
if (_data == null) return null;
/* asIntBuffer() creates a new view so the position of the
* original data will stay at 0 for subsequent getIntData()
* after get(i).
*/
IntBuffer ib = _data.asIntBuffer();
int[] i = new int[ib.limit()];
ib.get(i);
return i;
}
/**
* Int data setter
*
* @param data int array to be set
*/
public void setIntData(int[] data) {
/* asIntBuffer() creates a new view so the position of the
* original data will stay at 0 for getIntData() after put(data).
*/
_data = ByteBuffer.allocateDirect(data.length*4).order(endian);
_data.asIntBuffer().put(data);
_type = ONNX_TYPE_INT32;
}
/* ---------- Long data getter and setter ---------- */
/**
* Long data getter
*
* @return long data array
*/
public long[] getLongData() {
if (_data == null) return null;
/* asLongBuffer() creates a new view so the position of the
* original data will stay at 0 for subsequent getLongData()
* after get(l).
*/
LongBuffer lb = _data.asLongBuffer();
long[] l = new long[lb.limit()];
lb.get(l);
return l;
}
/**
* Long data setter
*
* @param data long array to be set
*/
public void setLongData(long[] data) {
/* asLongBuffer() creates a new view so the position of the
* original data will stay at 0 for getLongData() after put(data).
*/
_data = ByteBuffer.allocateDirect(data.length*8).order(endian);
_data.asLongBuffer().put(data);
_type = ONNX_TYPE_INT64;
}
/* ---------- Float data getter and setter ---------- */
/**
* Float data getter
*
* @return float data array
*/
public float[] getFloatData() {
if (_data == null) return null;
/* asFloatBuffer() creates a new view so the position of the
* original data will stay at 0 for subsequent getFloatData()
* after get(f).
*/
FloatBuffer fb = _data.asFloatBuffer();
float[] f = new float[fb.limit()];
fb.get(f);
return f;
}
/**
* Float data setter
*
* @param data float array to be set
*/
public void setFloatData(float[] data) {
/* asFloatBuffer() creates a new view so the position of the
* original data will stay at 0 for getFloatData() after put(data).
*/
_data = ByteBuffer.allocateDirect(data.length*4).order(endian);
_data.asFloatBuffer().put(data);
_type = ONNX_TYPE_FLOAT;
}
/* ---------- Double data getter and setter ---------- */
/**
* Double data getter
*
* @return double data array
*/
public double[] getDoubleData() {
if (_data == null) return null;
/* asDoubleBuffer() creates a new view so the position of the
* original data will stay at 0 for subsequent getDoubleData()
* after get(d).
*/
DoubleBuffer db = _data.asDoubleBuffer();
double[] d = new double[db.limit()];
db.get(d);
return d;
}
/**
* Double data setter
*
* @param data double array to be set
*/
public void setDoubleData(double[] data) {
/* asDoubleBuffer() creates a new view so the position of the
* original data will stay at 0 for getDoubleData() after put(data).
*/
_data = ByteBuffer.allocateDirect(data.length*8).order(endian);
_data.asDoubleBuffer().put(data);
_type = ONNX_TYPE_DOUBLE;
}
/**
* Rank getter
*
* @return rank
*/
public int getRank() {
return _rank;
}
/* ---------- Sizes getter and setter ---------- */
/**
* Sizes getter
*
* @return sizes array
*/
public long[] getSizes() {
return _sizes;
}
/**
* Sizes setter
*
* @param sizes sizes array to be set
*/
public void setSizes(long[] sizes) {
if (sizes.length != _rank)
throw new IllegalArgumentException(
"array length " + sizes.length + " != rank " + _rank);
_sizes = sizes.clone();
}
/* ---------- Strides getter and setter ---------- */
/**
* Strides getter
*
* @return strides array
*/
public long[] getStrides() {
return _strides;
}
/**
* Strides setter
*
* @param strides strides array to be set
*/
public void setStrides(long[] strides) {
if (strides.length != _rank)
throw new IllegalArgumentException(
"array length " + strides.length + " != rank " + _rank);
_strides = strides.clone();
}
/**
* Size getter
*
* @return product of sizes array, i.e., total number of data elements
*/
public long getDataSize() {
long n = _sizes[0];
for (int i = 1; i < _sizes.length; i++) n *= _sizes[i];
return n;
}
/**
* Check validity of RtMemRef
*
* @return true if RtMemRef is valid, false otherwise
*/
public boolean validRmr() {
return (_data != null &&
_data.limit() != 0 &&
_data.limit() == getDataSize() * ONNX_TYPE_SIZE[_type]);
}
}

View File

@ -4,12 +4,6 @@ add_dependencies(onnx-mlir-opt OMKrnlOpsInc OMONNXOpsInc)
target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_SRC_ROOT}) target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_SRC_ROOT})
target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_BIN_ROOT}) target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_BIN_ROOT})
set(ONNX_MLIR_LD_PRELOAD_onnx-mlir-opt "" CACHE STRING "" FORCE)
if(BUILD_SHARED_LIBS)
message(STATUS "To run dynamically linked onnx-mlir-opt, you must specify:")
message(STATUS "LD_PRELOAD=${ONNX_MLIR_LD_PRELOAD_onnx-mlir-opt}")
endif()
target_link_libraries(onnx-mlir-opt target_link_libraries(onnx-mlir-opt
${OMLibs} ${OMLibs}
${MLIRLibs} ${MLIRLibs}

View File

@ -12,6 +12,7 @@ using namespace std;
using namespace onnx_mlir; using namespace onnx_mlir;
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
setExecPath(argv[0], (void *)main);
registerDialects(); registerDialects();
llvm::cl::OptionCategory OnnxMlirOptions( llvm::cl::OptionCategory OnnxMlirOptions(
@ -33,7 +34,9 @@ int main(int argc, char *argv[]) {
clEnumVal(EmitLLVMIR, "Lower model to LLVM IR (LLVM dialect)."), clEnumVal(EmitLLVMIR, "Lower model to LLVM IR (LLVM dialect)."),
clEnumVal(EmitLib, "Lower model to LLVM IR, emit (to file) " clEnumVal(EmitLib, "Lower model to LLVM IR, emit (to file) "
"LLVM bitcode for model, compile and link it to a " "LLVM bitcode for model, compile and link it to a "
"shared library.")), "shared library."),
clEnumVal(EmitJNI, "Lower model to LLMV IR -> LLVM bitcode "
"-> JNI shared library -> jar")),
llvm::cl::init(EmitLib), llvm::cl::cat(OnnxMlirOptions)); llvm::cl::init(EmitLib), llvm::cl::cat(OnnxMlirOptions));
llvm::cl::HideUnrelatedOptions(OnnxMlirOptions); llvm::cl::HideUnrelatedOptions(OnnxMlirOptions);

View File

@ -13,6 +13,8 @@
#include "src/MainUtils.hpp" #include "src/MainUtils.hpp"
#include "src/Runtime/ExecusionSession.hpp" #include "src/Runtime/ExecusionSession.hpp"
#define SHARED_LIB_BASE string("./TestConv_main_graph")
using namespace std; using namespace std;
// Returns whether onnx-mlir compiled convolution is producing the same results // Returns whether onnx-mlir compiled convolution is producing the same results
@ -87,14 +89,9 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H,
OwningModuleRef moduleRef(module); OwningModuleRef moduleRef(module);
llvm::SmallVector<char, 10> path; compileModule(moduleRef, ctx, SHARED_LIB_BASE, EmitLib);
llvm::sys::fs::createTemporaryFile("_main_graph", "", path);
string pathStr(path.begin(), path.end());
llvm::FileRemover remover(path);
compileModule(moduleRef, ctx, pathStr, EmitLib);
onnx_mlir::ExecutionSession sess( onnx_mlir::ExecutionSession sess(
pathStr + ".so", "_dyn_entry_point_main_graph"); SHARED_LIB_BASE + ".so", "_dyn_entry_point_main_graph");
std::vector<unique_ptr<RtMemRef>> inputs; std::vector<unique_ptr<RtMemRef>> inputs;
auto xRmr = unique_ptr<RtMemRef>(getRndRealRmr<float>({N, C, H, W})); auto xRmr = unique_ptr<RtMemRef>(getRndRealRmr<float>({N, C, H, W}));
@ -127,7 +124,10 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H,
return isRmrClose<float>(conv.get(), ref); return isRmrClose<float>(conv.get(), ref);
} }
int main() { int main(int argc, char *argv[]) {
setExecPath(argv[0], (void *)main);
llvm::FileRemover remover(SHARED_LIB_BASE + ".so");
// RapidCheck test case generation. // RapidCheck test case generation.
rc::check("convolution implementation correctness", []() { rc::check("convolution implementation correctness", []() {
const auto N = *rc::gen::inRange(1, 10); const auto N = *rc::gen::inRange(1, 10);