Add emitjni target (#204)
* Detect llvm-project commit change in utils/clone-mlir.sh and rebuild llvm-project for zLinux Jenkins build bot * Add --EmitJNI target (tested working with mnist and resnet50) - MainUtils * first shot at refactoring compileModuleToSharedLibrary * add setExecPath call to allow resolving runtime directory from onnx-mlir executable path when ONNX_MLIR_RUNTIME_DIR is not set. This allows tests to run without having to install onnx-mlir or to explicitly set ONNX_MLIR_RUNTIME_DIR - RtMemRef * add getDataSize for C (equivalent of size() for C++). * fix setStrides bug (setting sizes, not strides) - TestConv * _main_graph-*.so were filling up /tmp. Change to use fixed shared library in build directory * Fix clang-format-lint complaints * - getRuntimeDir checks lib64 - install targets for javaruntime and jniruntime - remove ONNX_MLIR_LD_PRELOAD_onnx-mlir and ONNX_MLIR_LD_PRELOAD_onnx-mlir-opt * See what happens when `kExecPath` decl is dropped. Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
af75b4c75e
commit
d235f248e4
|
@ -71,5 +71,5 @@ cmake -DCMAKE_INSTALL_PREFIX=${INSTALL_PATH} .. \
|
||||||
|
|
||||||
make -j$(nproc)
|
make -j$(nproc)
|
||||||
make -j$(nproc) check-onnx-lit
|
make -j$(nproc) check-onnx-lit
|
||||||
RUNTIME_DIR=$(pwd)/lib make -j$(nproc) check-onnx-backend
|
make -j$(nproc) check-onnx-backend
|
||||||
RUNTIME_DIR=$(pwd)/lib PATH=$(pwd)/bin:$PATH make -j$(nproc) test
|
PATH=$(pwd)/bin:$PATH make -j$(nproc) test
|
||||||
|
|
|
@ -55,12 +55,6 @@ endif()
|
||||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/ExternalUtil.hpp.in
|
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/ExternalUtil.hpp.in
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/ExternalUtil.hpp)
|
${CMAKE_CURRENT_BINARY_DIR}/ExternalUtil.hpp)
|
||||||
|
|
||||||
set(ONNX_MLIR_LD_PRELOAD_onnx-mlir "" CACHE STRING "" FORCE)
|
|
||||||
if(BUILD_SHARED_LIBS)
|
|
||||||
message(STATUS "To run dynamically linked onnx-mlir, you must specify:")
|
|
||||||
message(STATUS "LD_PRELOAD=${ONNX_MLIR_LD_PRELOAD_onnx-mlir}")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Libraries specified on the target_link_libraries for the add_subdirectory
|
# Libraries specified on the target_link_libraries for the add_subdirectory
|
||||||
# targets get added to the end of the list here. This creates two problems:
|
# targets get added to the end of the list here. This creates two problems:
|
||||||
# 1. It produces duplicated libraries being specified for the link command.
|
# 1. It produces duplicated libraries being specified for the link command.
|
||||||
|
|
|
@ -3,8 +3,12 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
std::string kExecPath = "@CMAKE_INSTALL_PREFIX@/bin/onnx-mlir"; /* fallback if not set by main */
|
||||||
|
const std::string kInstPath = "@CMAKE_INSTALL_PREFIX@";
|
||||||
const std::string kLlcPath = "@LLVM_PROJ_BUILD@/bin/llc";
|
const std::string kLlcPath = "@LLVM_PROJ_BUILD@/bin/llc";
|
||||||
const std::string kCxxPath = "@CMAKE_CXX_COMPILER@";
|
const std::string kCxxPath = "@CMAKE_CXX_COMPILER@";
|
||||||
const std::string kLinkerPath = "@CMAKE_LINKER@";
|
const std::string kLinkerPath = "@CMAKE_LINKER@";
|
||||||
const std::string kObjCopyPath = "@CMAKE_OBJCOPY@";
|
const std::string kObjCopyPath = "@CMAKE_OBJCOPY@";
|
||||||
|
const std::string kArPath = "@CMAKE_AR@";
|
||||||
|
const std::string kJarPath = "@Java_JAR_EXECUTABLE@";
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
//===--------------------------- main_utils.cpp ---------------------------===//
|
//===--------------------------- MainUtils.cpp ---------------------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019-2020 The IBM Research Authors.
|
// Copyright 2019-2020 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -13,6 +13,7 @@
|
||||||
#include <fcntl.h>
|
#include <fcntl.h>
|
||||||
#include <regex>
|
#include <regex>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include <llvm/Support/FileSystem.h>
|
#include <llvm/Support/FileSystem.h>
|
||||||
#include <llvm/Support/Program.h>
|
#include <llvm/Support/Program.h>
|
||||||
|
@ -22,8 +23,6 @@
|
||||||
#include "src/ExternalUtil.hpp"
|
#include "src/ExternalUtil.hpp"
|
||||||
#include "src/MainUtils.hpp"
|
#include "src/MainUtils.hpp"
|
||||||
|
|
||||||
#include "MainUtils.hpp"
|
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
#include <io.h>
|
#include <io.h>
|
||||||
#else
|
#else
|
||||||
|
@ -41,6 +40,42 @@ llvm::Optional<std::string> getEnvVar(std::string name) {
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Runtime directory contains all the libraries, jars, etc. that are
|
||||||
|
// necessary for running onnx-mlir. It's resolved in the following order:
|
||||||
|
//
|
||||||
|
// - if ONNX_MLIR_RUNTIME_DIR is set, use it, otherwise
|
||||||
|
// - get path from where onnx-mlir is run, if it's of the form
|
||||||
|
// /foo/bar/bin/onnx-mlir,
|
||||||
|
// the runtime directory is /foo/bar/lib (note that when onnx-mlir is
|
||||||
|
// installed system wide, which is typically /usr/local/bin, this will
|
||||||
|
// correctly resolve to /usr/local/lib), but some systems still have
|
||||||
|
// lib64 so we check that first. If neither exists, then
|
||||||
|
// - use CMAKE_INSTALL_PREFIX/lib, which is typically /usr/local/lib
|
||||||
|
string getRuntimeDir() {
|
||||||
|
const auto &envDir = getEnvVar("ONNX_MLIR_RUNTIME_DIR");
|
||||||
|
if (envDir && llvm::sys::fs::exists(envDir.getValue()))
|
||||||
|
return envDir.getValue();
|
||||||
|
|
||||||
|
string execDir = llvm::sys::path::parent_path(kExecPath).str();
|
||||||
|
if (llvm::sys::path::stem(execDir).str().compare("bin") == 0) {
|
||||||
|
string p = execDir.substr(0, execDir.size() - 3);
|
||||||
|
if (llvm::sys::fs::exists(p + "lib64"))
|
||||||
|
return p + "lib64";
|
||||||
|
if (llvm::sys::fs::exists(p + "lib"))
|
||||||
|
return p + "lib";
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallString<8> instDir64(kInstPath);
|
||||||
|
llvm::sys::path::append(instDir64, "lib64");
|
||||||
|
string p = llvm::StringRef(instDir64).str();
|
||||||
|
if (llvm::sys::fs::exists(p))
|
||||||
|
return p;
|
||||||
|
|
||||||
|
llvm::SmallString<8> instDir(kInstPath);
|
||||||
|
llvm::sys::path::append(instDir, "lib");
|
||||||
|
return llvm::StringRef(instDir).str();
|
||||||
|
}
|
||||||
|
|
||||||
// Helper struct to make command construction and execution easy & readable.
|
// Helper struct to make command construction and execution easy & readable.
|
||||||
struct Command {
|
struct Command {
|
||||||
std::string _path;
|
std::string _path;
|
||||||
|
@ -97,6 +132,12 @@ struct Command {
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
void setExecPath(const char *argv0, void *fmain) {
|
||||||
|
string p;
|
||||||
|
if (!(p = llvm::sys::fs::getMainExecutable(argv0, fmain)).empty())
|
||||||
|
kExecPath = p;
|
||||||
|
}
|
||||||
|
|
||||||
void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
|
void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
|
||||||
mlir::OwningModuleRef &module) {
|
mlir::OwningModuleRef &module) {
|
||||||
// Handle '.mlir' input to the ONNX MLIR frontend.
|
// Handle '.mlir' input to the ONNX MLIR frontend.
|
||||||
|
@ -119,8 +160,8 @@ void LoadMLIR(string inputFilename, mlir::MLIRContext &context,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void compileModuleToSharedLibrary(
|
void genConstPackObj(const mlir::OwningModuleRef &module,
|
||||||
const mlir::OwningModuleRef &module, string outputBaseName) {
|
llvm::Optional<string> &constPackObjPath) {
|
||||||
// Extract constant pack file name, which is embedded as a symbol in the
|
// Extract constant pack file name, which is embedded as a symbol in the
|
||||||
// module being compiled.
|
// module being compiled.
|
||||||
auto constPackFilePathSym = (*module).lookupSymbol<mlir::LLVM::GlobalOp>(
|
auto constPackFilePathSym = (*module).lookupSymbol<mlir::LLVM::GlobalOp>(
|
||||||
|
@ -131,7 +172,6 @@ void compileModuleToSharedLibrary(
|
||||||
.str();
|
.str();
|
||||||
llvm::FileRemover constPackRemover(constPackFilePath);
|
llvm::FileRemover constPackRemover(constPackFilePath);
|
||||||
|
|
||||||
llvm::Optional<std::string> constPackObjPath;
|
|
||||||
#if __APPLE__
|
#if __APPLE__
|
||||||
// Create a empty stub file, compile it to an empty obj file.
|
// Create a empty stub file, compile it to an empty obj file.
|
||||||
llvm::SmallVector<char, 20> stubSrcPath;
|
llvm::SmallVector<char, 20> stubSrcPath;
|
||||||
|
@ -153,7 +193,6 @@ void compileModuleToSharedLibrary(
|
||||||
.appendList({"-sectcreate", "binary", "param", constPackFilePath})
|
.appendList({"-sectcreate", "binary", "param", constPackFilePath})
|
||||||
.appendStr(stubObjPathStr)
|
.appendStr(stubObjPathStr)
|
||||||
.exec();
|
.exec();
|
||||||
llvm::FileRemover constPackObjRemover(constPackObjPath.getValue());
|
|
||||||
|
|
||||||
#elif __linux__
|
#elif __linux__
|
||||||
// Create param.o holding packed parameter values.
|
// Create param.o holding packed parameter values.
|
||||||
|
@ -164,7 +203,6 @@ void compileModuleToSharedLibrary(
|
||||||
.appendList({"-o", constPackObjPath.getValue()})
|
.appendList({"-o", constPackObjPath.getValue()})
|
||||||
.appendStr(constPackFilePath)
|
.appendStr(constPackFilePath)
|
||||||
.exec();
|
.exec();
|
||||||
llvm::FileRemover constPackObjRemover(constPackObjPath.getValue());
|
|
||||||
|
|
||||||
// Figure out what is the default symbol name describing the start/end
|
// Figure out what is the default symbol name describing the start/end
|
||||||
// address of the embedded data.
|
// address of the embedded data.
|
||||||
|
@ -204,40 +242,119 @@ void compileModuleToSharedLibrary(
|
||||||
mlir::KrnlPackedConstantOp::getConstPackFileNameStrLenSymbolName())
|
mlir::KrnlPackedConstantOp::getConstPackFileNameStrLenSymbolName())
|
||||||
.valueAttr(builder.getI64IntegerAttr(constPackFileName.str().size()));
|
.valueAttr(builder.getI64IntegerAttr(constPackFileName.str().size()));
|
||||||
#endif
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
// Write LLVM bitcode.
|
// Write LLVM bitcode.
|
||||||
string outputFilename = outputBaseName + ".bc";
|
void genLLVMBitcode(const mlir::OwningModuleRef &module, string bitcodePath) {
|
||||||
error_code error;
|
error_code error;
|
||||||
|
|
||||||
llvm::raw_fd_ostream moduleBitcodeStream(
|
llvm::raw_fd_ostream moduleBitcodeStream(
|
||||||
outputFilename, error, llvm::sys::fs::F_None);
|
bitcodePath, error, llvm::sys::fs::F_None);
|
||||||
|
|
||||||
llvm::WriteBitcodeToFile(
|
llvm::WriteBitcodeToFile(
|
||||||
*mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream);
|
*mlir::translateModuleToLLVMIR(*module), moduleBitcodeStream);
|
||||||
moduleBitcodeStream.flush();
|
moduleBitcodeStream.flush();
|
||||||
llvm::FileRemover bcRemover(outputFilename);
|
}
|
||||||
|
|
||||||
// Compile LLVM bitcode to object file.
|
// Compile LLVM bitcode to object file.
|
||||||
|
void genModelObject(const mlir::OwningModuleRef &module, string bitcodePath,
|
||||||
|
string modelObjPath) {
|
||||||
Command llvmToObj(/*exePath=*/kLlcPath);
|
Command llvmToObj(/*exePath=*/kLlcPath);
|
||||||
llvmToObj.appendStr("-filetype=obj");
|
llvmToObj.appendStr("-filetype=obj")
|
||||||
llvmToObj.appendStr("-relocation-model=pic");
|
.appendStr("-relocation-model=pic")
|
||||||
llvmToObj.appendStr(outputFilename);
|
.appendList({"-o", modelObjPath})
|
||||||
llvmToObj.exec();
|
.appendStr(bitcodePath)
|
||||||
std::string modelObjPath = outputBaseName + ".o";
|
.exec();
|
||||||
|
}
|
||||||
|
|
||||||
|
void genJniObject(const mlir::OwningModuleRef &module, string jniSharedLibPath,
|
||||||
|
string jniObjPath) {
|
||||||
|
Command ar(/*exePath=*/kArPath);
|
||||||
|
ar.appendStr("x").appendStr(jniSharedLibPath).appendStr(jniObjPath).exec();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Link everything into a shared object.
|
||||||
|
void genSharedLib(const mlir::OwningModuleRef &module,
|
||||||
|
string modelSharedLibPath, std::vector<string> opts,
|
||||||
|
std::vector<string> objs, std::vector<string> libs) {
|
||||||
|
|
||||||
|
string runtimeDirInclFlag = "-L" + getRuntimeDir();
|
||||||
|
|
||||||
|
Command link(kCxxPath);
|
||||||
|
link.appendList(opts)
|
||||||
|
.appendList(objs)
|
||||||
|
.appendList({"-o", modelSharedLibPath})
|
||||||
|
.appendStrOpt(runtimeDirInclFlag)
|
||||||
|
.appendList(libs)
|
||||||
|
.exec();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create jar containing java runtime and model shared library (which includes
|
||||||
|
// jni runtime).
|
||||||
|
void genJniJar(const mlir::OwningModuleRef &module, string modelSharedLibPath,
|
||||||
|
string modelJniJarPath) {
|
||||||
|
llvm::SmallString<8> runtimeDir(getRuntimeDir());
|
||||||
|
llvm::sys::path::append(runtimeDir, "javaruntime.jar");
|
||||||
|
string javaRuntimeJarPath = llvm::StringRef(runtimeDir).str();
|
||||||
|
|
||||||
|
// Copy javaruntime.jar to model jar.
|
||||||
|
llvm::sys::fs::copy_file(javaRuntimeJarPath, modelJniJarPath);
|
||||||
|
|
||||||
|
// Add shared library to model jar.
|
||||||
|
Command jar(kJarPath);
|
||||||
|
jar.appendList({"uf", modelJniJarPath}).appendStr(modelSharedLibPath).exec();
|
||||||
|
}
|
||||||
|
|
||||||
|
void compileModuleToSharedLibrary(
|
||||||
|
const mlir::OwningModuleRef &module, std::string outputBaseName) {
|
||||||
|
|
||||||
|
llvm::Optional<string> constPackObjPath;
|
||||||
|
genConstPackObj(module, constPackObjPath);
|
||||||
|
llvm::FileRemover constPackObjRemover(constPackObjPath.getValue());
|
||||||
|
|
||||||
|
string bitcodePath = outputBaseName + ".bc";
|
||||||
|
genLLVMBitcode(module, bitcodePath);
|
||||||
|
llvm::FileRemover bitcodeRemover(bitcodePath);
|
||||||
|
|
||||||
|
string modelObjPath = outputBaseName + ".o";
|
||||||
|
genModelObject(module, bitcodePath, modelObjPath);
|
||||||
llvm::FileRemover modelObjRemover(modelObjPath);
|
llvm::FileRemover modelObjRemover(modelObjPath);
|
||||||
|
|
||||||
llvm::Optional<std::string> runtimeDirInclFlag;
|
string modelSharedLibPath = outputBaseName + ".so";
|
||||||
if (getEnvVar("RUNTIME_DIR").hasValue())
|
genSharedLib(module, modelSharedLibPath, {"-shared", "-fPIC"},
|
||||||
runtimeDirInclFlag = "-L" + getEnvVar("RUNTIME_DIR").getValue();
|
{constPackObjPath.getValueOr(""), modelObjPath},
|
||||||
|
{"-lEmbeddedDataLoader", "-lcruntime"});
|
||||||
|
}
|
||||||
|
|
||||||
// Link everything into a shared object.
|
void compileModuleToJniJar(
|
||||||
Command link(kCxxPath);
|
const mlir::OwningModuleRef &module, std::string outputBaseName) {
|
||||||
link.appendList({"-shared", "-fPIC"})
|
|
||||||
.appendStr(modelObjPath)
|
llvm::Optional<string> constPackObjPath;
|
||||||
.appendStr(constPackObjPath.getValueOr(""))
|
genConstPackObj(module, constPackObjPath);
|
||||||
.appendList({"-o", outputBaseName + ".so"})
|
llvm::FileRemover constPackObjRemover(constPackObjPath.getValue());
|
||||||
.appendStrOpt(runtimeDirInclFlag)
|
|
||||||
.appendList({"-lEmbeddedDataLoader", "-lcruntime"})
|
string bitcodePath = outputBaseName + ".bc";
|
||||||
.exec();
|
genLLVMBitcode(module, bitcodePath);
|
||||||
|
llvm::FileRemover bitcodeRemover(bitcodePath);
|
||||||
|
|
||||||
|
string modelObjPath = outputBaseName + ".o";
|
||||||
|
genModelObject(module, bitcodePath, modelObjPath);
|
||||||
|
llvm::FileRemover modelObjRemover(modelObjPath);
|
||||||
|
|
||||||
|
string jniSharedLibPath = getRuntimeDir() + "/libjniruntime.a";
|
||||||
|
string jniObjPath = "jnidummy.c.o";
|
||||||
|
genJniObject(module, jniSharedLibPath, jniObjPath);
|
||||||
|
llvm::FileRemover jniObjRemover(jniObjPath);
|
||||||
|
|
||||||
|
string modelSharedLibPath = "libmodel.so";
|
||||||
|
genSharedLib(module, modelSharedLibPath,
|
||||||
|
{"-shared", "-fPIC", "-z", "noexecstack"},
|
||||||
|
{constPackObjPath.getValueOr(""), modelObjPath, jniObjPath},
|
||||||
|
{"-lEmbeddedDataLoader", "-lcruntime", "-ljniruntime"});
|
||||||
|
llvm::FileRemover modelSharedLibRemover(modelSharedLibPath);
|
||||||
|
|
||||||
|
string modelJniJarPath = outputBaseName + ".jar";
|
||||||
|
genJniJar(module, modelSharedLibPath, modelJniJarPath);
|
||||||
}
|
}
|
||||||
|
|
||||||
void registerDialects() {
|
void registerDialects() {
|
||||||
|
@ -349,6 +466,9 @@ void emitOutputFiles(string outputBaseName, EmissionTargetType emissionTarget,
|
||||||
// Write LLVM bitcode to disk, compile & link.
|
// Write LLVM bitcode to disk, compile & link.
|
||||||
compileModuleToSharedLibrary(module, outputBaseName);
|
compileModuleToSharedLibrary(module, outputBaseName);
|
||||||
printf("Shared library %s.so has been compiled.\n", outputBaseName.c_str());
|
printf("Shared library %s.so has been compiled.\n", outputBaseName.c_str());
|
||||||
|
} else if (emissionTarget == EmitJNI) {
|
||||||
|
compileModuleToJniJar(module, outputBaseName);
|
||||||
|
printf("JNI archive %s.jar has been compiled.\n", outputBaseName.c_str());
|
||||||
} else {
|
} else {
|
||||||
// Emit the version with all constants included.
|
// Emit the version with all constants included.
|
||||||
outputCode(module, outputBaseName, ".onnx.mlir");
|
outputCode(module, outputBaseName, ".onnx.mlir");
|
||||||
|
|
|
@ -43,14 +43,20 @@ enum EmissionTargetType {
|
||||||
EmitMLIR,
|
EmitMLIR,
|
||||||
EmitLLVMIR,
|
EmitLLVMIR,
|
||||||
EmitLib,
|
EmitLib,
|
||||||
|
EmitJNI,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
void setExecPath(const char *argv0, void *fmain);
|
||||||
|
|
||||||
void LoadMLIR(std::string inputFilename, mlir::MLIRContext &context,
|
void LoadMLIR(std::string inputFilename, mlir::MLIRContext &context,
|
||||||
mlir::OwningModuleRef &module);
|
mlir::OwningModuleRef &module);
|
||||||
|
|
||||||
void compileModuleToSharedLibrary(
|
void compileModuleToSharedLibrary(
|
||||||
const mlir::OwningModuleRef &module, std::string outputBaseName);
|
const mlir::OwningModuleRef &module, std::string outputBaseName);
|
||||||
|
|
||||||
|
void compileModuleToJniJar(
|
||||||
|
const mlir::OwningModuleRef &module, std::string outputBaseName);
|
||||||
|
|
||||||
void registerDialects();
|
void registerDialects();
|
||||||
|
|
||||||
void addONNXToMLIRPasses(mlir::PassManager &pm);
|
void addONNXToMLIRPasses(mlir::PassManager &pm);
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
add_subdirectory(jni)
|
||||||
|
|
||||||
# Create static libcruntime.a to be embedded in model.so to make model.so self contained.
|
# Create static libcruntime.a to be embedded in model.so to make model.so self contained.
|
||||||
# However, by default object code for static library is not compiled with -fPIC. Embedding
|
# However, by default object code for static library is not compiled with -fPIC. Embedding
|
||||||
# such static library in a shared library can cause runtime failure on some architectures,
|
# such static library in a shared library can cause runtime failure on some architectures,
|
||||||
|
|
|
@ -150,6 +150,13 @@ int64_t *getStrides(RtMemRef *dynMemRef) { return dynMemRef->strides; }
|
||||||
|
|
||||||
int64_t getSize(OrderedRtMemRefDict *dict) { return dict->orderedNames.size(); }
|
int64_t getSize(OrderedRtMemRefDict *dict) { return dict->orderedNames.size(); }
|
||||||
|
|
||||||
|
INDEX_TYPE getDataSize(RtMemRef *rtMemRef) {
|
||||||
|
INDEX_TYPE n = rtMemRef->sizes[0];
|
||||||
|
for (int i = 1; i < rtMemRef->rank; i++)
|
||||||
|
n *= rtMemRef->sizes[i];
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
|
||||||
void setDType(RtMemRef *dynMemRef, int onnxType) {
|
void setDType(RtMemRef *dynMemRef, int onnxType) {
|
||||||
dynMemRef->onnx_dtype = onnxType;
|
dynMemRef->onnx_dtype = onnxType;
|
||||||
}
|
}
|
||||||
|
@ -160,5 +167,5 @@ unsigned int getRank(RtMemRef *dynMemRef) { return dynMemRef->rank; }
|
||||||
|
|
||||||
void setStrides(RtMemRef *dynMemRef, int64_t *strides) {
|
void setStrides(RtMemRef *dynMemRef, int64_t *strides) {
|
||||||
for (int i = 0; i < dynMemRef->rank; i++)
|
for (int i = 0; i < dynMemRef->rank; i++)
|
||||||
dynMemRef->sizes[i] = strides[i];
|
dynMemRef->strides[i] = strides[i];
|
||||||
}
|
}
|
||||||
|
|
|
@ -151,6 +151,9 @@ OrderedRtMemRefDict *createOrderedRtMemRefDict();
|
||||||
// Get how many dynamic memrefs are in dict.
|
// Get how many dynamic memrefs are in dict.
|
||||||
int64_t getSize(OrderedRtMemRefDict *dict);
|
int64_t getSize(OrderedRtMemRefDict *dict);
|
||||||
|
|
||||||
|
// Get how many data elements are in RtMemRef.
|
||||||
|
INDEX_TYPE getDataSize(RtMemRef *rtMemRef);
|
||||||
|
|
||||||
// Create a dynmemref with a certain rank.
|
// Create a dynmemref with a certain rank.
|
||||||
RtMemRef *createRtMemRef(int rank);
|
RtMemRef *createRtMemRef(int rank);
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
find_package(Java COMPONENTS Development)
|
||||||
|
find_package(JNI)
|
||||||
|
|
||||||
|
if(Java_Development_FOUND AND JNI_FOUND)
|
||||||
|
include(UseJava)
|
||||||
|
|
||||||
|
# Target for Java runtime jar
|
||||||
|
add_jar(javaruntime
|
||||||
|
src/com/ibm/onnxmlir/DynEntryPoint.java
|
||||||
|
src/com/ibm/onnxmlir/OrderedRtMemRefDict.java
|
||||||
|
src/com/ibm/onnxmlir/RtMemRef.java
|
||||||
|
OUTPUT_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
|
||||||
|
|
||||||
|
# Target for JNI runtime lib
|
||||||
|
add_library(jniruntime STATIC
|
||||||
|
jniwrapper.c jnilog.c jnidummy.c
|
||||||
|
com_ibm_onnxmlir_DynEntryPoint.h jnilog.h ../RtMemRef.h)
|
||||||
|
set_target_properties(jniruntime PROPERTIES
|
||||||
|
POSITION_INDEPENDENT_CODE TRUE)
|
||||||
|
target_include_directories(jniruntime PRIVATE
|
||||||
|
${ONNX_MLIR_SRC_ROOT}/src/Runtime
|
||||||
|
${JAVA_INCLUDE_PATH}
|
||||||
|
${JAVA_INCLUDE_PATH2})
|
||||||
|
|
||||||
|
install_jar(javaruntime DESTINATION lib)
|
||||||
|
install(TARGETS jniruntime DESTINATION lib)
|
||||||
|
|
||||||
|
else()
|
||||||
|
message(WARNING "Java Development component or JNI not found, JNI targets will not work")
|
||||||
|
endif()
|
|
@ -0,0 +1,22 @@
|
||||||
|
/* DO NOT EDIT THIS FILE - it is machine generated */
|
||||||
|
#include <jni.h>
|
||||||
|
/* Header for class com_ibm_onnxmlir_DynEntryPoint */
|
||||||
|
|
||||||
|
#ifndef _Included_com_ibm_onnxmlir_DynEntryPoint
|
||||||
|
#define _Included_com_ibm_onnxmlir_DynEntryPoint
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
/*
|
||||||
|
* Class: com_ibm_onnxmlir_DynEntryPoint
|
||||||
|
* Method: main_graph_jni
|
||||||
|
* Signature:
|
||||||
|
* (Lcom/ibm/onnxmlir/OrderedRtMemRefDict;)Lcom/ibm/onnxmlir/OrderedRtMemRefDict;
|
||||||
|
*/
|
||||||
|
JNIEXPORT jobject JNICALL Java_com_ibm_onnxmlir_DynEntryPoint_main_1graph_1jni(
|
||||||
|
JNIEnv *, jclass, jobject);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#endif
|
|
@ -0,0 +1,7 @@
|
||||||
|
#include "com_ibm_onnxmlir_DynEntryPoint.h"
|
||||||
|
|
||||||
|
/* Dummy routine to force the link editor to embed code in libjniruntime.a
|
||||||
|
into libmodel.so */
|
||||||
|
void __dummy_do_not_call__(JNIEnv *env, jclass cls, jobject obj) {
|
||||||
|
Java_com_ibm_onnxmlir_DynEntryPoint_main_1graph_1jni(NULL, NULL, NULL);
|
||||||
|
}
|
|
@ -0,0 +1,94 @@
|
||||||
|
#include <errno.h>
|
||||||
|
#include <libgen.h>
|
||||||
|
#include <stdarg.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include <time.h>
|
||||||
|
|
||||||
|
#include "jnilog.h"
|
||||||
|
|
||||||
|
static int log_initd = 0;
|
||||||
|
static int log_level;
|
||||||
|
static FILE *log_fp;
|
||||||
|
|
||||||
|
/* Must match enum in log.h */
|
||||||
|
static char *log_level_name[] = {
|
||||||
|
"trace", "debug", "info", "warning", "error", "fatal"};
|
||||||
|
|
||||||
|
/* Return numerical log level of give level name */
|
||||||
|
static int get_log_level_by_name(char *name) {
|
||||||
|
int level = -1;
|
||||||
|
for (int i = 0; i < sizeof(log_level_name) / sizeof(char *); i++) {
|
||||||
|
if (!strcmp(name, log_level_name[i])) {
|
||||||
|
level = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return level;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Return FILE pointer of given file name */
|
||||||
|
static FILE *get_log_file_by_name(char *name) {
|
||||||
|
FILE *fp = NULL;
|
||||||
|
if (!strcmp(name, "stdout"))
|
||||||
|
fp = stdout;
|
||||||
|
else if (!strcmp(name, "stderr"))
|
||||||
|
fp = stderr;
|
||||||
|
else
|
||||||
|
fp = fopen(name, "w");
|
||||||
|
return fp;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Initialize log system. Set default log level and file or use environment
|
||||||
|
* variables ONNX_MLIR_JNI_LOG_LEVEL and ONNX_MLIR_JNI_LOG_FILE, respectively.
|
||||||
|
*/
|
||||||
|
static void log_init() {
|
||||||
|
if (log_initd)
|
||||||
|
return;
|
||||||
|
|
||||||
|
log_level = LOG_INFO;
|
||||||
|
char *strlevel = getenv("ONNX_MLIR_JNI_LOG_LEVEL");
|
||||||
|
int level;
|
||||||
|
if (strlevel && (level = get_log_level_by_name(strlevel)) != -1)
|
||||||
|
log_level = level;
|
||||||
|
|
||||||
|
log_fp = stderr;
|
||||||
|
char *strfname = getenv("ONNX_MLIR_JNI_LOG_FILE");
|
||||||
|
FILE *fp;
|
||||||
|
if (strfname && (fp = get_log_file_by_name(strfname)))
|
||||||
|
log_fp = fp;
|
||||||
|
|
||||||
|
tzset();
|
||||||
|
log_initd = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Generic log routine */
|
||||||
|
void log_printf(
|
||||||
|
int level, char *file, const char *func, int line, char *fmt, ...) {
|
||||||
|
if (!log_initd)
|
||||||
|
log_init();
|
||||||
|
if (level < log_level)
|
||||||
|
return;
|
||||||
|
|
||||||
|
time_t now;
|
||||||
|
struct tm *tm;
|
||||||
|
char buf[80];
|
||||||
|
|
||||||
|
/* Get local time and format as 2020-07-03 05:17:42 -0400 */
|
||||||
|
if (time(&now) == -1 || (tm = localtime(&now)) == NULL ||
|
||||||
|
strftime(buf, sizeof(buf), "%F %T %z", tm) == 0)
|
||||||
|
sprintf(buf, "-");
|
||||||
|
|
||||||
|
/* Output log prefix */
|
||||||
|
fprintf(log_fp, "[%s][%s]%s:%s:%d ", buf, log_level_name[level],
|
||||||
|
basename(file), func, line);
|
||||||
|
|
||||||
|
/* Output actually log data */
|
||||||
|
va_list va_list;
|
||||||
|
va_start(va_list, fmt);
|
||||||
|
vfprintf(log_fp, fmt, va_list);
|
||||||
|
va_end(va_list);
|
||||||
|
|
||||||
|
fprintf(log_fp, "\n");
|
||||||
|
}
|
|
@ -0,0 +1,125 @@
|
||||||
|
#ifndef __JNILOG_H__
|
||||||
|
#define __JNILOG_H__
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
enum { LOG_TRACE, LOG_DEBUG, LOG_INFO, LOG_WARNING, LOG_ERROR, LOG_FATAL };
|
||||||
|
|
||||||
|
#define LOG_MAX_NUM 16 /* max number of elements to output */
|
||||||
|
|
||||||
|
#define MIN(x, y) ((x) > (y) ? y : x)
|
||||||
|
|
||||||
|
/* Construct string of up to LOG_MAX_NUM elements of a char array */
|
||||||
|
#define LOG_CHAR_BUF(buf, data, n) \
|
||||||
|
do { \
|
||||||
|
buf[0] = '\0'; \
|
||||||
|
for (int i = 0; i < MIN(n, LOG_MAX_NUM); i++) \
|
||||||
|
sprintf(buf + strlen(buf), " %02x", ((char *)data)[i]); \
|
||||||
|
sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
/* Construct string of up to LOG_MAX_NUM elements of a short array */
|
||||||
|
#define LOG_SHORT_BUF(buf, data, n) \
|
||||||
|
do { \
|
||||||
|
buf[0] = '\0'; \
|
||||||
|
for (int i = 0; i < MIN(n, LOG_MAX_NUM); i++) \
|
||||||
|
sprintf(buf + strlen(buf), " %d", ((short *)data)[i]); \
|
||||||
|
sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
/* Construct string of up to LOG_MAX_NUM elements of a int array */
|
||||||
|
#define LOG_INT_BUF(buf, data, n) \
|
||||||
|
do { \
|
||||||
|
buf[0] = '\0'; \
|
||||||
|
for (int i = 0; i < MIN(n, LOG_MAX_NUM); i++) \
|
||||||
|
sprintf(buf + strlen(buf), " %d", ((int *)data)[i]); \
|
||||||
|
sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
/* Construct string of up to LOG_MAX_NUM elements of a long array */
|
||||||
|
#define LOG_LONG_BUF(buf, data, n) \
|
||||||
|
do { \
|
||||||
|
buf[0] = '\0'; \
|
||||||
|
for (int i = 0; i < MIN(n, LOG_MAX_NUM); i++) \
|
||||||
|
sprintf(buf + strlen(buf), " %ld", ((long *)data)[i]); \
|
||||||
|
sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
/* Construct string of up to LOG_MAX_NUM elements of a float array */
|
||||||
|
#define LOG_FLOAT_BUF(buf, data, n) \
|
||||||
|
do { \
|
||||||
|
buf[0] = '\0'; \
|
||||||
|
for (int i = 0; i < MIN(n, LOG_MAX_NUM); i++) \
|
||||||
|
sprintf(buf + strlen(buf), " %f", ((float *)data)[i]); \
|
||||||
|
sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
/* Construct string of up to LOG_MAX_NUM elements of a double array */
|
||||||
|
#define LOG_DOUBLE_BUF(buf, data, n) \
|
||||||
|
do { \
|
||||||
|
buf[0] = '\0'; \
|
||||||
|
for (int i = 0; i < MIN(n, LOG_MAX_NUM); i++) \
|
||||||
|
sprintf(buf + strlen(buf), " %f", ((double *)data)[i]); \
|
||||||
|
sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
enum {
|
||||||
|
ONNX_TYPE_UNDEFINED, /* 0 */
|
||||||
|
ONNX_TYPE_FLOAT, /* 1 */
|
||||||
|
ONNX_TYPE_UINT8, /* 2 */
|
||||||
|
ONNX_TYPE_INT8, /* 3 */
|
||||||
|
ONNX_TYPE_UINT16, /* 4 */
|
||||||
|
ONNX_TYPE_INT16, /* 5 */
|
||||||
|
ONNX_TYPE_INT32, /* 6 */
|
||||||
|
ONNX_TYPE_INT64, /* 7 */
|
||||||
|
ONNX_TYPE_STRING, /* 8 */
|
||||||
|
ONNX_TYPE_BOOL, /* 9 */
|
||||||
|
ONNX_TYPE_FLOAT16, /* 10 */
|
||||||
|
ONNX_TYPE_DOUBLE, /* 11 */
|
||||||
|
ONNX_TYPE_UINT32, /* 12 */
|
||||||
|
ONNX_TYPE_UINT64, /* 13 */
|
||||||
|
ONNX_TYPE_COMPLEX64, /* 14 */
|
||||||
|
ONNX_TYPE_COMPLEX128, /* 15 */
|
||||||
|
ONNX_TYPE_BFLOAT16, /* 16 */
|
||||||
|
};
|
||||||
|
|
||||||
|
/* Construct string of up to LOG_MAX_NUM elements of a "type" array */
|
||||||
|
#define LOG_TYPE_BUF(type, buf, data, n) \
|
||||||
|
do { \
|
||||||
|
switch (type) { \
|
||||||
|
case ONNX_TYPE_UINT8: \
|
||||||
|
case ONNX_TYPE_INT8: \
|
||||||
|
LOG_CHAR_BUF(buf, data, n); \
|
||||||
|
break; \
|
||||||
|
case ONNX_TYPE_UINT16: \
|
||||||
|
case ONNX_TYPE_INT16: \
|
||||||
|
LOG_SHORT_BUF(buf, data, n); \
|
||||||
|
break; \
|
||||||
|
case ONNX_TYPE_UINT32: \
|
||||||
|
case ONNX_TYPE_INT32: \
|
||||||
|
LOG_INT_BUF(buf, data, n); \
|
||||||
|
break; \
|
||||||
|
case ONNX_TYPE_UINT64: \
|
||||||
|
case ONNX_TYPE_INT64: \
|
||||||
|
LOG_LONG_BUF(buf, data, n); \
|
||||||
|
break; \
|
||||||
|
case ONNX_TYPE_FLOAT: \
|
||||||
|
LOG_FLOAT_BUF(buf, data, n); \
|
||||||
|
break; \
|
||||||
|
case ONNX_TYPE_DOUBLE: \
|
||||||
|
LOG_DOUBLE_BUF(buf, data, n); \
|
||||||
|
break; \
|
||||||
|
defaut: \
|
||||||
|
sprintf(buf, " unsupported data type %d ", type); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
/* Main macro for log output */
|
||||||
|
#define LOG_PRINTF(level, ...) \
|
||||||
|
log_printf(level, __FILE__, __FUNCTION__, __LINE__, __VA_ARGS__)
|
||||||
|
|
||||||
|
/* Generic log routine */
|
||||||
|
void log_printf(
|
||||||
|
int level, char *file, const char *func, int line, char *fmt, ...);
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,356 @@
|
||||||
|
#include <assert.h>
|
||||||
|
#include <malloc.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include "RtMemRef.h"
|
||||||
|
#include "com_ibm_onnxmlir_DynEntryPoint.h"
|
||||||
|
#include "jnilog.h"
|
||||||
|
|
||||||
|
/* Declare type var, make call and assign to var, check against val */
|
||||||
|
#define CHECK_CALL(type, var, call, val) \
|
||||||
|
type var = call; \
|
||||||
|
if (var == val) \
|
||||||
|
return NULL
|
||||||
|
|
||||||
|
/* Make a JNI call and throw Java exception if the call failed */
|
||||||
|
#define JNI_CALL(env, stmt) \
|
||||||
|
stmt; \
|
||||||
|
do { \
|
||||||
|
jthrowable e = (*env)->ExceptionOccurred(env); \
|
||||||
|
if (e) { \
|
||||||
|
LOG_PRINTF(LOG_ERROR, "JNI call exception occurred"); \
|
||||||
|
(*env)->Throw(env, e); \
|
||||||
|
return NULL; \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
/* Make a JNI call and assign return value to var,
|
||||||
|
* throw Java exception if the call failed
|
||||||
|
*/
|
||||||
|
#define JNI_VAR_CALL(env, var, call) JNI_CALL(env, var = call)
|
||||||
|
|
||||||
|
/* Declare type var, make a JNI call and assign return value to var,
|
||||||
|
* throw Java exception if the call failed
|
||||||
|
*/
|
||||||
|
#define JNI_TYPE_VAR_CALL(env, type, var, call) JNI_CALL(env, type var = call);
|
||||||
|
|
||||||
|
/* If cond is true (native code failed), log error and throw Java exception */
|
||||||
|
#define JNI_COND(type, var, call, val, env, cls, ...) \
|
||||||
|
type var = call; \
|
||||||
|
do { \
|
||||||
|
if (var == val) { \
|
||||||
|
LOG_PRINTF(LOG_ERROR, __VA_ARGS__); \
|
||||||
|
(*env)->ThrowNew(env, cls, "native code error"); \
|
||||||
|
return NULL; \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
/* Debug output of RtMemRef fields */
|
||||||
|
#define RMR_DEBUG(i, type, rank, sizes, strides, data, datasize) \
|
||||||
|
do { \
|
||||||
|
char tmp[1024]; \
|
||||||
|
LOG_PRINTF(LOG_DEBUG, "rmr[%d]:type=%d", i, type); \
|
||||||
|
LOG_PRINTF(LOG_DEBUG, "rmr[%d]:rank=%d", i, rank); \
|
||||||
|
LOG_LONG_BUF(tmp, sizes, rank); \
|
||||||
|
LOG_PRINTF(LOG_DEBUG, "rmr[%d]:sizes=[%s]", i, tmp); \
|
||||||
|
LOG_LONG_BUF(tmp, strides, rank); \
|
||||||
|
LOG_PRINTF(LOG_DEBUG, "rmr[%d]:strides=[%s]", i, tmp); \
|
||||||
|
LOG_TYPE_BUF(type, tmp, data, datasize); \
|
||||||
|
LOG_PRINTF(LOG_DEBUG, "rmr[%d]:data=[%s]", i, tmp); \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
/* Model shared library entry point */
|
||||||
|
extern OrderedRtMemRefDict *_dyn_entry_point_main_graph(OrderedRtMemRefDict *);
|
||||||
|
|
||||||
|
/* ONNX type to size (number of bytes) mapping */
|
||||||
|
int onnx_type_size[] = {
|
||||||
|
0, /* UNDEFINED = 0 */
|
||||||
|
4, /* FLOAT = 1 */
|
||||||
|
1, /* UINT8 = 2 */
|
||||||
|
1, /* INT8 = 3 */
|
||||||
|
2, /* UINT16 = 4 */
|
||||||
|
2, /* INT16 = 5 */
|
||||||
|
4, /* INT32 = 6 */
|
||||||
|
8, /* INT64 = 7 */
|
||||||
|
0, /* STRING = 8 */
|
||||||
|
1, /* BOOL = 9 */
|
||||||
|
2, /* FLOAT16 = 10 */
|
||||||
|
8, /* DOUBLE = 11 */
|
||||||
|
4, /* UINT32 = 12 */
|
||||||
|
8, /* UINT64 = 13 */
|
||||||
|
8, /* COMPLEX64 = 14 */
|
||||||
|
16, /* COMPLEX128 = 15 */
|
||||||
|
2, /* BFLOAT16 = 16 */
|
||||||
|
};
|
||||||
|
|
||||||
|
/* Java classes and methods needed for making various JNI API calls */
|
||||||
|
typedef struct {
|
||||||
|
jclass ecpt_cls; /* java/lang/Exception class */
|
||||||
|
jclass long_cls; /* java/lang/Long class */
|
||||||
|
jclass string_cls; /* java/lang/String class */
|
||||||
|
jclass ormrd_cls; /* com/ibm/onnxmlir/OrderedRtMemRefDict class */
|
||||||
|
jclass rmr_cls; /* com/ibm/onnxmlir/RtMemRef class */
|
||||||
|
|
||||||
|
jmethodID ormrd_constructor; /* OrderedRtMemRefDict constructor */
|
||||||
|
jmethodID ormrd_getRmrs; /* OrderedRtMemRefDict getRmrs method */
|
||||||
|
jmethodID ormrd_getNames; /* OrderedRtMemRefDict getNames method */
|
||||||
|
|
||||||
|
jmethodID rmr_constructor; /* RtMemRef constructor */
|
||||||
|
jmethodID rmr_getType; /* RtMemRef getType method */
|
||||||
|
jmethodID rmr_setType; /* RtMemRef setType method */
|
||||||
|
jmethodID rmr_getRank; /* RtMemRef getRank method */
|
||||||
|
jmethodID rmr_getData; /* RtMemRef getData method */
|
||||||
|
jmethodID rmr_setData; /* RtMemRef setData method */
|
||||||
|
jmethodID rmr_getSizes; /* RtMemRef getSizes method */
|
||||||
|
jmethodID rmr_setSizes; /* RtMemRef setSizes method */
|
||||||
|
jmethodID rmr_getStrides; /* RtMemRef getStrides method */
|
||||||
|
jmethodID rmr_setStrides; /* RtMemRef setStrides method */
|
||||||
|
jmethodID rmr_getDataSize; /* RtMemRef getDataSize method */
|
||||||
|
} jniapi_t;
|
||||||
|
|
||||||
|
jniapi_t jniapi;
|
||||||
|
|
||||||
|
/* Fill in struct jniapi */
|
||||||
|
jniapi_t *fill_jniapi(JNIEnv *env, jniapi_t *japi) {
|
||||||
|
/* Get Java Exception, Long, String, OrderedRtMemRefDict, and RtMemRef classes
|
||||||
|
*/
|
||||||
|
JNI_VAR_CALL(
|
||||||
|
env, japi->ecpt_cls, (*env)->FindClass(env, "java/lang/Exception"));
|
||||||
|
JNI_VAR_CALL(env, japi->long_cls, (*env)->FindClass(env, "java/lang/Long"));
|
||||||
|
JNI_VAR_CALL(
|
||||||
|
env, japi->string_cls, (*env)->FindClass(env, "java/lang/String"));
|
||||||
|
JNI_VAR_CALL(env, japi->ormrd_cls,
|
||||||
|
(*env)->FindClass(env, "com/ibm/onnxmlir/OrderedRtMemRefDict"));
|
||||||
|
JNI_VAR_CALL(
|
||||||
|
env, japi->rmr_cls, (*env)->FindClass(env, "com/ibm/onnxmlir/RtMemRef"));
|
||||||
|
|
||||||
|
/* Get method ID of constructor and various methods in OrderedRtMemRefDict */
|
||||||
|
JNI_VAR_CALL(env, japi->ormrd_constructor,
|
||||||
|
(*env)->GetMethodID(
|
||||||
|
env, japi->ormrd_cls, "<init>", "([Lcom/ibm/onnxmlir/RtMemRef;)V"));
|
||||||
|
JNI_VAR_CALL(env, japi->ormrd_getRmrs,
|
||||||
|
(*env)->GetMethodID(
|
||||||
|
env, japi->ormrd_cls, "getRmrs", "()[Lcom/ibm/onnxmlir/RtMemRef;"));
|
||||||
|
JNI_VAR_CALL(env, japi->ormrd_getNames,
|
||||||
|
(*env)->GetMethodID(
|
||||||
|
env, japi->ormrd_cls, "getNames", "()[Ljava/lang/String;"));
|
||||||
|
|
||||||
|
/* Get method ID of constructor and various methods in RtMemRef */
|
||||||
|
JNI_VAR_CALL(env, japi->rmr_constructor,
|
||||||
|
(*env)->GetMethodID(env, japi->rmr_cls, "<init>", "(I)V"));
|
||||||
|
JNI_VAR_CALL(env, japi->rmr_getType,
|
||||||
|
(*env)->GetMethodID(env, japi->rmr_cls, "getType", "()I"));
|
||||||
|
JNI_VAR_CALL(env, japi->rmr_setType,
|
||||||
|
(*env)->GetMethodID(env, japi->rmr_cls, "setType", "(I)V"));
|
||||||
|
JNI_VAR_CALL(env, japi->rmr_getRank,
|
||||||
|
(*env)->GetMethodID(env, japi->rmr_cls, "getRank", "()I"));
|
||||||
|
JNI_VAR_CALL(env, japi->rmr_getData,
|
||||||
|
(*env)->GetMethodID(
|
||||||
|
env, japi->rmr_cls, "getData", "()Ljava/nio/ByteBuffer;"));
|
||||||
|
JNI_VAR_CALL(env, japi->rmr_setData,
|
||||||
|
(*env)->GetMethodID(
|
||||||
|
env, japi->rmr_cls, "setData", "(Ljava/nio/ByteBuffer;)V"));
|
||||||
|
JNI_VAR_CALL(env, japi->rmr_getSizes,
|
||||||
|
(*env)->GetMethodID(env, japi->rmr_cls, "getSizes", "()[J"));
|
||||||
|
JNI_VAR_CALL(env, japi->rmr_setSizes,
|
||||||
|
(*env)->GetMethodID(env, japi->rmr_cls, "setSizes", "([J)V"));
|
||||||
|
JNI_VAR_CALL(env, japi->rmr_getStrides,
|
||||||
|
(*env)->GetMethodID(env, japi->rmr_cls, "getStrides", "()[J"));
|
||||||
|
JNI_VAR_CALL(env, japi->rmr_setStrides,
|
||||||
|
(*env)->GetMethodID(env, japi->rmr_cls, "setStrides", "([J)V"));
|
||||||
|
JNI_VAR_CALL(env, japi->rmr_getDataSize,
|
||||||
|
(*env)->GetMethodID(env, japi->rmr_cls, "getDataSize", "()J"));
|
||||||
|
|
||||||
|
return japi;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Convert Java object to native data structure */
|
||||||
|
OrderedRtMemRefDict *ormrd_java_to_native(
|
||||||
|
JNIEnv *env, jclass cls, jobject obj, jniapi_t *japi) {
|
||||||
|
/* Get object array "rmrs" and "names" in OrderedRtMemRefDict */
|
||||||
|
JNI_TYPE_VAR_CALL(env, jobjectArray, ormrd_rmrs,
|
||||||
|
(*env)->CallObjectMethod(env, obj, japi->ormrd_getRmrs));
|
||||||
|
JNI_TYPE_VAR_CALL(env, jobjectArray, ormrd_names,
|
||||||
|
(*env)->CallObjectMethod(env, obj, japi->ormrd_getNames));
|
||||||
|
|
||||||
|
/* Get length of object array "rmrs" and "names" in OrderedRtMemRefDict */
|
||||||
|
JNI_TYPE_VAR_CALL(
|
||||||
|
env, jsize, ormrd_rmrs_len, (*env)->GetArrayLength(env, ormrd_rmrs));
|
||||||
|
JNI_TYPE_VAR_CALL(
|
||||||
|
env, jsize, ormrd_names_len, (*env)->GetArrayLength(env, ormrd_names));
|
||||||
|
|
||||||
|
/* Allocate memory for holding each Java rmr object and name string,
|
||||||
|
* and RtMemRef and char pointers for constructing native RtMemRef and name
|
||||||
|
* array
|
||||||
|
*/
|
||||||
|
JNI_COND(jobject *, obj_rmr, malloc(ormrd_rmrs_len * sizeof(jobject)), NULL,
|
||||||
|
env, japi->ecpt_cls, "obj_rmr=null");
|
||||||
|
JNI_COND(jstring *, obj_name, malloc(ormrd_names_len * sizeof(jstring)), NULL,
|
||||||
|
env, japi->ecpt_cls, "obj_name=null");
|
||||||
|
JNI_COND(RtMemRef **, jni_rmr, malloc(ormrd_rmrs_len * sizeof(RtMemRef *)),
|
||||||
|
NULL, env, japi->ecpt_cls, "jni_rmr=null");
|
||||||
|
JNI_COND(const char **, jni_name,
|
||||||
|
malloc(ormrd_names_len * sizeof(const char *)), NULL, env, japi->ecpt_cls,
|
||||||
|
"jni_name=null");
|
||||||
|
|
||||||
|
/* Create OrderedRtMemRefDict to be constructed and passed to the model shared
|
||||||
|
* library */
|
||||||
|
JNI_COND(OrderedRtMemRefDict *, ormrd, createOrderedRtMemRefDict(), NULL, env,
|
||||||
|
japi->ecpt_cls, "ormrd=null");
|
||||||
|
|
||||||
|
/* Loop through all the ormrd_rmrs and ormrd_names */
|
||||||
|
for (int i = 0; i < ormrd_rmrs_len; i++) {
|
||||||
|
JNI_VAR_CALL(
|
||||||
|
env, obj_rmr[i], (*env)->GetObjectArrayElement(env, ormrd_rmrs, i));
|
||||||
|
JNI_VAR_CALL(
|
||||||
|
env, obj_name[i], (*env)->GetObjectArrayElement(env, ormrd_names, i));
|
||||||
|
|
||||||
|
/* Get type, rank, data, sizes, and strides by calling corresponding methods
|
||||||
|
*/
|
||||||
|
JNI_TYPE_VAR_CALL(env, jint, rmr_type,
|
||||||
|
(*env)->CallIntMethod(env, obj_rmr[i], japi->rmr_getType));
|
||||||
|
JNI_TYPE_VAR_CALL(env, jint, rmr_rank,
|
||||||
|
(*env)->CallIntMethod(env, obj_rmr[i], japi->rmr_getRank));
|
||||||
|
JNI_TYPE_VAR_CALL(env, jlong, rmr_datasize,
|
||||||
|
(*env)->CallLongMethod(env, obj_rmr[i], japi->rmr_getDataSize));
|
||||||
|
JNI_TYPE_VAR_CALL(env, jobject, rmr_data,
|
||||||
|
(*env)->CallObjectMethod(env, obj_rmr[i], japi->rmr_getData));
|
||||||
|
JNI_TYPE_VAR_CALL(env, jobject, rmr_sizes,
|
||||||
|
(*env)->CallObjectMethod(env, obj_rmr[i], japi->rmr_getSizes));
|
||||||
|
JNI_TYPE_VAR_CALL(env, jobject, rmr_strides,
|
||||||
|
(*env)->CallObjectMethod(env, obj_rmr[i], japi->rmr_getStrides));
|
||||||
|
|
||||||
|
/* Primitive type int and long can be directly used */
|
||||||
|
int jni_type = rmr_type, jni_rank = rmr_rank;
|
||||||
|
long jni_datasize = rmr_datasize;
|
||||||
|
|
||||||
|
/* Get direct buffer associated with data */
|
||||||
|
JNI_TYPE_VAR_CALL(
|
||||||
|
env, void *, jni_data, (*env)->GetDirectBufferAddress(env, rmr_data));
|
||||||
|
|
||||||
|
/* Get long array associated with sizes and strides */
|
||||||
|
JNI_TYPE_VAR_CALL(env, long *, jni_sizes,
|
||||||
|
(*env)->GetLongArrayElements(env, rmr_sizes, NULL));
|
||||||
|
JNI_TYPE_VAR_CALL(env, long *, jni_strides,
|
||||||
|
(*env)->GetLongArrayElements(env, rmr_strides, NULL));
|
||||||
|
|
||||||
|
/* Print debug info on what we got from the Java side */
|
||||||
|
RMR_DEBUG(
|
||||||
|
i, jni_type, jni_rank, jni_sizes, jni_strides, jni_data, jni_datasize);
|
||||||
|
|
||||||
|
/* Create native RtMemRef struct and fill in its fields */
|
||||||
|
jni_rmr[i] = createRtMemRef(jni_rank);
|
||||||
|
setDType(jni_rmr[i], jni_type);
|
||||||
|
setData(jni_rmr[i], jni_data);
|
||||||
|
setSizes(jni_rmr[i], jni_sizes);
|
||||||
|
setStrides(jni_rmr[i], jni_strides);
|
||||||
|
|
||||||
|
/*jni_name[i] = (*env)->GetStringUTFChars(env, obj_name[i], NULL);
|
||||||
|
printf("jni_name=%s\n", jni_name[i]);*/
|
||||||
|
|
||||||
|
/* Install RtMemRef into OrderedRtMemRefDict */
|
||||||
|
setRtMemRef(ormrd, i, jni_rmr[i]);
|
||||||
|
|
||||||
|
/* Release reference to the java objects */
|
||||||
|
JNI_CALL(
|
||||||
|
env, (*env)->ReleaseLongArrayElements(env, rmr_sizes, jni_sizes, 0));
|
||||||
|
JNI_CALL(env,
|
||||||
|
(*env)->ReleaseLongArrayElements(env, rmr_strides, jni_strides, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* setRtMemRef(ormrd, jni_rmr, jni_name); */
|
||||||
|
return ormrd;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Convert native data structure to Java object */
|
||||||
|
jobject ormrd_native_to_java(
|
||||||
|
JNIEnv *env, jclass cls, OrderedRtMemRefDict *dict, jniapi_t *japi) {
|
||||||
|
JNI_COND(int, nrmr, numRtMemRefs(dict), 0, env, japi->ecpt_cls, "nrmr=0");
|
||||||
|
|
||||||
|
/* Create RtMemRef java object array */
|
||||||
|
JNI_TYPE_VAR_CALL(env, jobjectArray, rmrs,
|
||||||
|
(*env)->NewObjectArray(env, nrmr, japi->rmr_cls, NULL));
|
||||||
|
|
||||||
|
/* Loop through the native RtMemRef structs */
|
||||||
|
for (int i = 0; i < nrmr; i++) {
|
||||||
|
JNI_COND(RtMemRef *, rmr, getRtMemRef(dict, i), NULL, env, japi->ecpt_cls,
|
||||||
|
"rmr[%d]=null", i);
|
||||||
|
|
||||||
|
JNI_COND(int, jni_type, getDType(rmr), 0, env, japi->ecpt_cls,
|
||||||
|
"rmr[%d]:type=0", i);
|
||||||
|
JNI_COND(int, jni_rank, getRank(rmr), 0, env, japi->ecpt_cls,
|
||||||
|
"rmr[%d]:rank=0", i);
|
||||||
|
JNI_COND(long, jni_datasize, getDataSize(rmr), 0, env, japi->ecpt_cls,
|
||||||
|
"rmr[%d]:datasize=0", i);
|
||||||
|
JNI_COND(void *, jni_data, getData(rmr), NULL, env, japi->ecpt_cls,
|
||||||
|
"rmr[%d]:data=null", i);
|
||||||
|
JNI_COND(long *, jni_sizes, getSizes(rmr), NULL, env, japi->ecpt_cls,
|
||||||
|
"rmr[%d]:sizes=null", i);
|
||||||
|
JNI_COND(long *, jni_strides, getStrides(rmr), NULL, env, japi->ecpt_cls,
|
||||||
|
"rmr[%d]:strides=null", i);
|
||||||
|
|
||||||
|
/* Print debug info on what we got from the native side */
|
||||||
|
RMR_DEBUG(
|
||||||
|
i, jni_type, jni_rank, jni_sizes, jni_strides, jni_data, jni_datasize);
|
||||||
|
|
||||||
|
/* create the following Java objects:
|
||||||
|
* - RtMemRef
|
||||||
|
* - DirectByteBuffer (from native buffers)
|
||||||
|
* - long array for sizes and strides
|
||||||
|
*/
|
||||||
|
JNI_TYPE_VAR_CALL(env, jobject, obj_rmr,
|
||||||
|
(*env)->NewObject(env, japi->rmr_cls, japi->rmr_constructor, jni_rank));
|
||||||
|
JNI_TYPE_VAR_CALL(env, jobject, rmr_data,
|
||||||
|
(*env)->NewDirectByteBuffer(
|
||||||
|
env, jni_data, jni_datasize * onnx_type_size[jni_type]));
|
||||||
|
JNI_TYPE_VAR_CALL(
|
||||||
|
env, jlongArray, rmr_sizes, (*env)->NewLongArray(env, jni_rank));
|
||||||
|
JNI_TYPE_VAR_CALL(
|
||||||
|
env, jlongArray, rmr_strides, (*env)->NewLongArray(env, jni_rank));
|
||||||
|
|
||||||
|
/* Call setType method */
|
||||||
|
JNI_CALL(env,
|
||||||
|
(*env)->CallObjectMethod(env, obj_rmr, japi->rmr_setType, jni_type));
|
||||||
|
|
||||||
|
/* Call setData method */
|
||||||
|
JNI_CALL(env,
|
||||||
|
(*env)->CallObjectMethod(env, obj_rmr, japi->rmr_setData, rmr_data));
|
||||||
|
|
||||||
|
/* Fill in sizes array from native array and call setSizes method */
|
||||||
|
JNI_CALL(env,
|
||||||
|
(*env)->SetLongArrayRegion(env, rmr_sizes, 0, jni_rank, jni_sizes));
|
||||||
|
JNI_CALL(env,
|
||||||
|
(*env)->CallObjectMethod(env, obj_rmr, japi->rmr_setSizes, rmr_sizes));
|
||||||
|
|
||||||
|
/* Fill in strides array from native array and call setStrides method */
|
||||||
|
JNI_CALL(env,
|
||||||
|
(*env)->SetLongArrayRegion(env, rmr_strides, 0, jni_rank, jni_strides));
|
||||||
|
JNI_CALL(env, (*env)->CallObjectMethod(
|
||||||
|
env, obj_rmr, japi->rmr_setStrides, rmr_strides));
|
||||||
|
|
||||||
|
/* Set DynMemRef object in the object array */
|
||||||
|
JNI_CALL(env, (*env)->SetObjectArrayElement(env, rmrs, i, obj_rmr));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Create the OrderedRtMemRefDict java object */
|
||||||
|
JNI_TYPE_VAR_CALL(env, jobject, ormrd,
|
||||||
|
(*env)->NewObject(env, japi->ormrd_cls, japi->ormrd_constructor, rmrs));
|
||||||
|
|
||||||
|
return ormrd;
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jobject JNICALL Java_com_ibm_onnxmlir_DynEntryPoint_main_1graph_1jni(
|
||||||
|
JNIEnv *env, jclass cls, jobject obj) {
|
||||||
|
CHECK_CALL(jniapi_t *, japi, fill_jniapi(env, &jniapi), NULL);
|
||||||
|
|
||||||
|
CHECK_CALL(OrderedRtMemRefDict *, input_ormrd,
|
||||||
|
ormrd_java_to_native(env, cls, obj, japi), NULL);
|
||||||
|
|
||||||
|
CHECK_CALL(OrderedRtMemRefDict *, dict,
|
||||||
|
_dyn_entry_point_main_graph(input_ormrd), NULL);
|
||||||
|
|
||||||
|
CHECK_CALL(
|
||||||
|
jobject, output_ormrd, ormrd_native_to_java(env, cls, dict, japi), NULL);
|
||||||
|
|
||||||
|
return output_ormrd;
|
||||||
|
}
|
|
@ -0,0 +1,64 @@
|
||||||
|
package com.ibm.onnxmlir;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileNotFoundException;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.net.URISyntaxException;
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.StandardCopyOption;
|
||||||
|
import java.util.jar.JarFile;
|
||||||
|
|
||||||
|
public class DynEntryPoint {
|
||||||
|
static String libname = "libmodel.so";
|
||||||
|
|
||||||
|
static {
|
||||||
|
File jar;
|
||||||
|
JarFile jf;
|
||||||
|
String jarDir = null;
|
||||||
|
String libPath = null;
|
||||||
|
try {
|
||||||
|
// Get path name of jar
|
||||||
|
jar = new File(DynEntryPoint.class.getProtectionDomain()
|
||||||
|
.getCodeSource()
|
||||||
|
.getLocation().toURI());
|
||||||
|
jarDir = jar.getParentFile().getAbsolutePath();
|
||||||
|
libPath = jarDir + "/" + libname;
|
||||||
|
|
||||||
|
// Open jar file to read and check libname inside jar.
|
||||||
|
// If IOException thrown, load .so from where .jar is.
|
||||||
|
//
|
||||||
|
// Checking whether DynEntryPoint.class.getResourceAsStream returns null
|
||||||
|
// does NOT work. Because it checks whether the resource is
|
||||||
|
// available on the classpath, not only just inside the jar file.
|
||||||
|
jf = new JarFile(jar);
|
||||||
|
if (jf.getEntry(libname) != null) {
|
||||||
|
File libFile = new File(libPath);
|
||||||
|
// Copy .so to where jar is
|
||||||
|
Files.copy(jf.getInputStream(jf.getEntry(libname)),
|
||||||
|
libFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
|
||||||
|
// z/OS USS requires "x" permission bit
|
||||||
|
libFile.setExecutable(true, false);
|
||||||
|
// Load the temporary .so copy
|
||||||
|
System.load(libPath);
|
||||||
|
// POSIX can unlink file after loading
|
||||||
|
libFile.delete();
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// Throw subclass of IOException
|
||||||
|
throw new FileNotFoundException(".so not found inside jar");
|
||||||
|
}
|
||||||
|
} catch (URISyntaxException e) {
|
||||||
|
// Failed to find jar path, assume the .so is in cwd
|
||||||
|
System.load(libname);
|
||||||
|
} catch (IOException e) {
|
||||||
|
// .so not found in jar, assume it's where jar is
|
||||||
|
System.load(libPath);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static native OrderedRtMemRefDict main_graph_jni(OrderedRtMemRefDict ormrd);
|
||||||
|
|
||||||
|
public static OrderedRtMemRefDict main_graph(OrderedRtMemRefDict ormrd) {
|
||||||
|
return main_graph_jni(ormrd);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,118 @@
|
||||||
|
package com.ibm.onnxmlir;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
|
||||||
|
public class OrderedRtMemRefDict {
|
||||||
|
|
||||||
|
private RtMemRef[] _rmrs;
|
||||||
|
private String[] _names;
|
||||||
|
private HashMap<String, Integer> _n2i;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructor
|
||||||
|
*
|
||||||
|
* @param rmrs DynMemRef array
|
||||||
|
*/
|
||||||
|
public OrderedRtMemRefDict(RtMemRef[] rmrs) {
|
||||||
|
this(rmrs, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructor
|
||||||
|
*
|
||||||
|
* @param rmrs DynMemRef array
|
||||||
|
* @param names name array
|
||||||
|
*/
|
||||||
|
public OrderedRtMemRefDict(RtMemRef[] rmrs, String[] names) {
|
||||||
|
/* rmrs cannot be null or empty */
|
||||||
|
if (rmrs == null || rmrs.length == 0)
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Number of dmrs is invalid");
|
||||||
|
|
||||||
|
/* If names is null or empty, construct a default one with
|
||||||
|
* index as name.
|
||||||
|
*/
|
||||||
|
if (names == null || names.length == 0) {
|
||||||
|
names = new String[rmrs.length];
|
||||||
|
for (int i = 0; i < rmrs.length; i++)
|
||||||
|
names[i] = Integer.toString(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Number of rmrs and names must match */
|
||||||
|
if (rmrs.length != names.length)
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Number of dmrs and names do not match");
|
||||||
|
|
||||||
|
/* Establish name to index mapping. Individual rmr is
|
||||||
|
* checked for validity.
|
||||||
|
*/
|
||||||
|
_n2i = new HashMap<String, Integer>();
|
||||||
|
for (int i = 0; i < names.length; i++) {
|
||||||
|
if (rmrs[i] == null || !rmrs[i].validRmr())
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"rmrs[" + i + "] is invalid");
|
||||||
|
if (_n2i.put(names[i], i) != null)
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"name[" + i + "] = " + names[i] + " not unique");
|
||||||
|
}
|
||||||
|
_rmrs = rmrs;
|
||||||
|
_names = names;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* RtMemRef getter by index
|
||||||
|
*
|
||||||
|
* @param idx index of RtMemRef instance to get
|
||||||
|
* @return RtMemRef instance
|
||||||
|
*/
|
||||||
|
public RtMemRef getRmrbyIndex(int idx) {
|
||||||
|
return _rmrs[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* RtMemRef getter by name
|
||||||
|
*
|
||||||
|
* @param name name of RtMemRef instance to get
|
||||||
|
* @return RtMemRef instance
|
||||||
|
*/
|
||||||
|
public RtMemRef getRmrByName(String name) {
|
||||||
|
return _rmrs[_n2i.get(name)];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* RtMemRef array getter
|
||||||
|
*
|
||||||
|
* @return RtMemRef array
|
||||||
|
*/
|
||||||
|
public RtMemRef[] getRmrs() {
|
||||||
|
return _rmrs;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Name getter
|
||||||
|
*
|
||||||
|
* @param idx index of name to get
|
||||||
|
* @return name string
|
||||||
|
*/
|
||||||
|
public String getName(int idx) {
|
||||||
|
return _names[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Name array getter
|
||||||
|
*
|
||||||
|
* @return name array
|
||||||
|
*/
|
||||||
|
public String[] getNames() {
|
||||||
|
return _names;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* RtMemRef array size getter
|
||||||
|
*
|
||||||
|
* @return RtMemRef array size
|
||||||
|
*/
|
||||||
|
public int size() {
|
||||||
|
return _rmrs.length;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,400 @@
|
||||||
|
package com.ibm.onnxmlir;
|
||||||
|
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.nio.ByteOrder;
|
||||||
|
import java.nio.DoubleBuffer;
|
||||||
|
import java.nio.FloatBuffer;
|
||||||
|
import java.nio.IntBuffer;
|
||||||
|
import java.nio.LongBuffer;
|
||||||
|
import java.nio.ShortBuffer;
|
||||||
|
|
||||||
|
public class RtMemRef {
|
||||||
|
final ByteOrder endian = ByteOrder.nativeOrder();
|
||||||
|
|
||||||
|
/* We can use enum but that creates another class
|
||||||
|
* which complicates things for JNI.
|
||||||
|
*/
|
||||||
|
final int ONNX_TYPE_UNDEFINED = 0;
|
||||||
|
final int ONNX_TYPE_FLOAT = 1;
|
||||||
|
final int ONNX_TYPE_UINT8 = 2;
|
||||||
|
final int ONNX_TYPE_INT8 = 3;
|
||||||
|
final int ONNX_TYPE_UINT16 = 4;
|
||||||
|
final int ONNX_TYPE_INT16 = 5;
|
||||||
|
final int ONNX_TYPE_INT32 = 6;
|
||||||
|
final int ONNX_TYPE_INT64 = 7;
|
||||||
|
final int ONNX_TYPE_STRING = 8;
|
||||||
|
final int ONNX_TYPE_BOOL = 9;
|
||||||
|
final int ONNX_TYPE_FLOAT16 = 10;
|
||||||
|
final int ONNX_TYPE_DOUBLE = 11;
|
||||||
|
final int ONNX_TYPE_UINT32 = 12;
|
||||||
|
final int ONNX_TYPE_UINT64 = 13;
|
||||||
|
final int ONNX_TYPE_COMPLEX64 = 14;
|
||||||
|
final int ONNX_TYPE_COMPLEX128 = 15;
|
||||||
|
final int ONNX_TYPE_BFLOAT16 = 16;
|
||||||
|
|
||||||
|
final int[] ONNX_TYPE_SIZE = new int[] {
|
||||||
|
0, /* UNDEFINED */
|
||||||
|
4, /* FLOAT */
|
||||||
|
1, /* UINT8 */
|
||||||
|
1, /* INT8 */
|
||||||
|
2, /* UINT16 */
|
||||||
|
2, /* INT16 */
|
||||||
|
4, /* INT32 */
|
||||||
|
8, /* INT64 */
|
||||||
|
0, /* STRING */
|
||||||
|
1, /* BOOL */
|
||||||
|
2, /* FLOAT16 */
|
||||||
|
8, /* DOUBLE */
|
||||||
|
4, /* UINT32 */
|
||||||
|
8, /* UINT64 */
|
||||||
|
8, /* COMPLEX64 */
|
||||||
|
16, /* COMPLEX128 */
|
||||||
|
2, /* BFLOAT16 */
|
||||||
|
};
|
||||||
|
|
||||||
|
private ByteBuffer _data;
|
||||||
|
private int _type;
|
||||||
|
private int _rank;
|
||||||
|
private long[] _sizes;
|
||||||
|
private long[] _strides;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructor
|
||||||
|
*/
|
||||||
|
public RtMemRef(int rank) {
|
||||||
|
if (rank <= 0)
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"invalid rank " + rank);
|
||||||
|
_data = null;
|
||||||
|
_type = ONNX_TYPE_UNDEFINED;
|
||||||
|
_rank = rank;
|
||||||
|
_sizes = new long[rank];
|
||||||
|
_strides = new long[rank];
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ---------- Data type getter and setter ---------- */
|
||||||
|
/* For JNI wrapper only. Not intended for end user. */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Type getter
|
||||||
|
*
|
||||||
|
* @return data type
|
||||||
|
*/
|
||||||
|
@SuppressWarnings("unused")
|
||||||
|
private int getType() {
|
||||||
|
return _type;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Type setter
|
||||||
|
*
|
||||||
|
* @param type data type to be set
|
||||||
|
*/
|
||||||
|
@SuppressWarnings("unused")
|
||||||
|
private void setType(int type) {
|
||||||
|
_type = type;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ---------- Raw data getter and setter ---------- */
|
||||||
|
/* For JNI wrapper only. Not intended for end user. */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Raw data getter
|
||||||
|
*
|
||||||
|
* @return raw data
|
||||||
|
*/
|
||||||
|
@SuppressWarnings("unused")
|
||||||
|
private ByteBuffer getData() {
|
||||||
|
return _data;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Raw data setter
|
||||||
|
*
|
||||||
|
* @param data raw data to be set
|
||||||
|
*/
|
||||||
|
@SuppressWarnings("unused")
|
||||||
|
private void setData(ByteBuffer data) {
|
||||||
|
_data = data.order(endian);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ---------- Byte data getter and setter ---------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Byte data getter
|
||||||
|
*
|
||||||
|
* @return byte data array
|
||||||
|
*/
|
||||||
|
public byte[] getByteData() {
|
||||||
|
if (_data == null) return null;
|
||||||
|
|
||||||
|
/* asReadOnlyBuffer() creates a new view so the position of the
|
||||||
|
* original data will stay at 0 for subsequent getByteData()
|
||||||
|
* after get(b).
|
||||||
|
*/
|
||||||
|
byte[] b = new byte[_data.limit()];
|
||||||
|
_data.asReadOnlyBuffer().get(b);
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Byte data setter
|
||||||
|
*
|
||||||
|
* @param data byte array to be set
|
||||||
|
*/
|
||||||
|
public void setByteData(byte[] data) {
|
||||||
|
/* slice() creates a new view so the position of the
|
||||||
|
* original data will stay at 0 for getByteData() after put(data).
|
||||||
|
*/
|
||||||
|
_data = ByteBuffer.allocateDirect(data.length);
|
||||||
|
_data.slice().put(data);
|
||||||
|
_type = ONNX_TYPE_INT8;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ---------- Short data getter and setter ---------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Short data getter
|
||||||
|
*
|
||||||
|
* @return short data array
|
||||||
|
*/
|
||||||
|
public short[] getShortData() {
|
||||||
|
if (_data == null) return null;
|
||||||
|
|
||||||
|
/* asShortBuffer() creates a new view so the position of the
|
||||||
|
* original data will stay at 0 for subsequent getShortData()
|
||||||
|
* after get(s).
|
||||||
|
*/
|
||||||
|
ShortBuffer sb = _data.asShortBuffer();
|
||||||
|
short[] s = new short[sb.limit()];
|
||||||
|
sb.get(s);
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Short data setter
|
||||||
|
*
|
||||||
|
* @param data short array to be set
|
||||||
|
*/
|
||||||
|
public void setShortData(short[] data) {
|
||||||
|
/* asShortBuffer() creates a new view so the position of the
|
||||||
|
* original data will stay at 0 for getShortData() after put(data).
|
||||||
|
*/
|
||||||
|
_data = ByteBuffer.allocateDirect(data.length*2).order(endian);
|
||||||
|
_data.asShortBuffer().put(data);
|
||||||
|
_type = ONNX_TYPE_INT16;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ---------- Int data getter and setter ---------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Int data getter
|
||||||
|
*
|
||||||
|
* @return int data array
|
||||||
|
*/
|
||||||
|
public int[] getIntData() {
|
||||||
|
if (_data == null) return null;
|
||||||
|
|
||||||
|
/* asIntBuffer() creates a new view so the position of the
|
||||||
|
* original data will stay at 0 for subsequent getIntData()
|
||||||
|
* after get(i).
|
||||||
|
*/
|
||||||
|
IntBuffer ib = _data.asIntBuffer();
|
||||||
|
int[] i = new int[ib.limit()];
|
||||||
|
ib.get(i);
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Int data setter
|
||||||
|
*
|
||||||
|
* @param data int array to be set
|
||||||
|
*/
|
||||||
|
public void setIntData(int[] data) {
|
||||||
|
/* asIntBuffer() creates a new view so the position of the
|
||||||
|
* original data will stay at 0 for getIntData() after put(data).
|
||||||
|
*/
|
||||||
|
_data = ByteBuffer.allocateDirect(data.length*4).order(endian);
|
||||||
|
_data.asIntBuffer().put(data);
|
||||||
|
_type = ONNX_TYPE_INT32;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ---------- Long data getter and setter ---------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Long data getter
|
||||||
|
*
|
||||||
|
* @return long data array
|
||||||
|
*/
|
||||||
|
public long[] getLongData() {
|
||||||
|
if (_data == null) return null;
|
||||||
|
|
||||||
|
/* asLongBuffer() creates a new view so the position of the
|
||||||
|
* original data will stay at 0 for subsequent getLongData()
|
||||||
|
* after get(l).
|
||||||
|
*/
|
||||||
|
LongBuffer lb = _data.asLongBuffer();
|
||||||
|
long[] l = new long[lb.limit()];
|
||||||
|
lb.get(l);
|
||||||
|
return l;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Long data setter
|
||||||
|
*
|
||||||
|
* @param data long array to be set
|
||||||
|
*/
|
||||||
|
public void setLongData(long[] data) {
|
||||||
|
/* asLongBuffer() creates a new view so the position of the
|
||||||
|
* original data will stay at 0 for getLongData() after put(data).
|
||||||
|
*/
|
||||||
|
_data = ByteBuffer.allocateDirect(data.length*8).order(endian);
|
||||||
|
_data.asLongBuffer().put(data);
|
||||||
|
_type = ONNX_TYPE_INT64;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ---------- Float data getter and setter ---------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Float data getter
|
||||||
|
*
|
||||||
|
* @return float data array
|
||||||
|
*/
|
||||||
|
public float[] getFloatData() {
|
||||||
|
if (_data == null) return null;
|
||||||
|
|
||||||
|
/* asFloatBuffer() creates a new view so the position of the
|
||||||
|
* original data will stay at 0 for subsequent getFloatData()
|
||||||
|
* after get(f).
|
||||||
|
*/
|
||||||
|
FloatBuffer fb = _data.asFloatBuffer();
|
||||||
|
float[] f = new float[fb.limit()];
|
||||||
|
fb.get(f);
|
||||||
|
return f;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Float data setter
|
||||||
|
*
|
||||||
|
* @param data float array to be set
|
||||||
|
*/
|
||||||
|
public void setFloatData(float[] data) {
|
||||||
|
/* asFloatBuffer() creates a new view so the position of the
|
||||||
|
* original data will stay at 0 for getFloatData() after put(data).
|
||||||
|
*/
|
||||||
|
_data = ByteBuffer.allocateDirect(data.length*4).order(endian);
|
||||||
|
_data.asFloatBuffer().put(data);
|
||||||
|
_type = ONNX_TYPE_FLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ---------- Double data getter and setter ---------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Double data getter
|
||||||
|
*
|
||||||
|
* @return double data array
|
||||||
|
*/
|
||||||
|
public double[] getDoubleData() {
|
||||||
|
if (_data == null) return null;
|
||||||
|
|
||||||
|
/* asDoubleBuffer() creates a new view so the position of the
|
||||||
|
* original data will stay at 0 for subsequent getDoubleData()
|
||||||
|
* after get(d).
|
||||||
|
*/
|
||||||
|
DoubleBuffer db = _data.asDoubleBuffer();
|
||||||
|
double[] d = new double[db.limit()];
|
||||||
|
db.get(d);
|
||||||
|
return d;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Double data setter
|
||||||
|
*
|
||||||
|
* @param data double array to be set
|
||||||
|
*/
|
||||||
|
public void setDoubleData(double[] data) {
|
||||||
|
/* asDoubleBuffer() creates a new view so the position of the
|
||||||
|
* original data will stay at 0 for getDoubleData() after put(data).
|
||||||
|
*/
|
||||||
|
_data = ByteBuffer.allocateDirect(data.length*8).order(endian);
|
||||||
|
_data.asDoubleBuffer().put(data);
|
||||||
|
_type = ONNX_TYPE_DOUBLE;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Rank getter
|
||||||
|
*
|
||||||
|
* @return rank
|
||||||
|
*/
|
||||||
|
public int getRank() {
|
||||||
|
return _rank;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ---------- Sizes getter and setter ---------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sizes getter
|
||||||
|
*
|
||||||
|
* @return sizes array
|
||||||
|
*/
|
||||||
|
public long[] getSizes() {
|
||||||
|
return _sizes;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sizes setter
|
||||||
|
*
|
||||||
|
* @param sizes sizes array to be set
|
||||||
|
*/
|
||||||
|
public void setSizes(long[] sizes) {
|
||||||
|
if (sizes.length != _rank)
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"array length " + sizes.length + " != rank " + _rank);
|
||||||
|
_sizes = sizes.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ---------- Strides getter and setter ---------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Strides getter
|
||||||
|
*
|
||||||
|
* @return strides array
|
||||||
|
*/
|
||||||
|
public long[] getStrides() {
|
||||||
|
return _strides;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Strides setter
|
||||||
|
*
|
||||||
|
* @param strides strides array to be set
|
||||||
|
*/
|
||||||
|
public void setStrides(long[] strides) {
|
||||||
|
if (strides.length != _rank)
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"array length " + strides.length + " != rank " + _rank);
|
||||||
|
_strides = strides.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Size getter
|
||||||
|
*
|
||||||
|
* @return product of sizes array, i.e., total number of data elements
|
||||||
|
*/
|
||||||
|
public long getDataSize() {
|
||||||
|
long n = _sizes[0];
|
||||||
|
for (int i = 1; i < _sizes.length; i++) n *= _sizes[i];
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check validity of RtMemRef
|
||||||
|
*
|
||||||
|
* @return true if RtMemRef is valid, false otherwise
|
||||||
|
*/
|
||||||
|
public boolean validRmr() {
|
||||||
|
return (_data != null &&
|
||||||
|
_data.limit() != 0 &&
|
||||||
|
_data.limit() == getDataSize() * ONNX_TYPE_SIZE[_type]);
|
||||||
|
}
|
||||||
|
}
|
|
@ -4,12 +4,6 @@ add_dependencies(onnx-mlir-opt OMKrnlOpsInc OMONNXOpsInc)
|
||||||
target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_SRC_ROOT})
|
target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_SRC_ROOT})
|
||||||
target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_BIN_ROOT})
|
target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_BIN_ROOT})
|
||||||
|
|
||||||
set(ONNX_MLIR_LD_PRELOAD_onnx-mlir-opt "" CACHE STRING "" FORCE)
|
|
||||||
if(BUILD_SHARED_LIBS)
|
|
||||||
message(STATUS "To run dynamically linked onnx-mlir-opt, you must specify:")
|
|
||||||
message(STATUS "LD_PRELOAD=${ONNX_MLIR_LD_PRELOAD_onnx-mlir-opt}")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
target_link_libraries(onnx-mlir-opt
|
target_link_libraries(onnx-mlir-opt
|
||||||
${OMLibs}
|
${OMLibs}
|
||||||
${MLIRLibs}
|
${MLIRLibs}
|
||||||
|
|
|
@ -12,6 +12,7 @@ using namespace std;
|
||||||
using namespace onnx_mlir;
|
using namespace onnx_mlir;
|
||||||
|
|
||||||
int main(int argc, char *argv[]) {
|
int main(int argc, char *argv[]) {
|
||||||
|
setExecPath(argv[0], (void *)main);
|
||||||
registerDialects();
|
registerDialects();
|
||||||
|
|
||||||
llvm::cl::OptionCategory OnnxMlirOptions(
|
llvm::cl::OptionCategory OnnxMlirOptions(
|
||||||
|
@ -33,7 +34,9 @@ int main(int argc, char *argv[]) {
|
||||||
clEnumVal(EmitLLVMIR, "Lower model to LLVM IR (LLVM dialect)."),
|
clEnumVal(EmitLLVMIR, "Lower model to LLVM IR (LLVM dialect)."),
|
||||||
clEnumVal(EmitLib, "Lower model to LLVM IR, emit (to file) "
|
clEnumVal(EmitLib, "Lower model to LLVM IR, emit (to file) "
|
||||||
"LLVM bitcode for model, compile and link it to a "
|
"LLVM bitcode for model, compile and link it to a "
|
||||||
"shared library.")),
|
"shared library."),
|
||||||
|
clEnumVal(EmitJNI, "Lower model to LLMV IR -> LLVM bitcode "
|
||||||
|
"-> JNI shared library -> jar")),
|
||||||
llvm::cl::init(EmitLib), llvm::cl::cat(OnnxMlirOptions));
|
llvm::cl::init(EmitLib), llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
llvm::cl::HideUnrelatedOptions(OnnxMlirOptions);
|
llvm::cl::HideUnrelatedOptions(OnnxMlirOptions);
|
||||||
|
|
|
@ -13,6 +13,8 @@
|
||||||
#include "src/MainUtils.hpp"
|
#include "src/MainUtils.hpp"
|
||||||
#include "src/Runtime/ExecusionSession.hpp"
|
#include "src/Runtime/ExecusionSession.hpp"
|
||||||
|
|
||||||
|
#define SHARED_LIB_BASE string("./TestConv_main_graph")
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
// Returns whether onnx-mlir compiled convolution is producing the same results
|
// Returns whether onnx-mlir compiled convolution is producing the same results
|
||||||
|
@ -87,14 +89,9 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H,
|
||||||
|
|
||||||
OwningModuleRef moduleRef(module);
|
OwningModuleRef moduleRef(module);
|
||||||
|
|
||||||
llvm::SmallVector<char, 10> path;
|
compileModule(moduleRef, ctx, SHARED_LIB_BASE, EmitLib);
|
||||||
llvm::sys::fs::createTemporaryFile("_main_graph", "", path);
|
|
||||||
string pathStr(path.begin(), path.end());
|
|
||||||
llvm::FileRemover remover(path);
|
|
||||||
|
|
||||||
compileModule(moduleRef, ctx, pathStr, EmitLib);
|
|
||||||
onnx_mlir::ExecutionSession sess(
|
onnx_mlir::ExecutionSession sess(
|
||||||
pathStr + ".so", "_dyn_entry_point_main_graph");
|
SHARED_LIB_BASE + ".so", "_dyn_entry_point_main_graph");
|
||||||
|
|
||||||
std::vector<unique_ptr<RtMemRef>> inputs;
|
std::vector<unique_ptr<RtMemRef>> inputs;
|
||||||
auto xRmr = unique_ptr<RtMemRef>(getRndRealRmr<float>({N, C, H, W}));
|
auto xRmr = unique_ptr<RtMemRef>(getRndRealRmr<float>({N, C, H, W}));
|
||||||
|
@ -127,7 +124,10 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H,
|
||||||
return isRmrClose<float>(conv.get(), ref);
|
return isRmrClose<float>(conv.get(), ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main(int argc, char *argv[]) {
|
||||||
|
setExecPath(argv[0], (void *)main);
|
||||||
|
llvm::FileRemover remover(SHARED_LIB_BASE + ".so");
|
||||||
|
|
||||||
// RapidCheck test case generation.
|
// RapidCheck test case generation.
|
||||||
rc::check("convolution implementation correctness", []() {
|
rc::check("convolution implementation correctness", []() {
|
||||||
const auto N = *rc::gen::inRange(1, 10);
|
const auto N = *rc::gen::inRange(1, 10);
|
||||||
|
|
Loading…
Reference in New Issue