Add emitjni target (#204)
* Detect llvm-project commit change in utils/clone-mlir.sh and rebuild llvm-project for zLinux Jenkins build bot * Add --EmitJNI target (tested working with mnist and resnet50) - MainUtils * first shot at refactoring compileModuleToSharedLibrary * add setExecPath call to allow resolving runtime directory from onnx-mlir executable path when ONNX_MLIR_RUNTIME_DIR is not set. This allows tests to run without having to install onnx-mlir or to explicitly set ONNX_MLIR_RUNTIME_DIR - RtMemRef * add getDataSize for C (equivalent of size() for C++). * fix setStrides bug (setting sizes, not strides) - TestConv * _main_graph-*.so were filling up /tmp. Change to use fixed shared library in build directory * Fix clang-format-lint complaints * - getRuntimeDir checks lib64 - install targets for javaruntime and jniruntime - remove ONNX_MLIR_LD_PRELOAD_onnx-mlir and ONNX_MLIR_LD_PRELOAD_onnx-mlir-opt * See what happens when `kExecPath` decl is dropped. Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
af75b4c75e
commit
d235f248e4
|
@ -71,5 +71,5 @@ cmake -DCMAKE_INSTALL_PREFIX=${INSTALL_PATH} .. \
|
|||
|
||||
make -j$(nproc)
|
||||
make -j$(nproc) check-onnx-lit
|
||||
RUNTIME_DIR=$(pwd)/lib make -j$(nproc) check-onnx-backend
|
||||
RUNTIME_DIR=$(pwd)/lib PATH=$(pwd)/bin:$PATH make -j$(nproc) test
|
||||
make -j$(nproc) check-onnx-backend
|
||||
PATH=$(pwd)/bin:$PATH make -j$(nproc) test
|
||||
|
|
|
@ -55,12 +55,6 @@ endif()
|
|||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/ExternalUtil.hpp.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/ExternalUtil.hpp)
|
||||
|
||||
set(ONNX_MLIR_LD_PRELOAD_onnx-mlir "" CACHE STRING "" FORCE)
|
||||
if(BUILD_SHARED_LIBS)
|
||||
message(STATUS "To run dynamically linked onnx-mlir, you must specify:")
|
||||
message(STATUS "LD_PRELOAD=${ONNX_MLIR_LD_PRELOAD_onnx-mlir}")
|
||||
endif()
|
||||
|
||||
# Libraries specified on the target_link_libraries for the add_subdirectory
|
||||
# targets get added to the end of the list here. This creates two problems:
|
||||
# 1. It produces duplicated libraries being specified for the link command.
|
||||
|
|
|
@ -3,8 +3,12 @@
|
|||
#include <string>
|
||||
|
||||
namespace onnx_mlir {
|
||||
std::string kExecPath = "@CMAKE_INSTALL_PREFIX@/bin/onnx-mlir"; /* fallback if not set by main */
|
||||
const std::string kInstPath = "@CMAKE_INSTALL_PREFIX@";
|
||||
const std::string kLlcPath = "@LLVM_PROJ_BUILD@/bin/llc";
|
||||
const std::string kCxxPath = "@CMAKE_CXX_COMPILER@";
|
||||
const std::string kLinkerPath = "@CMAKE_LINKER@";
|
||||
const std::string kObjCopyPath = "@CMAKE_OBJCOPY@";
|
||||
const std::string kArPath = "@CMAKE_AR@";
|
||||
const std::string kJarPath = "@Java_JAR_EXECUTABLE@";
|
||||
} // namespace onnx_mlir
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===--------------------------- main_utils.cpp ---------------------------===//
|
||||
//===--------------------------- MainUtils.cpp ---------------------------===//
|
||||
//
|
||||
// Copyright 2019-2020 The IBM Research Authors.
|
||||
//
|
||||
|
@ -13,6 +13,7 @@
|
|||
#include <fcntl.h>
|
||||
#include <regex>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <llvm/Support/FileSystem.h>
|
||||
#include <llvm/Support/Program.h>
|
||||
|
@ -22,8 +23,6 @@
|
|||
#include "src/ExternalUtil.hpp"
|
||||
#include "src/MainUtils.hpp"
|
||||
|
||||
#include "MainUtils.hpp"
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <io.h>
|
||||
#else
|
||||
|
@ -41,6 +40,42 @@ llvm::Optional<std::string> getEnvVar(std::string name) {
|
|||
return llvm::None;
|
||||
}
|
||||
|
||||
// Runtime directory contains all the libraries, jars, etc. that are
|
||||
// necessary for running onnx-mlir. It's resolved in the following order:
|
||||
//
|
||||
// - if ONNX_MLIR_RUNTIME_DIR is set, use it, otherwise
|
||||
// - get path from where onnx-mlir is run, if it's of the form
|
||||
// /foo/bar/bin/onnx-mlir,
|
||||
// the runtime directory is /foo/bar/lib (note that when onnx-mlir is
|
||||
// installed system wide, which is typically /usr/local/bin, this will
|
||||
// correctly resolve to /usr/local/lib), but some systems still have
|
||||
// lib64 so we check that first. If neither exists, then
|
||||
// - use CMAKE_INSTALL_PREFIX/lib, which is typically /usr/local/lib
|
||||
string getRuntimeDir() {
|
||||
const auto &envDir = getEnvVar("ONNX_MLIR_RUNTIME_DIR");
|
||||
if (envDir && llvm::sys::fs::exists(envDir.getValue()))
|
||||
return envDir.getValue();
|
||||
|
||||
string execDir = llvm::sys::path::parent_path(kExecPath).str();
|
||||
if (llvm::sys::path::stem(execDir).str().compare("bin") == 0) {
|
||||
string p = execDir.substr(0, execDir.size() - 3);
|
||||
if (llvm::sys::fs::exists(p + "lib64"))
|
||||
return p + "lib64";
|
||||
if (llvm::sys::fs::exists(p + "lib"))
|
||||
return p + "lib";
|
||||
}
|
||||
|
||||
llvm::SmallString<8> instDir64(kInstPath);
|
||||
llvm::sys::path::append(instDir64, "lib64");
|
||||
string p = llvm::StringRef(instDir64).str();
|
||||
if (llvm::sys::fs::exists(p))
|
||||
return p;
|
||||
|
||||
llvm::SmallString<8> instDir(kInstPath);
|
||||
llvm::sys::path::append(instDir, "lib");
|
||||
return llvm::StringRef(instDir).str();
|
||||
}
|
||||
|
||||
// Helper struct to make command construction and execution easy & readable.
|
||||
struct Command {
|
||||
std::string _path;
|
||||
|
@ -97,6 +132,12 @@ struct Command {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
void setExecPath(const char *argv0, void *fmain) {
|
||||
string p;
|
||||
if (!(p = llvm::sys::fs::getMainExecutable(argv0, fmain)).empty())
|
||||
kExecPath = p;
|
||||
}
|
||||
|
||||
void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
|
||||
mlir::OwningModuleRef &module) {
|
||||
// Handle '.mlir' input to the ONNX MLIR frontend.
|
||||
|
@ -119,8 +160,8 @@ void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
|
|||
}
|
||||
}
|
||||
|
||||
void compileModuleToSharedLibrary(
|
||||
const mlir::OwningModuleRef &module, string outputBaseName) {
|
||||
void genConstPackObj(const mlir::OwningModuleRef &module,
|
||||
llvm::Optional<string> &constPackObjPath) {
|
||||
// Extract constant pack file name, which is embedded as a symbol in the
|
||||
// module being compiled.
|
||||
auto constPackFilePathSym = (*module).lookupSymbol<mlir::LLVM::GlobalOp>(
|
||||
|
@ -131,7 +172,6 @@ void compileModuleToSharedLibrary(
|
|||
.str();
|
||||
llvm::FileRemover constPackRemover(constPackFilePath);
|
||||
|
||||
llvm::Optional<std::string> constPackObjPath;
|
||||
#if __APPLE__
|
||||
// Create a empty stub file, compile it to an empty obj file.
|
||||
llvm::SmallVector<char, 20> stubSrcPath;
|
||||
|
@ -153,7 +193,6 @@ void compileModuleToSharedLibrary(
|
|||
.appendList({"-sectcreate", "binary", "param", constPackFilePath})
|
||||
.appendStr(stubObjPathStr)
|
||||
.exec();
|
||||
llvm::FileRemover constPackObjRemover(constPackObjPath.getValue());
|
||||
|
||||
#elif __linux__
|
||||
// Create param.o holding packed parameter values.
|
||||
|
@ -164,7 +203,6 @@ void compileModuleToSharedLibrary(
|
|||
.appendList({"-o", constPackObjPath.getValue()})
|
||||
.appendStr(constPackFilePath)
|
||||
.exec();
|
||||
llvm::FileRemover constPackObjRemover(constPackObjPath.getValue());
|
||||
|
||||
// Figure out what is the default symbol name describing the start/end
|
||||
// address of the embedded data.
|
||||
|
@ -204,40 +242,119 @@ void compileModuleToSharedLibrary(
|
|||
mlir::KrnlPackedConstantOp::getConstPackFileNameStrLenSymbolName())
|
||||
.valueAttr(builder.getI64IntegerAttr(constPackFileName.str().size()));
|
||||
#endif
|
||||
}
|
||||
|
||||
// Write LLVM bitcode.
|
||||
string outputFilename = outputBaseName + ".bc";
|
||||
// Write LLVM bitcode.
|
||||
void genLLVMBitcode(const mlir::OwningModuleRef &module, string bitcodePath) {
|
||||
error_code error;
|
||||
|
||||
llvm::raw_fd_ostream moduleBitcodeStream(
|
||||
outputFilename, error, llvm::sys::fs::F_None);
|
||||
bitcodePath, error, llvm::sys::fs::F_None);
|
||||
|
||||
llvm::WriteBitcodeToFile(
|
||||
*mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream);
|
||||
moduleBitcodeStream.flush();
|
||||
llvm::FileRemover bcRemover(outputFilename);
|
||||
}
|
||||
|
||||
// Compile LLVM bitcode to object file.
|
||||
// Compile LLVM bitcode to object file.
|
||||
void genModelObject(const mlir::OwningModuleRef &module, string bitcodePath,
|
||||
string modelObjPath) {
|
||||
Command llvmToObj(/*exePath=*/kLlcPath);
|
||||
llvmToObj.appendStr("-filetype=obj");
|
||||
llvmToObj.appendStr("-relocation-model=pic");
|
||||
llvmToObj.appendStr(outputFilename);
|
||||
llvmToObj.exec();
|
||||
std::string modelObjPath = outputBaseName + ".o";
|
||||
llvmToObj.appendStr("-filetype=obj")
|
||||
.appendStr("-relocation-model=pic")
|
||||
.appendList({"-o", modelObjPath})
|
||||
.appendStr(bitcodePath)
|
||||
.exec();
|
||||
}
|
||||
|
||||
void genJniObject(const mlir::OwningModuleRef &module, string jniSharedLibPath,
|
||||
string jniObjPath) {
|
||||
Command ar(/*exePath=*/kArPath);
|
||||
ar.appendStr("x").appendStr(jniSharedLibPath).appendStr(jniObjPath).exec();
|
||||
}
|
||||
|
||||
// Link everything into a shared object.
|
||||
void genSharedLib(const mlir::OwningModuleRef &module,
|
||||
string modelSharedLibPath, std::vector<string> opts,
|
||||
std::vector<string> objs, std::vector<string> libs) {
|
||||
|
||||
string runtimeDirInclFlag = "-L" + getRuntimeDir();
|
||||
|
||||
Command link(kCxxPath);
|
||||
link.appendList(opts)
|
||||
.appendList(objs)
|
||||
.appendList({"-o", modelSharedLibPath})
|
||||
.appendStrOpt(runtimeDirInclFlag)
|
||||
.appendList(libs)
|
||||
.exec();
|
||||
}
|
||||
|
||||
// Create jar containing java runtime and model shared library (which includes
|
||||
// jni runtime).
|
||||
void genJniJar(const mlir::OwningModuleRef &module, string modelSharedLibPath,
|
||||
string modelJniJarPath) {
|
||||
llvm::SmallString<8> runtimeDir(getRuntimeDir());
|
||||
llvm::sys::path::append(runtimeDir, "javaruntime.jar");
|
||||
string javaRuntimeJarPath = llvm::StringRef(runtimeDir).str();
|
||||
|
||||
// Copy javaruntime.jar to model jar.
|
||||
llvm::sys::fs::copy_file(javaRuntimeJarPath, modelJniJarPath);
|
||||
|
||||
// Add shared library to model jar.
|
||||
Command jar(kJarPath);
|
||||
jar.appendList({"uf", modelJniJarPath}).appendStr(modelSharedLibPath).exec();
|
||||
}
|
||||
|
||||
void compileModuleToSharedLibrary(
|
||||
const mlir::OwningModuleRef &module, std::string outputBaseName) {
|
||||
|
||||
llvm::Optional<string> constPackObjPath;
|
||||
genConstPackObj(module, constPackObjPath);
|
||||
llvm::FileRemover constPackObjRemover(constPackObjPath.getValue());
|
||||
|
||||
string bitcodePath = outputBaseName + ".bc";
|
||||
genLLVMBitcode(module, bitcodePath);
|
||||
llvm::FileRemover bitcodeRemover(bitcodePath);
|
||||
|
||||
string modelObjPath = outputBaseName + ".o";
|
||||
genModelObject(module, bitcodePath, modelObjPath);
|
||||
llvm::FileRemover modelObjRemover(modelObjPath);
|
||||
|
||||
llvm::Optional<std::string> runtimeDirInclFlag;
|
||||
if (getEnvVar("RUNTIME_DIR").hasValue())
|
||||
runtimeDirInclFlag = "-L" + getEnvVar("RUNTIME_DIR").getValue();
|
||||
string modelSharedLibPath = outputBaseName + ".so";
|
||||
genSharedLib(module, modelSharedLibPath, {"-shared", "-fPIC"},
|
||||
{constPackObjPath.getValueOr(""), modelObjPath},
|
||||
{"-lEmbeddedDataLoader", "-lcruntime"});
|
||||
}
|
||||
|
||||
// Link everything into a shared object.
|
||||
Command link(kCxxPath);
|
||||
link.appendList({"-shared", "-fPIC"})
|
||||
.appendStr(modelObjPath)
|
||||
.appendStr(constPackObjPath.getValueOr(""))
|
||||
.appendList({"-o", outputBaseName + ".so"})
|
||||
.appendStrOpt(runtimeDirInclFlag)
|
||||
.appendList({"-lEmbeddedDataLoader", "-lcruntime"})
|
||||
.exec();
|
||||
void compileModuleToJniJar(
|
||||
const mlir::OwningModuleRef &module, std::string outputBaseName) {
|
||||
|
||||
llvm::Optional<string> constPackObjPath;
|
||||
genConstPackObj(module, constPackObjPath);
|
||||
llvm::FileRemover constPackObjRemover(constPackObjPath.getValue());
|
||||
|
||||
string bitcodePath = outputBaseName + ".bc";
|
||||
genLLVMBitcode(module, bitcodePath);
|
||||
llvm::FileRemover bitcodeRemover(bitcodePath);
|
||||
|
||||
string modelObjPath = outputBaseName + ".o";
|
||||
genModelObject(module, bitcodePath, modelObjPath);
|
||||
llvm::FileRemover modelObjRemover(modelObjPath);
|
||||
|
||||
string jniSharedLibPath = getRuntimeDir() + "/libjniruntime.a";
|
||||
string jniObjPath = "jnidummy.c.o";
|
||||
genJniObject(module, jniSharedLibPath, jniObjPath);
|
||||
llvm::FileRemover jniObjRemover(jniObjPath);
|
||||
|
||||
string modelSharedLibPath = "libmodel.so";
|
||||
genSharedLib(module, modelSharedLibPath,
|
||||
{"-shared", "-fPIC", "-z", "noexecstack"},
|
||||
{constPackObjPath.getValueOr(""), modelObjPath, jniObjPath},
|
||||
{"-lEmbeddedDataLoader", "-lcruntime", "-ljniruntime"});
|
||||
llvm::FileRemover modelSharedLibRemover(modelSharedLibPath);
|
||||
|
||||
string modelJniJarPath = outputBaseName + ".jar";
|
||||
genJniJar(module, modelSharedLibPath, modelJniJarPath);
|
||||
}
|
||||
|
||||
void registerDialects() {
|
||||
|
@ -349,6 +466,9 @@ void emitOutputFiles(string outputBaseName, EmissionTargetType emissionTarget,
|
|||
// Write LLVM bitcode to disk, compile & link.
|
||||
compileModuleToSharedLibrary(module, outputBaseName);
|
||||
printf("Shared library %s.so has been compiled.\n", outputBaseName.c_str());
|
||||
} else if (emissionTarget == EmitJNI) {
|
||||
compileModuleToJniJar(module, outputBaseName);
|
||||
printf("JNI archive %s.jar has been compiled.\n", outputBaseName.c_str());
|
||||
} else {
|
||||
// Emit the version with all constants included.
|
||||
outputCode(module, outputBaseName, ".onnx.mlir");
|
||||
|
|
|
@ -43,14 +43,20 @@ enum EmissionTargetType {
|
|||
EmitMLIR,
|
||||
EmitLLVMIR,
|
||||
EmitLib,
|
||||
EmitJNI,
|
||||
};
|
||||
|
||||
void setExecPath(const char *argv0, void *fmain);
|
||||
|
||||
void LoadMLIR(std::string inputFilename, mlir::MLIRContext &context,
|
||||
mlir::OwningModuleRef &module);
|
||||
|
||||
void compileModuleToSharedLibrary(
|
||||
const mlir::OwningModuleRef &module, std::string outputBaseName);
|
||||
|
||||
void compileModuleToJniJar(
|
||||
const mlir::OwningModuleRef &module, std::string outputBaseName);
|
||||
|
||||
void registerDialects();
|
||||
|
||||
void addONNXToMLIRPasses(mlir::PassManager &pm);
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
add_subdirectory(jni)
|
||||
|
||||
# Create static libcruntime.a to be embedded in model.so to make model.so self contained.
|
||||
# However, by default object code for static library is not compiled with -fPIC. Embedding
|
||||
# such static library in a shared library can cause runtime failure on some architectures,
|
||||
|
|
|
@ -150,6 +150,13 @@ int64_t *getStrides(RtMemRef *dynMemRef) { return dynMemRef->strides; }
|
|||
|
||||
int64_t getSize(OrderedRtMemRefDict *dict) { return dict->orderedNames.size(); }
|
||||
|
||||
INDEX_TYPE getDataSize(RtMemRef *rtMemRef) {
|
||||
INDEX_TYPE n = rtMemRef->sizes[0];
|
||||
for (int i = 1; i < rtMemRef->rank; i++)
|
||||
n *= rtMemRef->sizes[i];
|
||||
return n;
|
||||
}
|
||||
|
||||
void setDType(RtMemRef *dynMemRef, int onnxType) {
|
||||
dynMemRef->onnx_dtype = onnxType;
|
||||
}
|
||||
|
@ -160,5 +167,5 @@ unsigned int getRank(RtMemRef *dynMemRef) { return dynMemRef->rank; }
|
|||
|
||||
void setStrides(RtMemRef *dynMemRef, int64_t *strides) {
|
||||
for (int i = 0; i < dynMemRef->rank; i++)
|
||||
dynMemRef->sizes[i] = strides[i];
|
||||
dynMemRef->strides[i] = strides[i];
|
||||
}
|
||||
|
|
|
@ -151,6 +151,9 @@ OrderedRtMemRefDict *createOrderedRtMemRefDict();
|
|||
// Get how many dynamic memrefs are in dict.
|
||||
int64_t getSize(OrderedRtMemRefDict *dict);
|
||||
|
||||
// Get how many data elements are in RtMemRef.
|
||||
INDEX_TYPE getDataSize(RtMemRef *rtMemRef);
|
||||
|
||||
// Create a dynmemref with a certain rank.
|
||||
RtMemRef *createRtMemRef(int rank);
|
||||
|
||||
|
@ -277,4 +280,4 @@ inline bool isRmrClose(
|
|||
#endif
|
||||
|
||||
// Will transition from RtMemRef to RtMemRef soon.
|
||||
typedef RtMemRef RtMemRef;
|
||||
typedef RtMemRef RtMemRef;
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
find_package(Java COMPONENTS Development)
|
||||
find_package(JNI)
|
||||
|
||||
if(Java_Development_FOUND AND JNI_FOUND)
|
||||
include(UseJava)
|
||||
|
||||
# Target for Java runtime jar
|
||||
add_jar(javaruntime
|
||||
src/com/ibm/onnxmlir/DynEntryPoint.java
|
||||
src/com/ibm/onnxmlir/OrderedRtMemRefDict.java
|
||||
src/com/ibm/onnxmlir/RtMemRef.java
|
||||
OUTPUT_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
|
||||
|
||||
# Target for JNI runtime lib
|
||||
add_library(jniruntime STATIC
|
||||
jniwrapper.c jnilog.c jnidummy.c
|
||||
com_ibm_onnxmlir_DynEntryPoint.h jnilog.h ../RtMemRef.h)
|
||||
set_target_properties(jniruntime PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE TRUE)
|
||||
target_include_directories(jniruntime PRIVATE
|
||||
${ONNX_MLIR_SRC_ROOT}/src/Runtime
|
||||
${JAVA_INCLUDE_PATH}
|
||||
${JAVA_INCLUDE_PATH2})
|
||||
|
||||
install_jar(javaruntime DESTINATION lib)
|
||||
install(TARGETS jniruntime DESTINATION lib)
|
||||
|
||||
else()
|
||||
message(WARNING "Java Development component or JNI not found, JNI targets will not work")
|
||||
endif()
|
|
@ -0,0 +1,22 @@
|
|||
/* DO NOT EDIT THIS FILE - it is machine generated */
|
||||
#include <jni.h>
|
||||
/* Header for class com_ibm_onnxmlir_DynEntryPoint */
|
||||
|
||||
#ifndef _Included_com_ibm_onnxmlir_DynEntryPoint
|
||||
#define _Included_com_ibm_onnxmlir_DynEntryPoint
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
/*
|
||||
* Class: com_ibm_onnxmlir_DynEntryPoint
|
||||
* Method: main_graph_jni
|
||||
* Signature:
|
||||
* (Lcom/ibm/onnxmlir/OrderedRtMemRefDict;)Lcom/ibm/onnxmlir/OrderedRtMemRefDict;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL Java_com_ibm_onnxmlir_DynEntryPoint_main_1graph_1jni(
|
||||
JNIEnv *, jclass, jobject);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
|
@ -0,0 +1,7 @@
|
|||
#include "com_ibm_onnxmlir_DynEntryPoint.h"
|
||||
|
||||
/* Dummy routine to force the link editor to embed code in libjniruntime.a
|
||||
into libmodel.so */
|
||||
void __dummy_do_not_call__(JNIEnv *env, jclass cls, jobject obj) {
|
||||
Java_com_ibm_onnxmlir_DynEntryPoint_main_1graph_1jni(NULL, NULL, NULL);
|
||||
}
|
|
@ -0,0 +1,94 @@
|
|||
#include <errno.h>
|
||||
#include <libgen.h>
|
||||
#include <stdarg.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <time.h>
|
||||
|
||||
#include "jnilog.h"
|
||||
|
||||
static int log_initd = 0;
|
||||
static int log_level;
|
||||
static FILE *log_fp;
|
||||
|
||||
/* Must match enum in log.h */
|
||||
static char *log_level_name[] = {
|
||||
"trace", "debug", "info", "warning", "error", "fatal"};
|
||||
|
||||
/* Return numerical log level of give level name */
|
||||
static int get_log_level_by_name(char *name) {
|
||||
int level = -1;
|
||||
for (int i = 0; i < sizeof(log_level_name) / sizeof(char *); i++) {
|
||||
if (!strcmp(name, log_level_name[i])) {
|
||||
level = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return level;
|
||||
}
|
||||
|
||||
/* Return FILE pointer of given file name */
|
||||
static FILE *get_log_file_by_name(char *name) {
|
||||
FILE *fp = NULL;
|
||||
if (!strcmp(name, "stdout"))
|
||||
fp = stdout;
|
||||
else if (!strcmp(name, "stderr"))
|
||||
fp = stderr;
|
||||
else
|
||||
fp = fopen(name, "w");
|
||||
return fp;
|
||||
}
|
||||
|
||||
/* Initialize log system. Set default log level and file or use environment
|
||||
* variables ONNX_MLIR_JNI_LOG_LEVEL and ONNX_MLIR_JNI_LOG_FILE, respectively.
|
||||
*/
|
||||
static void log_init() {
|
||||
if (log_initd)
|
||||
return;
|
||||
|
||||
log_level = LOG_INFO;
|
||||
char *strlevel = getenv("ONNX_MLIR_JNI_LOG_LEVEL");
|
||||
int level;
|
||||
if (strlevel && (level = get_log_level_by_name(strlevel)) != -1)
|
||||
log_level = level;
|
||||
|
||||
log_fp = stderr;
|
||||
char *strfname = getenv("ONNX_MLIR_JNI_LOG_FILE");
|
||||
FILE *fp;
|
||||
if (strfname && (fp = get_log_file_by_name(strfname)))
|
||||
log_fp = fp;
|
||||
|
||||
tzset();
|
||||
log_initd = 1;
|
||||
}
|
||||
|
||||
/* Generic log routine */
|
||||
void log_printf(
|
||||
int level, char *file, const char *func, int line, char *fmt, ...) {
|
||||
if (!log_initd)
|
||||
log_init();
|
||||
if (level < log_level)
|
||||
return;
|
||||
|
||||
time_t now;
|
||||
struct tm *tm;
|
||||
char buf[80];
|
||||
|
||||
/* Get local time and format as 2020-07-03 05:17:42 -0400 */
|
||||
if (time(&now) == -1 || (tm = localtime(&now)) == NULL ||
|
||||
strftime(buf, sizeof(buf), "%F %T %z", tm) == 0)
|
||||
sprintf(buf, "-");
|
||||
|
||||
/* Output log prefix */
|
||||
fprintf(log_fp, "[%s][%s]%s:%s:%d ", buf, log_level_name[level],
|
||||
basename(file), func, line);
|
||||
|
||||
/* Output actually log data */
|
||||
va_list va_list;
|
||||
va_start(va_list, fmt);
|
||||
vfprintf(log_fp, fmt, va_list);
|
||||
va_end(va_list);
|
||||
|
||||
fprintf(log_fp, "\n");
|
||||
}
|
|
@ -0,0 +1,125 @@
|
|||
#ifndef __JNILOG_H__
|
||||
#define __JNILOG_H__
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
enum { LOG_TRACE, LOG_DEBUG, LOG_INFO, LOG_WARNING, LOG_ERROR, LOG_FATAL };
|
||||
|
||||
#define LOG_MAX_NUM 16 /* max number of elements to output */
|
||||
|
||||
#define MIN(x, y) ((x) > (y) ? y : x)
|
||||
|
||||
/* Construct string of up to LOG_MAX_NUM elements of a char array */
|
||||
#define LOG_CHAR_BUF(buf, data, n) \
|
||||
do { \
|
||||
buf[0] = '\0'; \
|
||||
for (int i = 0; i < MIN(n, LOG_MAX_NUM); i++) \
|
||||
sprintf(buf + strlen(buf), " %02x", ((char *)data)[i]); \
|
||||
sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \
|
||||
} while (0)
|
||||
|
||||
/* Construct string of up to LOG_MAX_NUM elements of a short array */
|
||||
#define LOG_SHORT_BUF(buf, data, n) \
|
||||
do { \
|
||||
buf[0] = '\0'; \
|
||||
for (int i = 0; i < MIN(n, LOG_MAX_NUM); i++) \
|
||||
sprintf(buf + strlen(buf), " %d", ((short *)data)[i]); \
|
||||
sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \
|
||||
} while (0)
|
||||
|
||||
/* Construct string of up to LOG_MAX_NUM elements of a int array */
|
||||
#define LOG_INT_BUF(buf, data, n) \
|
||||
do { \
|
||||
buf[0] = '\0'; \
|
||||
for (int i = 0; i < MIN(n, LOG_MAX_NUM); i++) \
|
||||
sprintf(buf + strlen(buf), " %d", ((int *)data)[i]); \
|
||||
sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \
|
||||
} while (0)
|
||||
|
||||
/* Construct string of up to LOG_MAX_NUM elements of a long array */
|
||||
#define LOG_LONG_BUF(buf, data, n) \
|
||||
do { \
|
||||
buf[0] = '\0'; \
|
||||
for (int i = 0; i < MIN(n, LOG_MAX_NUM); i++) \
|
||||
sprintf(buf + strlen(buf), " %ld", ((long *)data)[i]); \
|
||||
sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \
|
||||
} while (0)
|
||||
|
||||
/* Construct string of up to LOG_MAX_NUM elements of a float array */
|
||||
#define LOG_FLOAT_BUF(buf, data, n) \
|
||||
do { \
|
||||
buf[0] = '\0'; \
|
||||
for (int i = 0; i < MIN(n, LOG_MAX_NUM); i++) \
|
||||
sprintf(buf + strlen(buf), " %f", ((float *)data)[i]); \
|
||||
sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \
|
||||
} while (0)
|
||||
|
||||
/* Construct string of up to LOG_MAX_NUM elements of a double array */
|
||||
#define LOG_DOUBLE_BUF(buf, data, n) \
|
||||
do { \
|
||||
buf[0] = '\0'; \
|
||||
for (int i = 0; i < MIN(n, LOG_MAX_NUM); i++) \
|
||||
sprintf(buf + strlen(buf), " %f", ((double *)data)[i]); \
|
||||
sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \
|
||||
} while (0)
|
||||
|
||||
enum {
|
||||
ONNX_TYPE_UNDEFINED, /* 0 */
|
||||
ONNX_TYPE_FLOAT, /* 1 */
|
||||
ONNX_TYPE_UINT8, /* 2 */
|
||||
ONNX_TYPE_INT8, /* 3 */
|
||||
ONNX_TYPE_UINT16, /* 4 */
|
||||
ONNX_TYPE_INT16, /* 5 */
|
||||
ONNX_TYPE_INT32, /* 6 */
|
||||
ONNX_TYPE_INT64, /* 7 */
|
||||
ONNX_TYPE_STRING, /* 8 */
|
||||
ONNX_TYPE_BOOL, /* 9 */
|
||||
ONNX_TYPE_FLOAT16, /* 10 */
|
||||
ONNX_TYPE_DOUBLE, /* 11 */
|
||||
ONNX_TYPE_UINT32, /* 12 */
|
||||
ONNX_TYPE_UINT64, /* 13 */
|
||||
ONNX_TYPE_COMPLEX64, /* 14 */
|
||||
ONNX_TYPE_COMPLEX128, /* 15 */
|
||||
ONNX_TYPE_BFLOAT16, /* 16 */
|
||||
};
|
||||
|
||||
/* Construct string of up to LOG_MAX_NUM elements of a "type" array */
|
||||
#define LOG_TYPE_BUF(type, buf, data, n) \
|
||||
do { \
|
||||
switch (type) { \
|
||||
case ONNX_TYPE_UINT8: \
|
||||
case ONNX_TYPE_INT8: \
|
||||
LOG_CHAR_BUF(buf, data, n); \
|
||||
break; \
|
||||
case ONNX_TYPE_UINT16: \
|
||||
case ONNX_TYPE_INT16: \
|
||||
LOG_SHORT_BUF(buf, data, n); \
|
||||
break; \
|
||||
case ONNX_TYPE_UINT32: \
|
||||
case ONNX_TYPE_INT32: \
|
||||
LOG_INT_BUF(buf, data, n); \
|
||||
break; \
|
||||
case ONNX_TYPE_UINT64: \
|
||||
case ONNX_TYPE_INT64: \
|
||||
LOG_LONG_BUF(buf, data, n); \
|
||||
break; \
|
||||
case ONNX_TYPE_FLOAT: \
|
||||
LOG_FLOAT_BUF(buf, data, n); \
|
||||
break; \
|
||||
case ONNX_TYPE_DOUBLE: \
|
||||
LOG_DOUBLE_BUF(buf, data, n); \
|
||||
break; \
|
||||
defaut: \
|
||||
sprintf(buf, " unsupported data type %d ", type); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
/* Main macro for log output */
|
||||
#define LOG_PRINTF(level, ...) \
|
||||
log_printf(level, __FILE__, __FUNCTION__, __LINE__, __VA_ARGS__)
|
||||
|
||||
/* Generic log routine */
|
||||
void log_printf(
|
||||
int level, char *file, const char *func, int line, char *fmt, ...);
|
||||
|
||||
#endif
|
|
@ -0,0 +1,356 @@
|
|||
#include <assert.h>
|
||||
#include <malloc.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "RtMemRef.h"
|
||||
#include "com_ibm_onnxmlir_DynEntryPoint.h"
|
||||
#include "jnilog.h"
|
||||
|
||||
/* Declare type var, make call and assign to var, check against val */
|
||||
#define CHECK_CALL(type, var, call, val) \
|
||||
type var = call; \
|
||||
if (var == val) \
|
||||
return NULL
|
||||
|
||||
/* Make a JNI call and throw Java exception if the call failed */
|
||||
#define JNI_CALL(env, stmt) \
|
||||
stmt; \
|
||||
do { \
|
||||
jthrowable e = (*env)->ExceptionOccurred(env); \
|
||||
if (e) { \
|
||||
LOG_PRINTF(LOG_ERROR, "JNI call exception occurred"); \
|
||||
(*env)->Throw(env, e); \
|
||||
return NULL; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
/* Make a JNI call and assign return value to var,
|
||||
* throw Java exception if the call failed
|
||||
*/
|
||||
#define JNI_VAR_CALL(env, var, call) JNI_CALL(env, var = call)
|
||||
|
||||
/* Declare type var, make a JNI call and assign return value to var,
|
||||
* throw Java exception if the call failed
|
||||
*/
|
||||
#define JNI_TYPE_VAR_CALL(env, type, var, call) JNI_CALL(env, type var = call);
|
||||
|
||||
/* If cond is true (native code failed), log error and throw Java exception */
|
||||
#define JNI_COND(type, var, call, val, env, cls, ...) \
|
||||
type var = call; \
|
||||
do { \
|
||||
if (var == val) { \
|
||||
LOG_PRINTF(LOG_ERROR, __VA_ARGS__); \
|
||||
(*env)->ThrowNew(env, cls, "native code error"); \
|
||||
return NULL; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
/* Debug output of RtMemRef fields */
|
||||
#define RMR_DEBUG(i, type, rank, sizes, strides, data, datasize) \
|
||||
do { \
|
||||
char tmp[1024]; \
|
||||
LOG_PRINTF(LOG_DEBUG, "rmr[%d]:type=%d", i, type); \
|
||||
LOG_PRINTF(LOG_DEBUG, "rmr[%d]:rank=%d", i, rank); \
|
||||
LOG_LONG_BUF(tmp, sizes, rank); \
|
||||
LOG_PRINTF(LOG_DEBUG, "rmr[%d]:sizes=[%s]", i, tmp); \
|
||||
LOG_LONG_BUF(tmp, strides, rank); \
|
||||
LOG_PRINTF(LOG_DEBUG, "rmr[%d]:strides=[%s]", i, tmp); \
|
||||
LOG_TYPE_BUF(type, tmp, data, datasize); \
|
||||
LOG_PRINTF(LOG_DEBUG, "rmr[%d]:data=[%s]", i, tmp); \
|
||||
} while (0)
|
||||
|
||||
/* Model shared library entry point */
|
||||
extern OrderedRtMemRefDict *_dyn_entry_point_main_graph(OrderedRtMemRefDict *);
|
||||
|
||||
/* ONNX type to size (number of bytes) mapping */
|
||||
int onnx_type_size[] = {
|
||||
0, /* UNDEFINED = 0 */
|
||||
4, /* FLOAT = 1 */
|
||||
1, /* UINT8 = 2 */
|
||||
1, /* INT8 = 3 */
|
||||
2, /* UINT16 = 4 */
|
||||
2, /* INT16 = 5 */
|
||||
4, /* INT32 = 6 */
|
||||
8, /* INT64 = 7 */
|
||||
0, /* STRING = 8 */
|
||||
1, /* BOOL = 9 */
|
||||
2, /* FLOAT16 = 10 */
|
||||
8, /* DOUBLE = 11 */
|
||||
4, /* UINT32 = 12 */
|
||||
8, /* UINT64 = 13 */
|
||||
8, /* COMPLEX64 = 14 */
|
||||
16, /* COMPLEX128 = 15 */
|
||||
2, /* BFLOAT16 = 16 */
|
||||
};
|
||||
|
||||
/* Java classes and methods needed for making various JNI API calls */
|
||||
typedef struct {
|
||||
jclass ecpt_cls; /* java/lang/Exception class */
|
||||
jclass long_cls; /* java/lang/Long class */
|
||||
jclass string_cls; /* java/lang/String class */
|
||||
jclass ormrd_cls; /* com/ibm/onnxmlir/OrderedRtMemRefDict class */
|
||||
jclass rmr_cls; /* com/ibm/onnxmlir/RtMemRef class */
|
||||
|
||||
jmethodID ormrd_constructor; /* OrderedRtMemRefDict constructor */
|
||||
jmethodID ormrd_getRmrs; /* OrderedRtMemRefDict getRmrs method */
|
||||
jmethodID ormrd_getNames; /* OrderedRtMemRefDict getNames method */
|
||||
|
||||
jmethodID rmr_constructor; /* RtMemRef constructor */
|
||||
jmethodID rmr_getType; /* RtMemRef getType method */
|
||||
jmethodID rmr_setType; /* RtMemRef setType method */
|
||||
jmethodID rmr_getRank; /* RtMemRef getRank method */
|
||||
jmethodID rmr_getData; /* RtMemRef getData method */
|
||||
jmethodID rmr_setData; /* RtMemRef setData method */
|
||||
jmethodID rmr_getSizes; /* RtMemRef getSizes method */
|
||||
jmethodID rmr_setSizes; /* RtMemRef setSizes method */
|
||||
jmethodID rmr_getStrides; /* RtMemRef getStrides method */
|
||||
jmethodID rmr_setStrides; /* RtMemRef setStrides method */
|
||||
jmethodID rmr_getDataSize; /* RtMemRef getDataSize method */
|
||||
} jniapi_t;
|
||||
|
||||
jniapi_t jniapi;
|
||||
|
||||
/* Fill in struct jniapi */
|
||||
jniapi_t *fill_jniapi(JNIEnv *env, jniapi_t *japi) {
|
||||
/* Get Java Exception, Long, String, OrderedRtMemRefDict, and RtMemRef classes
|
||||
*/
|
||||
JNI_VAR_CALL(
|
||||
env, japi->ecpt_cls, (*env)->FindClass(env, "java/lang/Exception"));
|
||||
JNI_VAR_CALL(env, japi->long_cls, (*env)->FindClass(env, "java/lang/Long"));
|
||||
JNI_VAR_CALL(
|
||||
env, japi->string_cls, (*env)->FindClass(env, "java/lang/String"));
|
||||
JNI_VAR_CALL(env, japi->ormrd_cls,
|
||||
(*env)->FindClass(env, "com/ibm/onnxmlir/OrderedRtMemRefDict"));
|
||||
JNI_VAR_CALL(
|
||||
env, japi->rmr_cls, (*env)->FindClass(env, "com/ibm/onnxmlir/RtMemRef"));
|
||||
|
||||
/* Get method ID of constructor and various methods in OrderedRtMemRefDict */
|
||||
JNI_VAR_CALL(env, japi->ormrd_constructor,
|
||||
(*env)->GetMethodID(
|
||||
env, japi->ormrd_cls, "<init>", "([Lcom/ibm/onnxmlir/RtMemRef;)V"));
|
||||
JNI_VAR_CALL(env, japi->ormrd_getRmrs,
|
||||
(*env)->GetMethodID(
|
||||
env, japi->ormrd_cls, "getRmrs", "()[Lcom/ibm/onnxmlir/RtMemRef;"));
|
||||
JNI_VAR_CALL(env, japi->ormrd_getNames,
|
||||
(*env)->GetMethodID(
|
||||
env, japi->ormrd_cls, "getNames", "()[Ljava/lang/String;"));
|
||||
|
||||
/* Get method ID of constructor and various methods in RtMemRef */
|
||||
JNI_VAR_CALL(env, japi->rmr_constructor,
|
||||
(*env)->GetMethodID(env, japi->rmr_cls, "<init>", "(I)V"));
|
||||
JNI_VAR_CALL(env, japi->rmr_getType,
|
||||
(*env)->GetMethodID(env, japi->rmr_cls, "getType", "()I"));
|
||||
JNI_VAR_CALL(env, japi->rmr_setType,
|
||||
(*env)->GetMethodID(env, japi->rmr_cls, "setType", "(I)V"));
|
||||
JNI_VAR_CALL(env, japi->rmr_getRank,
|
||||
(*env)->GetMethodID(env, japi->rmr_cls, "getRank", "()I"));
|
||||
JNI_VAR_CALL(env, japi->rmr_getData,
|
||||
(*env)->GetMethodID(
|
||||
env, japi->rmr_cls, "getData", "()Ljava/nio/ByteBuffer;"));
|
||||
JNI_VAR_CALL(env, japi->rmr_setData,
|
||||
(*env)->GetMethodID(
|
||||
env, japi->rmr_cls, "setData", "(Ljava/nio/ByteBuffer;)V"));
|
||||
JNI_VAR_CALL(env, japi->rmr_getSizes,
|
||||
(*env)->GetMethodID(env, japi->rmr_cls, "getSizes", "()[J"));
|
||||
JNI_VAR_CALL(env, japi->rmr_setSizes,
|
||||
(*env)->GetMethodID(env, japi->rmr_cls, "setSizes", "([J)V"));
|
||||
JNI_VAR_CALL(env, japi->rmr_getStrides,
|
||||
(*env)->GetMethodID(env, japi->rmr_cls, "getStrides", "()[J"));
|
||||
JNI_VAR_CALL(env, japi->rmr_setStrides,
|
||||
(*env)->GetMethodID(env, japi->rmr_cls, "setStrides", "([J)V"));
|
||||
JNI_VAR_CALL(env, japi->rmr_getDataSize,
|
||||
(*env)->GetMethodID(env, japi->rmr_cls, "getDataSize", "()J"));
|
||||
|
||||
return japi;
|
||||
}
|
||||
|
||||
/* Convert Java object to native data structure */
|
||||
OrderedRtMemRefDict *ormrd_java_to_native(
|
||||
JNIEnv *env, jclass cls, jobject obj, jniapi_t *japi) {
|
||||
/* Get object array "rmrs" and "names" in OrderedRtMemRefDict */
|
||||
JNI_TYPE_VAR_CALL(env, jobjectArray, ormrd_rmrs,
|
||||
(*env)->CallObjectMethod(env, obj, japi->ormrd_getRmrs));
|
||||
JNI_TYPE_VAR_CALL(env, jobjectArray, ormrd_names,
|
||||
(*env)->CallObjectMethod(env, obj, japi->ormrd_getNames));
|
||||
|
||||
/* Get length of object array "rmrs" and "names" in OrderedRtMemRefDict */
|
||||
JNI_TYPE_VAR_CALL(
|
||||
env, jsize, ormrd_rmrs_len, (*env)->GetArrayLength(env, ormrd_rmrs));
|
||||
JNI_TYPE_VAR_CALL(
|
||||
env, jsize, ormrd_names_len, (*env)->GetArrayLength(env, ormrd_names));
|
||||
|
||||
/* Allocate memory for holding each Java rmr object and name string,
|
||||
* and RtMemRef and char pointers for constructing native RtMemRef and name
|
||||
* array
|
||||
*/
|
||||
JNI_COND(jobject *, obj_rmr, malloc(ormrd_rmrs_len * sizeof(jobject)), NULL,
|
||||
env, japi->ecpt_cls, "obj_rmr=null");
|
||||
JNI_COND(jstring *, obj_name, malloc(ormrd_names_len * sizeof(jstring)), NULL,
|
||||
env, japi->ecpt_cls, "obj_name=null");
|
||||
JNI_COND(RtMemRef **, jni_rmr, malloc(ormrd_rmrs_len * sizeof(RtMemRef *)),
|
||||
NULL, env, japi->ecpt_cls, "jni_rmr=null");
|
||||
JNI_COND(const char **, jni_name,
|
||||
malloc(ormrd_names_len * sizeof(const char *)), NULL, env, japi->ecpt_cls,
|
||||
"jni_name=null");
|
||||
|
||||
/* Create OrderedRtMemRefDict to be constructed and passed to the model shared
|
||||
* library */
|
||||
JNI_COND(OrderedRtMemRefDict *, ormrd, createOrderedRtMemRefDict(), NULL, env,
|
||||
japi->ecpt_cls, "ormrd=null");
|
||||
|
||||
/* Loop through all the ormrd_rmrs and ormrd_names */
|
||||
for (int i = 0; i < ormrd_rmrs_len; i++) {
|
||||
JNI_VAR_CALL(
|
||||
env, obj_rmr[i], (*env)->GetObjectArrayElement(env, ormrd_rmrs, i));
|
||||
JNI_VAR_CALL(
|
||||
env, obj_name[i], (*env)->GetObjectArrayElement(env, ormrd_names, i));
|
||||
|
||||
/* Get type, rank, data, sizes, and strides by calling corresponding methods
|
||||
*/
|
||||
JNI_TYPE_VAR_CALL(env, jint, rmr_type,
|
||||
(*env)->CallIntMethod(env, obj_rmr[i], japi->rmr_getType));
|
||||
JNI_TYPE_VAR_CALL(env, jint, rmr_rank,
|
||||
(*env)->CallIntMethod(env, obj_rmr[i], japi->rmr_getRank));
|
||||
JNI_TYPE_VAR_CALL(env, jlong, rmr_datasize,
|
||||
(*env)->CallLongMethod(env, obj_rmr[i], japi->rmr_getDataSize));
|
||||
JNI_TYPE_VAR_CALL(env, jobject, rmr_data,
|
||||
(*env)->CallObjectMethod(env, obj_rmr[i], japi->rmr_getData));
|
||||
JNI_TYPE_VAR_CALL(env, jobject, rmr_sizes,
|
||||
(*env)->CallObjectMethod(env, obj_rmr[i], japi->rmr_getSizes));
|
||||
JNI_TYPE_VAR_CALL(env, jobject, rmr_strides,
|
||||
(*env)->CallObjectMethod(env, obj_rmr[i], japi->rmr_getStrides));
|
||||
|
||||
/* Primitive type int and long can be directly used */
|
||||
int jni_type = rmr_type, jni_rank = rmr_rank;
|
||||
long jni_datasize = rmr_datasize;
|
||||
|
||||
/* Get direct buffer associated with data */
|
||||
JNI_TYPE_VAR_CALL(
|
||||
env, void *, jni_data, (*env)->GetDirectBufferAddress(env, rmr_data));
|
||||
|
||||
/* Get long array associated with sizes and strides */
|
||||
JNI_TYPE_VAR_CALL(env, long *, jni_sizes,
|
||||
(*env)->GetLongArrayElements(env, rmr_sizes, NULL));
|
||||
JNI_TYPE_VAR_CALL(env, long *, jni_strides,
|
||||
(*env)->GetLongArrayElements(env, rmr_strides, NULL));
|
||||
|
||||
/* Print debug info on what we got from the Java side */
|
||||
RMR_DEBUG(
|
||||
i, jni_type, jni_rank, jni_sizes, jni_strides, jni_data, jni_datasize);
|
||||
|
||||
/* Create native RtMemRef struct and fill in its fields */
|
||||
jni_rmr[i] = createRtMemRef(jni_rank);
|
||||
setDType(jni_rmr[i], jni_type);
|
||||
setData(jni_rmr[i], jni_data);
|
||||
setSizes(jni_rmr[i], jni_sizes);
|
||||
setStrides(jni_rmr[i], jni_strides);
|
||||
|
||||
/*jni_name[i] = (*env)->GetStringUTFChars(env, obj_name[i], NULL);
|
||||
printf("jni_name=%s\n", jni_name[i]);*/
|
||||
|
||||
/* Install RtMemRef into OrderedRtMemRefDict */
|
||||
setRtMemRef(ormrd, i, jni_rmr[i]);
|
||||
|
||||
/* Release reference to the java objects */
|
||||
JNI_CALL(
|
||||
env, (*env)->ReleaseLongArrayElements(env, rmr_sizes, jni_sizes, 0));
|
||||
JNI_CALL(env,
|
||||
(*env)->ReleaseLongArrayElements(env, rmr_strides, jni_strides, 0));
|
||||
}
|
||||
|
||||
/* setRtMemRef(ormrd, jni_rmr, jni_name); */
|
||||
return ormrd;
|
||||
}
|
||||
|
||||
/* Convert native data structure to Java object */
|
||||
jobject ormrd_native_to_java(
|
||||
JNIEnv *env, jclass cls, OrderedRtMemRefDict *dict, jniapi_t *japi) {
|
||||
JNI_COND(int, nrmr, numRtMemRefs(dict), 0, env, japi->ecpt_cls, "nrmr=0");
|
||||
|
||||
/* Create RtMemRef java object array */
|
||||
JNI_TYPE_VAR_CALL(env, jobjectArray, rmrs,
|
||||
(*env)->NewObjectArray(env, nrmr, japi->rmr_cls, NULL));
|
||||
|
||||
/* Loop through the native RtMemRef structs */
|
||||
for (int i = 0; i < nrmr; i++) {
|
||||
JNI_COND(RtMemRef *, rmr, getRtMemRef(dict, i), NULL, env, japi->ecpt_cls,
|
||||
"rmr[%d]=null", i);
|
||||
|
||||
JNI_COND(int, jni_type, getDType(rmr), 0, env, japi->ecpt_cls,
|
||||
"rmr[%d]:type=0", i);
|
||||
JNI_COND(int, jni_rank, getRank(rmr), 0, env, japi->ecpt_cls,
|
||||
"rmr[%d]:rank=0", i);
|
||||
JNI_COND(long, jni_datasize, getDataSize(rmr), 0, env, japi->ecpt_cls,
|
||||
"rmr[%d]:datasize=0", i);
|
||||
JNI_COND(void *, jni_data, getData(rmr), NULL, env, japi->ecpt_cls,
|
||||
"rmr[%d]:data=null", i);
|
||||
JNI_COND(long *, jni_sizes, getSizes(rmr), NULL, env, japi->ecpt_cls,
|
||||
"rmr[%d]:sizes=null", i);
|
||||
JNI_COND(long *, jni_strides, getStrides(rmr), NULL, env, japi->ecpt_cls,
|
||||
"rmr[%d]:strides=null", i);
|
||||
|
||||
/* Print debug info on what we got from the native side */
|
||||
RMR_DEBUG(
|
||||
i, jni_type, jni_rank, jni_sizes, jni_strides, jni_data, jni_datasize);
|
||||
|
||||
/* create the following Java objects:
|
||||
* - RtMemRef
|
||||
* - DirectByteBuffer (from native buffers)
|
||||
* - long array for sizes and strides
|
||||
*/
|
||||
JNI_TYPE_VAR_CALL(env, jobject, obj_rmr,
|
||||
(*env)->NewObject(env, japi->rmr_cls, japi->rmr_constructor, jni_rank));
|
||||
JNI_TYPE_VAR_CALL(env, jobject, rmr_data,
|
||||
(*env)->NewDirectByteBuffer(
|
||||
env, jni_data, jni_datasize * onnx_type_size[jni_type]));
|
||||
JNI_TYPE_VAR_CALL(
|
||||
env, jlongArray, rmr_sizes, (*env)->NewLongArray(env, jni_rank));
|
||||
JNI_TYPE_VAR_CALL(
|
||||
env, jlongArray, rmr_strides, (*env)->NewLongArray(env, jni_rank));
|
||||
|
||||
/* Call setType method */
|
||||
JNI_CALL(env,
|
||||
(*env)->CallObjectMethod(env, obj_rmr, japi->rmr_setType, jni_type));
|
||||
|
||||
/* Call setData method */
|
||||
JNI_CALL(env,
|
||||
(*env)->CallObjectMethod(env, obj_rmr, japi->rmr_setData, rmr_data));
|
||||
|
||||
/* Fill in sizes array from native array and call setSizes method */
|
||||
JNI_CALL(env,
|
||||
(*env)->SetLongArrayRegion(env, rmr_sizes, 0, jni_rank, jni_sizes));
|
||||
JNI_CALL(env,
|
||||
(*env)->CallObjectMethod(env, obj_rmr, japi->rmr_setSizes, rmr_sizes));
|
||||
|
||||
/* Fill in strides array from native array and call setStrides method */
|
||||
JNI_CALL(env,
|
||||
(*env)->SetLongArrayRegion(env, rmr_strides, 0, jni_rank, jni_strides));
|
||||
JNI_CALL(env, (*env)->CallObjectMethod(
|
||||
env, obj_rmr, japi->rmr_setStrides, rmr_strides));
|
||||
|
||||
/* Set DynMemRef object in the object array */
|
||||
JNI_CALL(env, (*env)->SetObjectArrayElement(env, rmrs, i, obj_rmr));
|
||||
}
|
||||
|
||||
/* Create the OrderedRtMemRefDict java object */
|
||||
JNI_TYPE_VAR_CALL(env, jobject, ormrd,
|
||||
(*env)->NewObject(env, japi->ormrd_cls, japi->ormrd_constructor, rmrs));
|
||||
|
||||
return ormrd;
|
||||
}
|
||||
|
||||
JNIEXPORT jobject JNICALL Java_com_ibm_onnxmlir_DynEntryPoint_main_1graph_1jni(
|
||||
JNIEnv *env, jclass cls, jobject obj) {
|
||||
CHECK_CALL(jniapi_t *, japi, fill_jniapi(env, &jniapi), NULL);
|
||||
|
||||
CHECK_CALL(OrderedRtMemRefDict *, input_ormrd,
|
||||
ormrd_java_to_native(env, cls, obj, japi), NULL);
|
||||
|
||||
CHECK_CALL(OrderedRtMemRefDict *, dict,
|
||||
_dyn_entry_point_main_graph(input_ormrd), NULL);
|
||||
|
||||
CHECK_CALL(
|
||||
jobject, output_ormrd, ormrd_native_to_java(env, cls, dict, japi), NULL);
|
||||
|
||||
return output_ormrd;
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
package com.ibm.onnxmlir;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.IOException;
|
||||
import java.net.URISyntaxException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
import java.util.jar.JarFile;
|
||||
|
||||
public class DynEntryPoint {
|
||||
static String libname = "libmodel.so";
|
||||
|
||||
static {
|
||||
File jar;
|
||||
JarFile jf;
|
||||
String jarDir = null;
|
||||
String libPath = null;
|
||||
try {
|
||||
// Get path name of jar
|
||||
jar = new File(DynEntryPoint.class.getProtectionDomain()
|
||||
.getCodeSource()
|
||||
.getLocation().toURI());
|
||||
jarDir = jar.getParentFile().getAbsolutePath();
|
||||
libPath = jarDir + "/" + libname;
|
||||
|
||||
// Open jar file to read and check libname inside jar.
|
||||
// If IOException thrown, load .so from where .jar is.
|
||||
//
|
||||
// Checking whether DynEntryPoint.class.getResourceAsStream returns null
|
||||
// does NOT work. Because it checks whether the resource is
|
||||
// available on the classpath, not only just inside the jar file.
|
||||
jf = new JarFile(jar);
|
||||
if (jf.getEntry(libname) != null) {
|
||||
File libFile = new File(libPath);
|
||||
// Copy .so to where jar is
|
||||
Files.copy(jf.getInputStream(jf.getEntry(libname)),
|
||||
libFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
|
||||
// z/OS USS requires "x" permission bit
|
||||
libFile.setExecutable(true, false);
|
||||
// Load the temporary .so copy
|
||||
System.load(libPath);
|
||||
// POSIX can unlink file after loading
|
||||
libFile.delete();
|
||||
}
|
||||
else {
|
||||
// Throw subclass of IOException
|
||||
throw new FileNotFoundException(".so not found inside jar");
|
||||
}
|
||||
} catch (URISyntaxException e) {
|
||||
// Failed to find jar path, assume the .so is in cwd
|
||||
System.load(libname);
|
||||
} catch (IOException e) {
|
||||
// .so not found in jar, assume it's where jar is
|
||||
System.load(libPath);
|
||||
}
|
||||
}
|
||||
|
||||
private static native OrderedRtMemRefDict main_graph_jni(OrderedRtMemRefDict ormrd);
|
||||
|
||||
public static OrderedRtMemRefDict main_graph(OrderedRtMemRefDict ormrd) {
|
||||
return main_graph_jni(ormrd);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,118 @@
|
|||
package com.ibm.onnxmlir;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
public class OrderedRtMemRefDict {
|
||||
|
||||
private RtMemRef[] _rmrs;
|
||||
private String[] _names;
|
||||
private HashMap<String, Integer> _n2i;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*
|
||||
* @param rmrs DynMemRef array
|
||||
*/
|
||||
public OrderedRtMemRefDict(RtMemRef[] rmrs) {
|
||||
this(rmrs, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*
|
||||
* @param rmrs DynMemRef array
|
||||
* @param names name array
|
||||
*/
|
||||
public OrderedRtMemRefDict(RtMemRef[] rmrs, String[] names) {
|
||||
/* rmrs cannot be null or empty */
|
||||
if (rmrs == null || rmrs.length == 0)
|
||||
throw new IllegalArgumentException(
|
||||
"Number of dmrs is invalid");
|
||||
|
||||
/* If names is null or empty, construct a default one with
|
||||
* index as name.
|
||||
*/
|
||||
if (names == null || names.length == 0) {
|
||||
names = new String[rmrs.length];
|
||||
for (int i = 0; i < rmrs.length; i++)
|
||||
names[i] = Integer.toString(i);
|
||||
}
|
||||
|
||||
/* Number of rmrs and names must match */
|
||||
if (rmrs.length != names.length)
|
||||
throw new IllegalArgumentException(
|
||||
"Number of dmrs and names do not match");
|
||||
|
||||
/* Establish name to index mapping. Individual rmr is
|
||||
* checked for validity.
|
||||
*/
|
||||
_n2i = new HashMap<String, Integer>();
|
||||
for (int i = 0; i < names.length; i++) {
|
||||
if (rmrs[i] == null || !rmrs[i].validRmr())
|
||||
throw new IllegalArgumentException(
|
||||
"rmrs[" + i + "] is invalid");
|
||||
if (_n2i.put(names[i], i) != null)
|
||||
throw new IllegalArgumentException(
|
||||
"name[" + i + "] = " + names[i] + " not unique");
|
||||
}
|
||||
_rmrs = rmrs;
|
||||
_names = names;
|
||||
}
|
||||
|
||||
/**
|
||||
* RtMemRef getter by index
|
||||
*
|
||||
* @param idx index of RtMemRef instance to get
|
||||
* @return RtMemRef instance
|
||||
*/
|
||||
public RtMemRef getRmrbyIndex(int idx) {
|
||||
return _rmrs[idx];
|
||||
}
|
||||
|
||||
/**
|
||||
* RtMemRef getter by name
|
||||
*
|
||||
* @param name name of RtMemRef instance to get
|
||||
* @return RtMemRef instance
|
||||
*/
|
||||
public RtMemRef getRmrByName(String name) {
|
||||
return _rmrs[_n2i.get(name)];
|
||||
}
|
||||
|
||||
/**
|
||||
* RtMemRef array getter
|
||||
*
|
||||
* @return RtMemRef array
|
||||
*/
|
||||
public RtMemRef[] getRmrs() {
|
||||
return _rmrs;
|
||||
}
|
||||
|
||||
/**
|
||||
* Name getter
|
||||
*
|
||||
* @param idx index of name to get
|
||||
* @return name string
|
||||
*/
|
||||
public String getName(int idx) {
|
||||
return _names[idx];
|
||||
}
|
||||
|
||||
/**
|
||||
* Name array getter
|
||||
*
|
||||
* @return name array
|
||||
*/
|
||||
public String[] getNames() {
|
||||
return _names;
|
||||
}
|
||||
|
||||
/**
|
||||
* RtMemRef array size getter
|
||||
*
|
||||
* @return RtMemRef array size
|
||||
*/
|
||||
public int size() {
|
||||
return _rmrs.length;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,400 @@
|
|||
package com.ibm.onnxmlir;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.DoubleBuffer;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.nio.IntBuffer;
|
||||
import java.nio.LongBuffer;
|
||||
import java.nio.ShortBuffer;
|
||||
|
||||
public class RtMemRef {
|
||||
final ByteOrder endian = ByteOrder.nativeOrder();
|
||||
|
||||
/* We can use enum but that creates another class
|
||||
* which complicates things for JNI.
|
||||
*/
|
||||
final int ONNX_TYPE_UNDEFINED = 0;
|
||||
final int ONNX_TYPE_FLOAT = 1;
|
||||
final int ONNX_TYPE_UINT8 = 2;
|
||||
final int ONNX_TYPE_INT8 = 3;
|
||||
final int ONNX_TYPE_UINT16 = 4;
|
||||
final int ONNX_TYPE_INT16 = 5;
|
||||
final int ONNX_TYPE_INT32 = 6;
|
||||
final int ONNX_TYPE_INT64 = 7;
|
||||
final int ONNX_TYPE_STRING = 8;
|
||||
final int ONNX_TYPE_BOOL = 9;
|
||||
final int ONNX_TYPE_FLOAT16 = 10;
|
||||
final int ONNX_TYPE_DOUBLE = 11;
|
||||
final int ONNX_TYPE_UINT32 = 12;
|
||||
final int ONNX_TYPE_UINT64 = 13;
|
||||
final int ONNX_TYPE_COMPLEX64 = 14;
|
||||
final int ONNX_TYPE_COMPLEX128 = 15;
|
||||
final int ONNX_TYPE_BFLOAT16 = 16;
|
||||
|
||||
final int[] ONNX_TYPE_SIZE = new int[] {
|
||||
0, /* UNDEFINED */
|
||||
4, /* FLOAT */
|
||||
1, /* UINT8 */
|
||||
1, /* INT8 */
|
||||
2, /* UINT16 */
|
||||
2, /* INT16 */
|
||||
4, /* INT32 */
|
||||
8, /* INT64 */
|
||||
0, /* STRING */
|
||||
1, /* BOOL */
|
||||
2, /* FLOAT16 */
|
||||
8, /* DOUBLE */
|
||||
4, /* UINT32 */
|
||||
8, /* UINT64 */
|
||||
8, /* COMPLEX64 */
|
||||
16, /* COMPLEX128 */
|
||||
2, /* BFLOAT16 */
|
||||
};
|
||||
|
||||
private ByteBuffer _data;
|
||||
private int _type;
|
||||
private int _rank;
|
||||
private long[] _sizes;
|
||||
private long[] _strides;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*/
|
||||
public RtMemRef(int rank) {
|
||||
if (rank <= 0)
|
||||
throw new IllegalArgumentException(
|
||||
"invalid rank " + rank);
|
||||
_data = null;
|
||||
_type = ONNX_TYPE_UNDEFINED;
|
||||
_rank = rank;
|
||||
_sizes = new long[rank];
|
||||
_strides = new long[rank];
|
||||
}
|
||||
|
||||
/* ---------- Data type getter and setter ---------- */
|
||||
/* For JNI wrapper only. Not intended for end user. */
|
||||
|
||||
/**
|
||||
* Type getter
|
||||
*
|
||||
* @return data type
|
||||
*/
|
||||
@SuppressWarnings("unused")
|
||||
private int getType() {
|
||||
return _type;
|
||||
}
|
||||
|
||||
/**
|
||||
* Type setter
|
||||
*
|
||||
* @param type data type to be set
|
||||
*/
|
||||
@SuppressWarnings("unused")
|
||||
private void setType(int type) {
|
||||
_type = type;
|
||||
}
|
||||
|
||||
/* ---------- Raw data getter and setter ---------- */
|
||||
/* For JNI wrapper only. Not intended for end user. */
|
||||
|
||||
/**
|
||||
* Raw data getter
|
||||
*
|
||||
* @return raw data
|
||||
*/
|
||||
@SuppressWarnings("unused")
|
||||
private ByteBuffer getData() {
|
||||
return _data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Raw data setter
|
||||
*
|
||||
* @param data raw data to be set
|
||||
*/
|
||||
@SuppressWarnings("unused")
|
||||
private void setData(ByteBuffer data) {
|
||||
_data = data.order(endian);
|
||||
}
|
||||
|
||||
/* ---------- Byte data getter and setter ---------- */
|
||||
|
||||
/**
|
||||
* Byte data getter
|
||||
*
|
||||
* @return byte data array
|
||||
*/
|
||||
public byte[] getByteData() {
|
||||
if (_data == null) return null;
|
||||
|
||||
/* asReadOnlyBuffer() creates a new view so the position of the
|
||||
* original data will stay at 0 for subsequent getByteData()
|
||||
* after get(b).
|
||||
*/
|
||||
byte[] b = new byte[_data.limit()];
|
||||
_data.asReadOnlyBuffer().get(b);
|
||||
return b;
|
||||
}
|
||||
|
||||
/**
|
||||
* Byte data setter
|
||||
*
|
||||
* @param data byte array to be set
|
||||
*/
|
||||
public void setByteData(byte[] data) {
|
||||
/* slice() creates a new view so the position of the
|
||||
* original data will stay at 0 for getByteData() after put(data).
|
||||
*/
|
||||
_data = ByteBuffer.allocateDirect(data.length);
|
||||
_data.slice().put(data);
|
||||
_type = ONNX_TYPE_INT8;
|
||||
}
|
||||
|
||||
/* ---------- Short data getter and setter ---------- */
|
||||
|
||||
/**
|
||||
* Short data getter
|
||||
*
|
||||
* @return short data array
|
||||
*/
|
||||
public short[] getShortData() {
|
||||
if (_data == null) return null;
|
||||
|
||||
/* asShortBuffer() creates a new view so the position of the
|
||||
* original data will stay at 0 for subsequent getShortData()
|
||||
* after get(s).
|
||||
*/
|
||||
ShortBuffer sb = _data.asShortBuffer();
|
||||
short[] s = new short[sb.limit()];
|
||||
sb.get(s);
|
||||
return s;
|
||||
}
|
||||
|
||||
/**
|
||||
* Short data setter
|
||||
*
|
||||
* @param data short array to be set
|
||||
*/
|
||||
public void setShortData(short[] data) {
|
||||
/* asShortBuffer() creates a new view so the position of the
|
||||
* original data will stay at 0 for getShortData() after put(data).
|
||||
*/
|
||||
_data = ByteBuffer.allocateDirect(data.length*2).order(endian);
|
||||
_data.asShortBuffer().put(data);
|
||||
_type = ONNX_TYPE_INT16;
|
||||
}
|
||||
|
||||
/* ---------- Int data getter and setter ---------- */
|
||||
|
||||
/**
|
||||
* Int data getter
|
||||
*
|
||||
* @return int data array
|
||||
*/
|
||||
public int[] getIntData() {
|
||||
if (_data == null) return null;
|
||||
|
||||
/* asIntBuffer() creates a new view so the position of the
|
||||
* original data will stay at 0 for subsequent getIntData()
|
||||
* after get(i).
|
||||
*/
|
||||
IntBuffer ib = _data.asIntBuffer();
|
||||
int[] i = new int[ib.limit()];
|
||||
ib.get(i);
|
||||
return i;
|
||||
}
|
||||
|
||||
/**
|
||||
* Int data setter
|
||||
*
|
||||
* @param data int array to be set
|
||||
*/
|
||||
public void setIntData(int[] data) {
|
||||
/* asIntBuffer() creates a new view so the position of the
|
||||
* original data will stay at 0 for getIntData() after put(data).
|
||||
*/
|
||||
_data = ByteBuffer.allocateDirect(data.length*4).order(endian);
|
||||
_data.asIntBuffer().put(data);
|
||||
_type = ONNX_TYPE_INT32;
|
||||
}
|
||||
|
||||
/* ---------- Long data getter and setter ---------- */
|
||||
|
||||
/**
|
||||
* Long data getter
|
||||
*
|
||||
* @return long data array
|
||||
*/
|
||||
public long[] getLongData() {
|
||||
if (_data == null) return null;
|
||||
|
||||
/* asLongBuffer() creates a new view so the position of the
|
||||
* original data will stay at 0 for subsequent getLongData()
|
||||
* after get(l).
|
||||
*/
|
||||
LongBuffer lb = _data.asLongBuffer();
|
||||
long[] l = new long[lb.limit()];
|
||||
lb.get(l);
|
||||
return l;
|
||||
}
|
||||
|
||||
/**
|
||||
* Long data setter
|
||||
*
|
||||
* @param data long array to be set
|
||||
*/
|
||||
public void setLongData(long[] data) {
|
||||
/* asLongBuffer() creates a new view so the position of the
|
||||
* original data will stay at 0 for getLongData() after put(data).
|
||||
*/
|
||||
_data = ByteBuffer.allocateDirect(data.length*8).order(endian);
|
||||
_data.asLongBuffer().put(data);
|
||||
_type = ONNX_TYPE_INT64;
|
||||
}
|
||||
|
||||
/* ---------- Float data getter and setter ---------- */
|
||||
|
||||
/**
|
||||
* Float data getter
|
||||
*
|
||||
* @return float data array
|
||||
*/
|
||||
public float[] getFloatData() {
|
||||
if (_data == null) return null;
|
||||
|
||||
/* asFloatBuffer() creates a new view so the position of the
|
||||
* original data will stay at 0 for subsequent getFloatData()
|
||||
* after get(f).
|
||||
*/
|
||||
FloatBuffer fb = _data.asFloatBuffer();
|
||||
float[] f = new float[fb.limit()];
|
||||
fb.get(f);
|
||||
return f;
|
||||
}
|
||||
|
||||
/**
|
||||
* Float data setter
|
||||
*
|
||||
* @param data float array to be set
|
||||
*/
|
||||
public void setFloatData(float[] data) {
|
||||
/* asFloatBuffer() creates a new view so the position of the
|
||||
* original data will stay at 0 for getFloatData() after put(data).
|
||||
*/
|
||||
_data = ByteBuffer.allocateDirect(data.length*4).order(endian);
|
||||
_data.asFloatBuffer().put(data);
|
||||
_type = ONNX_TYPE_FLOAT;
|
||||
}
|
||||
|
||||
/* ---------- Double data getter and setter ---------- */
|
||||
|
||||
/**
|
||||
* Double data getter
|
||||
*
|
||||
* @return double data array
|
||||
*/
|
||||
public double[] getDoubleData() {
|
||||
if (_data == null) return null;
|
||||
|
||||
/* asDoubleBuffer() creates a new view so the position of the
|
||||
* original data will stay at 0 for subsequent getDoubleData()
|
||||
* after get(d).
|
||||
*/
|
||||
DoubleBuffer db = _data.asDoubleBuffer();
|
||||
double[] d = new double[db.limit()];
|
||||
db.get(d);
|
||||
return d;
|
||||
}
|
||||
|
||||
/**
|
||||
* Double data setter
|
||||
*
|
||||
* @param data double array to be set
|
||||
*/
|
||||
public void setDoubleData(double[] data) {
|
||||
/* asDoubleBuffer() creates a new view so the position of the
|
||||
* original data will stay at 0 for getDoubleData() after put(data).
|
||||
*/
|
||||
_data = ByteBuffer.allocateDirect(data.length*8).order(endian);
|
||||
_data.asDoubleBuffer().put(data);
|
||||
_type = ONNX_TYPE_DOUBLE;
|
||||
}
|
||||
|
||||
/**
|
||||
* Rank getter
|
||||
*
|
||||
* @return rank
|
||||
*/
|
||||
public int getRank() {
|
||||
return _rank;
|
||||
}
|
||||
|
||||
/* ---------- Sizes getter and setter ---------- */
|
||||
|
||||
/**
|
||||
* Sizes getter
|
||||
*
|
||||
* @return sizes array
|
||||
*/
|
||||
public long[] getSizes() {
|
||||
return _sizes;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sizes setter
|
||||
*
|
||||
* @param sizes sizes array to be set
|
||||
*/
|
||||
public void setSizes(long[] sizes) {
|
||||
if (sizes.length != _rank)
|
||||
throw new IllegalArgumentException(
|
||||
"array length " + sizes.length + " != rank " + _rank);
|
||||
_sizes = sizes.clone();
|
||||
}
|
||||
|
||||
/* ---------- Strides getter and setter ---------- */
|
||||
|
||||
/**
|
||||
* Strides getter
|
||||
*
|
||||
* @return strides array
|
||||
*/
|
||||
public long[] getStrides() {
|
||||
return _strides;
|
||||
}
|
||||
|
||||
/**
|
||||
* Strides setter
|
||||
*
|
||||
* @param strides strides array to be set
|
||||
*/
|
||||
public void setStrides(long[] strides) {
|
||||
if (strides.length != _rank)
|
||||
throw new IllegalArgumentException(
|
||||
"array length " + strides.length + " != rank " + _rank);
|
||||
_strides = strides.clone();
|
||||
}
|
||||
|
||||
/**
|
||||
* Size getter
|
||||
*
|
||||
* @return product of sizes array, i.e., total number of data elements
|
||||
*/
|
||||
public long getDataSize() {
|
||||
long n = _sizes[0];
|
||||
for (int i = 1; i < _sizes.length; i++) n *= _sizes[i];
|
||||
return n;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check validity of RtMemRef
|
||||
*
|
||||
* @return true if RtMemRef is valid, false otherwise
|
||||
*/
|
||||
public boolean validRmr() {
|
||||
return (_data != null &&
|
||||
_data.limit() != 0 &&
|
||||
_data.limit() == getDataSize() * ONNX_TYPE_SIZE[_type]);
|
||||
}
|
||||
}
|
|
@ -4,12 +4,6 @@ add_dependencies(onnx-mlir-opt OMKrnlOpsInc OMONNXOpsInc)
|
|||
target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_SRC_ROOT})
|
||||
target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_BIN_ROOT})
|
||||
|
||||
set(ONNX_MLIR_LD_PRELOAD_onnx-mlir-opt "" CACHE STRING "" FORCE)
|
||||
if(BUILD_SHARED_LIBS)
|
||||
message(STATUS "To run dynamically linked onnx-mlir-opt, you must specify:")
|
||||
message(STATUS "LD_PRELOAD=${ONNX_MLIR_LD_PRELOAD_onnx-mlir-opt}")
|
||||
endif()
|
||||
|
||||
target_link_libraries(onnx-mlir-opt
|
||||
${OMLibs}
|
||||
${MLIRLibs}
|
||||
|
|
|
@ -12,6 +12,7 @@ using namespace std;
|
|||
using namespace onnx_mlir;
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
setExecPath(argv[0], (void *)main);
|
||||
registerDialects();
|
||||
|
||||
llvm::cl::OptionCategory OnnxMlirOptions(
|
||||
|
@ -33,7 +34,9 @@ int main(int argc, char *argv[]) {
|
|||
clEnumVal(EmitLLVMIR, "Lower model to LLVM IR (LLVM dialect)."),
|
||||
clEnumVal(EmitLib, "Lower model to LLVM IR, emit (to file) "
|
||||
"LLVM bitcode for model, compile and link it to a "
|
||||
"shared library.")),
|
||||
"shared library."),
|
||||
clEnumVal(EmitJNI, "Lower model to LLMV IR -> LLVM bitcode "
|
||||
"-> JNI shared library -> jar")),
|
||||
llvm::cl::init(EmitLib), llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::HideUnrelatedOptions(OnnxMlirOptions);
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
#include "src/MainUtils.hpp"
|
||||
#include "src/Runtime/ExecusionSession.hpp"
|
||||
|
||||
#define SHARED_LIB_BASE string("./TestConv_main_graph")
|
||||
|
||||
using namespace std;
|
||||
|
||||
// Returns whether onnx-mlir compiled convolution is producing the same results
|
||||
|
@ -87,14 +89,9 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H,
|
|||
|
||||
OwningModuleRef moduleRef(module);
|
||||
|
||||
llvm::SmallVector<char, 10> path;
|
||||
llvm::sys::fs::createTemporaryFile("_main_graph", "", path);
|
||||
string pathStr(path.begin(), path.end());
|
||||
llvm::FileRemover remover(path);
|
||||
|
||||
compileModule(moduleRef, ctx, pathStr, EmitLib);
|
||||
compileModule(moduleRef, ctx, SHARED_LIB_BASE, EmitLib);
|
||||
onnx_mlir::ExecutionSession sess(
|
||||
pathStr + ".so", "_dyn_entry_point_main_graph");
|
||||
SHARED_LIB_BASE + ".so", "_dyn_entry_point_main_graph");
|
||||
|
||||
std::vector<unique_ptr<RtMemRef>> inputs;
|
||||
auto xRmr = unique_ptr<RtMemRef>(getRndRealRmr<float>({N, C, H, W}));
|
||||
|
@ -127,7 +124,10 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H,
|
|||
return isRmrClose<float>(conv.get(), ref);
|
||||
}
|
||||
|
||||
int main() {
|
||||
int main(int argc, char *argv[]) {
|
||||
setExecPath(argv[0], (void *)main);
|
||||
llvm::FileRemover remover(SHARED_LIB_BASE + ".so");
|
||||
|
||||
// RapidCheck test case generation.
|
||||
rc::check("convolution implementation correctness", []() {
|
||||
const auto N = *rc::gen::inRange(1, 10);
|
||||
|
|
Loading…
Reference in New Issue