Cleanup rtmemref api (#238)

* Detect llvm-project commit change in utils/clone-mlir.sh and rebuild llvm-project
for zLinux Jenkins build bot

* Cleanup RtMemRef API

-  use forward declaration to hide private data fields
-  RtMemRef.h: external user header, C/C++
- _RtMemRef.h: internal user header, C++ only
-  RtMemRef.hpp and RtMemRef.cpp: implementation header and file
-  add external APIs
       OrderedRtMemRefDict *ormrd_create(RtMemRef **rmrs, int n)
       RtMemRef **ormrd_getRmrs(OrderedRtMemRefDict *ormrd)
       int ormrd_getNumOfRmrs(OrderedRtMemRefDict *ormrd)
   for creating and querying OrderedRtMemRefDict with RtMemRef arrays
-  data buffer installed by rmr_setData() will be managed by user
-  unique_ptr<RtMemRef> must use custom deleter <RtMemRef,decltype(&rmr_destroy)>

* See if I have write access.

* Remove test CMake code.

* Use new API.

* Format code.

* Format code & rename variables for readability.

* Remove used API spec.

* Rename OrderedRtMemRefDict -> RtMemRefList, _dataMalloc -> _owningData.

* OrderedRtMemRefDict -> RtMemRefList

* Update KrnlToLLVM.cpp

* Trigger Jenkins

* Restart Jenkins

* OrderedRtMemRefDict -> RtRmrRefList

* More OrderedRtMemRefDict -> RtMemRefList.

* Format jni wrapper.

* Rename API functions to maintain stylistic consistency.

* Bug fix.

* Bug fix.

* Format code.

* Fix RtMemRefUtils.

* Format code.

* Using llvm function naming scheme.

* Rename runtime api file name to project name (onnx-mlir) as per convention.

* Include the new runtime header file.

* Reflect api header file name change in build script.

* Bug fix.

* Remove C++ code.

* Revert "Remove C++ code."

This reverts commit b217dfabae99e42db30721600cb5507866d4dc98.

* Clarify memory management responsibility.

* Add constructor to specify name & data ownership.

* Include stdbool.

* Remove dictionary semantics from RtMemRefList

* Bug fix.

* Format code.

* Format code.

* Use macro to define database of metadata.

* Prevent formatter from acting on metadata decl.

* Nit.

* Restore backend unit tests.

* Use spaces instead of tabs for better formatting.

* Use explicit template instantiation.

* Update RtMemRef struct doc.

* Make runtime compilable both in c and c++ mode.
Build two versions of the runtime library, one c version as the user-facing c runtime, and one c++ version as the one used inside this project.

* Bug fix, avoid stack allocation for output rmr list.

* Change _dyn_entry_point_main_graph -> run_main_graph for better memorability.

* Write a complete introductory tutorial on c99 Runtime and a test for it.

* Add onnx installation as dependency.

* Use target_include_directory to avoid installation.

* Format code.

* Fix cmake target_include_directories.

* Address compiler warning.

* First pass of RtMemRef->OMTensor.

* Second pass of RtMemRef -> OMTensor.

* nit, omtList -> omTensorList.

* omt -> omTensor for clarity.

* Rename OnnxMlirInternal.h -> OnnxMlirRuntime.hpp because there's no internal/external API, only C/C++ API.

* Format code.

* Restructure Runtime source code and move header -> /include and test -> /test/unit.

* Bugfix.

* Format code.

* Add unit test for OMTensor ctor.

* Update JNI CMake include directory.

* Bugfix.

* No need to re-declare ONNX types.

* Disable runtime doc test on non-x86 platforms.

* Rename OMTensor fields to be more sensible.

* Fix data type mismatch.

* size_t -> int64_t, prefer fixed width integers.

* Use consistent header guard style.

* Further tweak OMTensor API.

* Bugfix.

* Bugfix.

* Format code.

* Add doxygen config file.

* Tweak OMTensor API.

* Tweak API doc, hide OMTensorList implementation.

* Format code.

* Add new documentation item for Runtime API.

* Hide internal use only API declarations, move their comments to their implementations.

* Clarify ownership semantics in relevant API documentations.

* Fix PyRuntime.

* Remove alignment concerns from public API and include explaination of alignment issue in struct OMTensor definition.

* Print out unsupported numpy dtype.

* Use preferred way of type comparison in pybind11.

* Debug s390x issue.

* Remove debug code.

* Clarify semantics of strides/shape setter/getter, use \brief to include short description of API function.

* Improve documentation.

* Single out unpolished C++ API declarations.

* Clarify OMTensorList API.

* Bugfix.

* Bugfix.

* Assert after malloc.

* Handle malloc failures.

* Nit.

* Tweak legal notices.

* Format code.

* Remove doxygen generated files.

* Tweak legal notice format.

* Upgrade Cython

Numpy installation depends on Cython.

Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
gongsu832 2020-10-10 10:32:09 -04:00 committed by GitHub
parent 533d47acf1
commit 81c774ba5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
47 changed files with 4867 additions and 1176 deletions

View File

@ -8,7 +8,7 @@ jobs:
steps: steps:
- run: - run:
name: Installing GCC, CMake, Ninja, Protobuf name: Installing GCC, CMake, Ninja, Protobuf
command: sudo apt-get update && sudo apt-get install -y gcc g++ cmake ninja-build protobuf-compiler command: sudo apt-get update && sudo apt-get install -y gcc g++ cmake ninja-build protobuf-compiler && pip install --upgrade -q cython
- checkout: - checkout:
path: onnx-mlir path: onnx-mlir
- run: - run:

View File

@ -27,6 +27,7 @@ if (MSVC)
endif() endif()
add_subdirectory(third_party/onnx) add_subdirectory(third_party/onnx)
add_subdirectory(third_party/googletest) add_subdirectory(third_party/googletest)
SET(BENCHMARK_ENABLE_GTEST_TESTS OFF)
add_subdirectory(third_party/benchmark) add_subdirectory(third_party/benchmark)
add_subdirectory(third_party/pybind11) add_subdirectory(third_party/pybind11)
add_subdirectory(third_party/variant) add_subdirectory(third_party/variant)
@ -35,6 +36,7 @@ add_subdirectory(third_party/rapidcheck)
set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_STANDARD 14)
add_subdirectory(utils) add_subdirectory(utils)
add_subdirectory(include)
add_subdirectory(src) add_subdirectory(src)
add_subdirectory(docs) add_subdirectory(docs)
add_subdirectory(test) add_subdirectory(test)

2568
Doxyfile Normal file

File diff suppressed because it is too large Load Diff

View File

@ -11,10 +11,10 @@ toc:
# subfolderitems: # subfolderitems:
# - page: Placeholder # - page: Placeholder
# url: /piece1.html # url: /piece1.html
# - title: How-tos - title: How-tos
# subfolderitems: subfolderitems:
# - page: Placeholder - page: Runtime API
# url: /piece1.html url: /doxygen_html/index.html
- title: References - title: References
subfolderitems: subfolderitems:
- page: ONNX Dialect - page: ONNX Dialect

3
include/CMakeLists.txt Normal file
View File

@ -0,0 +1,3 @@
add_subdirectory(onnx-mlir)
install(FILES OnnxMlirRuntime.h DESTINATION include)

116
include/OnnxMlirRuntime.h Normal file
View File

@ -0,0 +1,116 @@
//===------- OnnxMlirRuntime.h - ONNX-MLIR Runtime API Declarations -------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains declaration of external OMTensor data structures and
// helper functions.
//
//===----------------------------------------------------------------------===//
#ifndef ONNX_MLIR_ONNXMLIRRUNTIME_H
#define ONNX_MLIR_ONNXMLIRRUNTIME_H
#ifdef __cplusplus
#include <cstdint>
#else
#include <stdbool.h>
#include <stdint.h>
#endif
#include <onnx-mlir/Runtime/OMTensor.h>
#include <onnx-mlir/Runtime/OMTensorList.h>
/*! \mainpage ONNX-MLIR Runtime API documentation
*
* \section intro_sec Introduction
*
* ONNX-MLIR project comes with an executable `onnx-mlir` capable
* of compiling onnx models to a shared library. In this documentation, we
* demonstrate how to interact programmatically with the compiled
* shared library using ONNX-MLIR's Runtime API.
*
* \section c-runtime-api C Runtime API
*
* \subsection data-structures Data Structures
*
* `OMTensor` is the data structure used to describe the runtime information
* (rank, shape, data type, etc) associated with a tensor input or output.
*
* `OMTensorList` is the data structure used to hold a list of pointers to
* OMTensor so that they can be passed into and out of the compiled model as
* inputs and outputs.
*
* \subsection model-entry-point-signature Model Entry Point Signature
*
* All compiled model will have the same exact C function signature equivalent
* to:
*
* ```c
* OMTensorList* run_main_graph(OMTensorList*);
* ```
*
* Intuitively, the model takes a list of tensors as input and returns a list of
* ensors as output.
*
* \subsection invoke-models-using-c-runtime-api Invoke Models Using C Runtime
* API
*
* We demonstrate using the API functions to run a simple ONNX model consisting
* of an add operation. To create such an onnx model, use this
* <a href="gen_add_onnx.py" target="_blank"><b>python script</b></a>
*
* To compile the above model, run `onnx-mlir add.onnx` and a binary library
* "add.so" should appear. We can use the following C code to call into the
* compiled function computing the sum of two inputs:
*
* ```c
* #include <OnnxMlirRuntime.h>
* #include <stdio.h>
*
* OMTensorList *run_main_graph(OMTensorList *);
*
* int main() {
* // Shared shape & rank.
* int64_t shape[] = {2, 2};
* int64_t rank = 2;
* // Construct x1 omt filled with 1.
* float x1Data[] = {1., 1., 1., 1., 1., 1.};
* int64_t *x1Shape = {2, 2};
* OMTensor *x1 = omTensorCreate(x1Data, shape, rank, ONNX_TYPE_FLOAT);
* // Construct x2 omt filled with 2.
* float x2Data[] = {2., 2., 2., 2., 2., 2.};
* int64_t *x2Shape = {2, 2};
* OMTensor *x2 = omTensorCreate(x2Data, shape, rank, ONNX_TYPE_FLOAT);
* // Construct a list of omts as input.
* OMTensor *list[2] = {x1, x2};
* OMTensorList *input = omTensorListCreate(list, 2);
* // Call the compiled onnx model function.
* OMTensorList *outputList = run_main_graph(input);
* // Get the first omt as output.
* OMTensor *y = omTensorListGetOmtByIndex(outputList, 0);
* float *outputPtr = (float *)omTensorGetDataPtr(y);
* // Print its content, should be all 3.
* for (int i = 0; i < 6; i++)
* printf("%f ", outputPtr[i]);
* return 0;
* }
* ```
*
* Compile with `gcc main.c add.so -o add`, you should see an executable `add`
* appearing. Run it, and the output should be:
*
* ```
* 3.000000 3.000000 3.000000 3.000000 3.000000 3.000000
* ```
* Exactly as it should be.
*
* \subsection reference Reference
*
* For full reference to available C Runtime API, refer to
*`include/onnx-mlir/Runtime/OMTensor.h` and
* `include/onnx-mlir/Runtime/OMTensorList.h`.
*
*/
#endif // ONNX_MLIR_ONNXMLIRRUNTIME_H

View File

@ -0,0 +1 @@
add_subdirectory(Runtime)

View File

@ -0,0 +1,2 @@
install(FILES OnnxDataType.h DESTINATION include/onnx-mlir/Runtime)
install(FILES OnnxDataTypeMetaData.inc DESTINATION include/onnx-mlir/Runtime)

View File

@ -0,0 +1,263 @@
//===-------------- OMTensor.h - OMTensor Declaration header --------------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains declaration of OMTensor data structures and
// API functions.
//
//===----------------------------------------------------------------------===//
#ifndef ONNX_MLIR_OMTENSOR_H
#define ONNX_MLIR_OMTENSOR_H
#ifdef __cplusplus
#include <algorithm>
#include <iostream>
#include <map>
#include <numeric>
#include <string>
#include <vector>
#else
#include <stdbool.h>
#endif // #ifdef __cplusplus
#ifdef __APPLE__
#include <stdlib.h>
#else
#include <malloc.h>
#endif // #ifdef __APPLE__
#include "onnx-mlir/Runtime/OnnxDataType.h"
/* Typically, MemRefs in MLIR context are used as a compile-time constructs.
* Information such as element type and rank of the data payload is statically
* encoded, meaning that they are determined and fixed at compile-time. This
* presents significant burden for any runtime components trying to interact
* with the compiled executable.
*
* Thus a version of MemRef struct that is amenable to runtime manipulation is
* provided as a basis for building any runtime-related components providing
* user-facing programming interfaces. All information are dynamically encoded
* as members of this struct so that they can be accessed and modified easily
* during runtime.
*
* We will refer to it as an OMTensor.
*/
struct OMTensor;
#ifndef __cplusplus
typedef struct OMTensor OMTensor;
#endif
/* Helper function to compute the number of data elements */
static inline int64_t getNumOfElems(int64_t *dataSizes, int rank) {
int64_t numElem = 1;
for (int i = 0; i < rank; i++)
numElem *= dataSizes[i];
return numElem;
}
/**
* \brief Create a OMTensor with specified data pointer, shape, rank and element
* type.
*
* The call will not create a copy of the data. By default, caller is
* responsible for managing the memory this pointer refers to. Namely, the
* OMTensor is not the owner of the data. To indicate OMTensor's ownership of
* data, use `omTensorCreateWithOwnership`. Ownership determines what happens
* with the OMTensor is destroyed. With ownership of the data, the destruction
* of the OMTensor will also free the data.
*
* @param data_ptr pointer to tensor data. By default, caller is responsible for
* managing the memory this pointer refers to.
* @param shape list of integers indicating the tensor shape.
* @param rank tensor rank.
* @param dtype tensor element data type.
* @return pointer to OMTensor created, NULL if creation failed.
*
*/
OMTensor *omTensorCreate(
void *data_ptr, int64_t *shape, int64_t rank, OM_DATA_TYPE dtype);
/**
* \brief Create an OMTensor with the specified shape, rank and element type,
* allocate uninitialized data for the specified shape.
*
* The OMTensor created using this constructor owns the underlying memory
* space allocated to the content of the tensor.
*
* @param shape list of integers indicating the tensor shape.
* @param rank tensor rank.
* @param dtype tensor element data type.
* @return pointer to OMTensor created, NULL if creation failed.
*
*/
OMTensor *omTensorCreateEmpty(int64_t *shape, int64_t rank, OM_DATA_TYPE dtype);
/**
* \brief Create an OMTensor with specified data pointer, shape, rank and
* element type, manually setting data ptr ownership.
*
* Using this constructor, users can
* specify whether OMTensor owns the data, which subsequently determines whether
* the memory space underlying the data will be freed or not when OMTensor gets
* destroyed.
*
* @param data_ptr pointer to tensor data.
* @param shape list of integers indicating the tensor shape.
* @param rank tensor rank.
* @param dtype tensor element data type.
* @param owning whether OMTensor owns the data, if set to true, OMTensor will
* release the data_ptr upon destruction.
* @return pointer to OMTensor created, NULL if creation failed.
*
*/
OMTensor *omTensorCreateWithOwnership(void *data_ptr, int64_t *shape,
int64_t rank, OM_DATA_TYPE dtype, int owning);
/**
* \brief Create an empty OMTensor with specified rank.
*
* This constructor returns a
* partially filled omTensor; prefer using the new omTensorCreateEmpty()
* function to fill shape & stride fields automatically.
*
* @param rank tensor rank
* @return pointer to OMTensor created, NULL if creation failed.
*
*/
OMTensor *omTensorCreateEmptyDeprecated(int rank);
/**
* \brief Destroy the OMTensor struct.
*
* If OMTensor does not own the data, destroying the omTensor does not free up
* the memory occupied by the tensor content. If OMTensor owns the data, this
* function will free up the memory space underlying the tensor as well. The
* documentation of OMTensor constructors clarifies the ownership semantics.
*
* @param tensor pointer to the OMTensor
*
*/
void omTensorDestroy(OMTensor *tensor);
/**
* \brief OMTensor data pointer getter.
*
* @param tensor pointer to the OMTensor
* @return pointer to the data buffer of the OMTensor,
* NULL if the data buffer is not set.
*/
void *omTensorGetDataPtr(OMTensor *tensor);
/**
* \brief OMTensor data shape getter.
*
* The data shape is returned as a pointer pointing to an array of
* n 64-bit integers where n is the rank of the tensor.
*
* The shape array is returned without copying, so caller should
* not free the returned pointer.
*
* @param tensor pointer to the OMTensor
* @return pointer to the data shape array.
*/
int64_t *omTensorGetDataShape(OMTensor *tensor);
/**
* \brief OMTensor data shape setter.
*
* n int64 elements are copied from the shape array to indicate the shape of the
* tensor, where n is the rank of the tensor.
*
* The shape array is copied without being freed, so caller is expected to
* manage the shape array oneself.
*
* @param tensor pointer to the OMTensor
* @param shape data sizes array to be set
*
* Set the data sizes array of the OMTensor to the values in the input array.
*/
void omTensorSetShape(OMTensor *tensor, int64_t *shape);
/**
* \brief OMTensor data strides getter
*
* The data strides are returned as a pointer pointing to an array of
* n 64-bit integers where n is the rank of the tensor.
*
* The strides array is returned without copying, so caller should
* not free the returned pointer.
*
* @param tensor pointer to the OMTensor
* @return pointer to the data strides array.
*/
int64_t *omTensorGetStrides(OMTensor *tensor);
/**
* \brief OMTensor data strides setter
*
* n int64 elements are copied from the strides array to indicate the
* per-dimension stride of the tensor, where n is the rank of the tensor.
*
* The strides array is copied without being freed, so caller is expected to
* manage the strides array oneself.
*
* @param tensor pointer to the OMTensor
* @param strides tensor strides array to be set.
*
* Set the data strides array of the OMTensor to the values in the input array.
*/
void omTensorSetStrides(OMTensor *tensor, int64_t *strides);
/**
* \brief OMTensor data type getter
*
* @param tensor pointer to the OMTensor
* @return ONNX data type of the data buffer elements.
*/
OM_DATA_TYPE omTensorGetDataType(OMTensor *tensor);
/**
* \brief OMTensor data type setter
*
* @param tensor pointer to the OMTensor
* @param dataType ONNX data type to be set
*
* Set the ONNX data type of the data buffer elements.
*/
void omTensorSetDataType(OMTensor *tensor, OM_DATA_TYPE dataType);
/* Helper function to get the ONNX data type size in bytes */
static inline int getDataTypeSize(OM_DATA_TYPE dataType) {
return OM_DATA_TYPE_SIZE[dataType];
}
/**
* \brief OMTensor data buffer size getter
*
* @param tensor pointer to the OMTensor
* @return the total size of the data buffer in bytes.
*/
int64_t omTensorGetDataBufferSize(OMTensor *tensor);
/**
* \brief OMTensor rank getter
*
* @param tensor, pointer to the OMTensor
* @return rank of data sizes and strides of the OMTensor.
*/
int omTensorGetRank(OMTensor *tensor);
/**
* \brief OMTensor number of elements getter
*
* @param tensor, pointer to the OMTensor
* @return the number of elements in the data buffer.
*/
int64_t omTensorGetNumElems(OMTensor *tensor);
#endif // ONNX_MLIR_OMTENSOR_H

View File

@ -0,0 +1,77 @@
//===-------- OMTensorList.h - OMTensorList Declaration header-------------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains declaration of OMTensorList data structures and
// API functions.
//
//===----------------------------------------------------------------------===//
#ifndef ONNX_MLIR_OMTENSORLIST_H
#define ONNX_MLIR_OMTENSORLIST_H
#include "onnx-mlir/Runtime/OMTensor.h"
struct OMTensorList;
#ifndef __cplusplus
typedef struct OMTensorList OMTensorList;
#endif
/**
* \brief OMTensorList creator
*
* Create an OMTensorList with specified OMTensor array. The array of pointers
* to OMTensor pointers is used without copying, so caller should not free the
* `tensors` ptr.
*
* @param tensors array of pointers to OMTensor
* @param n number of elements in tensors array
* @return pointer to the OMTensorList created, NULL if creation failed.
*
*/
OMTensorList *omTensorListCreate(OMTensor **tensors, int n);
/**
* \brief OMTensorList destroyer
*
* Destroy the OMTensorList struct recursively. That is to say, both the
* ptr to the OMTensor pointers AND the OMTensor pointers are freed.
*
* @param list pointer to the OMTensorList to be destroyed
*
*/
void omTensorListDestroy(OMTensorList *list);
/**
* \brief OMTensorList OMTensor array getter
*
* The pointer to OMTensor pointers are returned without copying, so caller
* should not free the returned pointer.
*
* @param list pointer to the OMTensorList
* @return pointer to the array of OMTensor pointers.
*/
OMTensor **omTensorListGetPtrToOmts(OMTensorList *list);
/**
* \brief OMTensorList size getter
*
*
* @param list pointer to the OMTensorList
* @return number of elements in the OMTensor array.
*/
int omTensorListGetSize(OMTensorList *list);
/**
* \brief OMTensorList OMTensor getter by index
*
* @param list pointer to the OMTensorList
* @param index index of the OMTensor
* @reutrn pointer to the OMTensor, NULL if not found.
*/
OMTensor *omTensorListGetOmtByIndex(OMTensorList *list, size_t index);
#endif // ONNX_MLIR_OMTENSORLIST_H

View File

@ -0,0 +1,57 @@
//===---------------------- DataType.h - ONNX DataTypes -------------------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains declaration of ONNX data types and type size mapping.
// It is provided as a convenience and not used by OMTensor implementation.
//
//===----------------------------------------------------------------------===//
#ifndef ONNX_MLIR_ONNXDATATYPE_H
#define ONNX_MLIR_ONNXDATATYPE_H
#ifdef __cplusplus
#include <cstdint>
#include <map>
#include <string>
#else
#include <stdbool.h>
#include <stdint.h>
#endif
enum OM_DATA_TYPE {
#define OM_TYPE_METADATA_DEF(ENUM_NAME, ENUM_VAL, DTYPE_SIZE) \
ENUM_NAME = ENUM_VAL,
#include "OnnxDataTypeMetaData.inc"
#undef OM_TYPE_METADATA_DEF
};
#ifndef __cplusplus
typedef enum OM_DATA_TYPE OM_DATA_TYPE;
#endif
extern const int OM_DATA_TYPE_SIZE[];
#ifdef __cplusplus
// Note by design const map has no [] operator since [] creates a default
// key value mapping when the key is not found which changes the map
const std::map<std::string, OM_DATA_TYPE> OM_DATA_TYPE_CPP_TO_ONNX = {
{"b", ONNX_TYPE_BOOL}, // bool -> BOOL
{"c", ONNX_TYPE_INT8}, // char -> INT8 (platform dependent, can be UINT8)
{"a", ONNX_TYPE_INT8}, // int8_t -> INT8
{"h", ONNX_TYPE_UINT8}, // uint8_t -> UINT8, unsigned char -> UNIT 8
{"s", ONNX_TYPE_INT16}, // int16_t -> INT16, short -> INT16
{"t", ONNX_TYPE_UINT16}, // uint16_t -> UINT16, unsigned short -> UINT16
{"i", ONNX_TYPE_INT32}, // int32_t -> INT32, int -> INT32
{"j", ONNX_TYPE_UINT32}, // uint32_t -> UINT32, unsigned int -> UINT32
{"l", ONNX_TYPE_INT64}, // int64_t -> INT64, long -> INT64
{"m", ONNX_TYPE_UINT64}, // uint64_t -> UINT64, unsigned long -> UINT64
{"f", ONNX_TYPE_FLOAT}, // float -> FLOAT
{"d", ONNX_TYPE_DOUBLE}, // double -> DOUBLE
};
#endif //__cplusplus
#endif // ONNX_MLIR_ONNXDATATYPE_H

View File

@ -0,0 +1,25 @@
#if defined(OM_TYPE_METADATA_DEF)
// Data type metadata declared in the following format:
// OM_TYPE_METADATA_DEF( dtype enum name, dtype enum value, dtype size)
// clang-format off
OM_TYPE_METADATA_DEF(ONNX_TYPE_UNDEFINED, 0, 0)
OM_TYPE_METADATA_DEF(ONNX_TYPE_FLOAT, 1, sizeof(float))
OM_TYPE_METADATA_DEF(ONNX_TYPE_UINT8, 2, sizeof(uint8_t))
OM_TYPE_METADATA_DEF(ONNX_TYPE_INT8, 3, sizeof(int8_t))
OM_TYPE_METADATA_DEF(ONNX_TYPE_UINT16, 4, sizeof(uint16_t))
OM_TYPE_METADATA_DEF(ONNX_TYPE_INT16, 5, sizeof(int16_t))
OM_TYPE_METADATA_DEF(ONNX_TYPE_INT32, 6, sizeof(int32_t))
OM_TYPE_METADATA_DEF(ONNX_TYPE_INT64, 7, sizeof(int64_t))
OM_TYPE_METADATA_DEF(ONNX_TYPE_STRING, 8, 0)
OM_TYPE_METADATA_DEF(ONNX_TYPE_BOOL, 9, sizeof(bool))
OM_TYPE_METADATA_DEF(ONNX_TYPE_FLOAT16, 10, 2)
OM_TYPE_METADATA_DEF(ONNX_TYPE_DOUBLE, 11, sizeof(double))
OM_TYPE_METADATA_DEF(ONNX_TYPE_UINT32, 12, sizeof(uint32_t))
OM_TYPE_METADATA_DEF(ONNX_TYPE_UINT64, 13, sizeof(uint64_t))
OM_TYPE_METADATA_DEF(ONNX_TYPE_COMPLEX64, 14, 8)
OM_TYPE_METADATA_DEF(ONNX_TYPE_COMPLEX128, 15, 16)
OM_TYPE_METADATA_DEF(ONNX_TYPE_BFLOAT16, 16, 2)
// clang-format on
#else
#error "Must define OM_TYPE_METADATA_DEF macro."
#endif

View File

@ -131,6 +131,26 @@ static FlatSymbolRefAttr getOrInsertMemcpy(
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context); return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
} }
static FlatSymbolRefAttr getOrInsertMalloc(
PatternRewriter &rewriter, ModuleOp module) {
// Insert the malloc/aligned_alloc declaration if it is not already present.
auto allocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc");
auto ctx = rewriter.getContext();
LLVMTypeConverter converter(ctx);
if (!allocFunc) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
SmallVector<LLVM::LLVMType, 2> callArgTypes = {converter.getIndexType()};
// aligned_alloc(size_t alignment, size_t size)
auto voidPtrType = LLVM::LLVMType::getInt8PtrTy(&converter.getContext());
allocFunc =
rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(), "malloc",
LLVM::LLVMType::getFunctionTy(voidPtrType, callArgTypes,
/*isVarArg=*/false));
}
return SymbolRefAttr::get("malloc", ctx);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// KRNL to LLVM: KrnlGetRefOpLowering // KRNL to LLVM: KrnlGetRefOpLowering
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -399,16 +419,15 @@ public:
using OpRewritePattern<KrnlEntryPointOp>::OpRewritePattern; using OpRewritePattern<KrnlEntryPointOp>::OpRewritePattern;
enum class API { enum class API {
CREATE_ORDERED_DYN_MEM_REF_DICT, CREATE_OMTENSOR_LIST,
CREATE_DYN_MEM_REF, CREATE_OMTENSOR,
GET_DYN_MEM_REF,
SET_DYN_MEM_REF,
GET_DATA, GET_DATA,
SET_DATA, SET_DATA,
GET_SIZES, GET_DATA_SIZES,
GET_STRIDES, GET_DATA_STRIDES,
SET_DTYPE, SET_DATA_TYPE,
GET_DTYPE, GET_DATA_TYPE,
GET_OMTS,
}; };
struct ApiSpec { struct ApiSpec {
@ -443,6 +462,7 @@ public:
using LLVMType = LLVM::LLVMType; using LLVMType = LLVM::LLVMType;
auto opaquePtrTy = LLVMType::getInt8PtrTy(context); auto opaquePtrTy = LLVMType::getInt8PtrTy(context);
auto int32Ty = LLVMType::getInt32Ty(context); auto int32Ty = LLVMType::getInt32Ty(context);
auto int64Ty = LLVMType::getInt64Ty(context);
// Rewrite Krnl Entry Point Operation to an LLVM function with a dynamic // Rewrite Krnl Entry Point Operation to an LLVM function with a dynamic
// signature. The signature is dynamic because it remains the same no matter // signature. The signature is dynamic because it remains the same no matter
@ -455,7 +475,7 @@ public:
op.getAttrOfType<SymbolRefAttr>( op.getAttrOfType<SymbolRefAttr>(
KrnlEntryPointOp::getEntryPointFuncAttrName()) KrnlEntryPointOp::getEntryPointFuncAttrName())
.getLeafReference(); .getLeafReference();
auto dynEntryPointName = "_dyn_entry_point_" + staticEntryPointFuncName; auto dynEntryPointName = "run_" + staticEntryPointFuncName;
assert(module.lookupSymbol(dynEntryPointName.str()) == nullptr && assert(module.lookupSymbol(dynEntryPointName.str()) == nullptr &&
"dynamic entry point name is not unique"); "dynamic entry point name is not unique");
rewriter.eraseOp(op); rewriter.eraseOp(op);
@ -484,12 +504,22 @@ public:
// them to static mem refs. // them to static mem refs.
SmallVector<Value, 4> staticInputs; SmallVector<Value, 4> staticInputs;
auto wrappedInput = entryPointEntryBlock.getArgument(0); auto wrappedInput = entryPointEntryBlock.getArgument(0);
auto omTensorPtrArr =
callApi(rewriter, loc, apiRegistry, API::GET_OMTS, {wrappedInput});
for (size_t i = 0; i < staticEntryPointTy.getFunctionNumParams(); i++) { for (size_t i = 0; i < staticEntryPointTy.getFunctionNumParams(); i++) {
// Call API function to retrieve the i-th dynamic memref. // Call API function to retrieve the i-th dynamic memref.
auto idxVal = rewriter.create<LLVM::ConstantOp>( auto idxVal = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(i)); loc, int32Ty, rewriter.getI32IntegerAttr(i));
auto dynMemRef = callApi(rewriter, loc, apiRegistry, API::GET_DYN_MEM_REF,
{wrappedInput, idxVal}); auto omTensorPtrAddrTy = opaquePtrTy.getPointerTo();
auto omTensorPtrAddr = rewriter
.create<LLVM::GEPOp>(loc, omTensorPtrAddrTy,
omTensorPtrArr, ArrayRef<Value>({idxVal}))
.getResult();
auto omTensorPtr =
rewriter.create<LLVM::LoadOp>(loc, opaquePtrTy, omTensorPtrAddr)
.getResult();
// Create a (static) memref type corresponding to the i-th memref input to // Create a (static) memref type corresponding to the i-th memref input to
// the inference function on stack, and load it to memRef. // the inference function on stack, and load it to memRef.
@ -501,9 +531,9 @@ public:
/*alignment=*/0); /*alignment=*/0);
// Fill in the memref underlying ptrToMemRef with information extracted // Fill in the memref underlying ptrToMemRef with information extracted
// from dynMemRef. // from omTensorPtr.
fillPtrToMemRefWithRtMemRef( fillPtrToMemRefWithOMTensor(
dynMemRef, ptrToMemRef, rewriter, loc, apiRegistry, module); omTensorPtr, ptrToMemRef, rewriter, loc, apiRegistry, module);
// ptrToMemRef will be an input to main computation graph function. // ptrToMemRef will be an input to main computation graph function.
staticInputs.emplace_back(ptrToMemRef); staticInputs.emplace_back(ptrToMemRef);
@ -539,27 +569,57 @@ public:
} }
} }
// Create wrapped output. auto numOutput = rewriter.create<LLVM::ConstantOp>(
auto wrappedOutput = callApi( loc, int32Ty, rewriter.getI64IntegerAttr(outMemRefList.size()));
rewriter, loc, apiRegistry, API::CREATE_ORDERED_DYN_MEM_REF_DICT, {});
auto mallocSym = getOrInsertMalloc(rewriter, module);
// TODO(tjingrant): get pointer size from data layout.
size_t kPtrSize = 8;
auto outputOmtPtrsArraySizeInByte = rewriter.create<LLVM::ConstantOp>(loc,
int64Ty, rewriter.getI64IntegerAttr(outMemRefList.size() * kPtrSize));
auto outOmtPtrsArr =
rewriter
.create<LLVM::CallOp>(loc,
LLVM::LLVMType::getInt8PtrTy(module.getContext()), mallocSym,
ArrayRef<Value>(outputOmtPtrsArraySizeInByte))
.getResult(0);
outOmtPtrsArr = rewriter
.create<LLVM::BitcastOp>(loc,
LLVM::LLVMType::getInt8PtrTy(module.getContext())
.getPointerTo(0),
outOmtPtrsArr)
.getResult();
for (decltype(numOutputs) i = 0; i < outMemRefList.size(); i++) { for (decltype(numOutputs) i = 0; i < outMemRefList.size(); i++) {
// Get the i-th memref returned, convert to a dynamic memref and store it // Get the i-th memref returned, convert to a dynamic memref and store it
// in the wrappedOutput. // in the wrappedOutput.
auto memRef = outMemRefList.at(i); auto memRef = outMemRefList.at(i);
auto outMemRefTy = memRef.getType().dyn_cast<LLVMType>(); auto outMemRefTy = memRef.getType().dyn_cast<LLVMType>();
auto outMemRefRank = getRankFromMemRefType(outMemRefTy); auto outMemRefRank = getRankFromMemRefType(outMemRefTy);
auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>( auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank)); loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
auto outRtMemRef = callApi(rewriter, loc, apiRegistry, auto outOMTensor = callApi(
API::CREATE_DYN_MEM_REF, {outMemRefRankVal}); rewriter, loc, apiRegistry, API::CREATE_OMTENSOR, {outMemRefRankVal});
fillRtMemRefWithMemRef( fillOMTensorWithMemRef(
memRef, outRtMemRef, rewriter, loc, apiRegistry, module); memRef, outOMTensor, rewriter, loc, apiRegistry, module);
auto idx = rewriter.create<LLVM::ConstantOp>(
auto idxVal = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(i)); loc, int32Ty, rewriter.getI32IntegerAttr(i));
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
{wrappedOutput, idx, outRtMemRef}); auto omTensorPtrAddrTy = opaquePtrTy.getPointerTo();
auto omTensorPtrAddr = rewriter
.create<LLVM::GEPOp>(loc, omTensorPtrAddrTy,
outOmtPtrsArr, ArrayRef<Value>({idxVal}))
.getResult();
rewriter.create<LLVM::StoreOp>(loc, outOMTensor, omTensorPtrAddr);
} }
// Create wrapped output.
auto wrappedOutput = callApi(rewriter, loc, apiRegistry,
API::CREATE_OMTENSOR_LIST, {outOmtPtrsArr, numOutput});
// Return wrapped output. // Return wrapped output.
rewriter.create<LLVM::ReturnOp>( rewriter.create<LLVM::ReturnOp>(
loc, SmallVector<Value, 1>({wrappedOutput})); loc, SmallVector<Value, 1>({wrappedOutput}));
@ -575,6 +635,7 @@ private:
using LLVMType = LLVM::LLVMType; using LLVMType = LLVM::LLVMType;
auto voidTy = LLVMType::getVoidTy(context); auto voidTy = LLVMType::getVoidTy(context);
auto opaquePtrTy = LLVMType::getInt8PtrTy(context); auto opaquePtrTy = LLVMType::getInt8PtrTy(context);
auto opaquePtrPtrTy = opaquePtrTy.getPointerTo();
auto int32Ty = LLVMType::getInt32Ty(context); auto int32Ty = LLVMType::getInt32Ty(context);
auto int64Ty = LLVMType::getInt64Ty(context); auto int64Ty = LLVMType::getInt64Ty(context);
auto int64PtrTy = int64Ty.getPointerTo(); auto int64PtrTy = int64Ty.getPointerTo();
@ -583,16 +644,15 @@ private:
// specifying its signature. // specifying its signature.
// clang-format off // clang-format off
std::vector<ApiSpec> apiSpecs = { std::vector<ApiSpec> apiSpecs = {
ApiSpec(API::CREATE_ORDERED_DYN_MEM_REF_DICT, "createOrderedRtMemRefDict", opaquePtrTy, {}), ApiSpec(API::CREATE_OMTENSOR_LIST, "omTensorListCreate", opaquePtrTy, {opaquePtrPtrTy, int32Ty}),
ApiSpec(API::CREATE_DYN_MEM_REF, "createRtMemRef", opaquePtrTy, {int32Ty}), ApiSpec(API::CREATE_OMTENSOR, "omTensorCreateEmptyDeprecated", opaquePtrTy, {int32Ty}),
ApiSpec(API::GET_DATA, "getData", opaquePtrTy, {opaquePtrTy}), ApiSpec(API::GET_DATA, "omTensorGetDataPtr", opaquePtrTy, {opaquePtrTy}),
ApiSpec(API::SET_DATA, "setData", voidTy, {opaquePtrTy, opaquePtrTy}), ApiSpec(API::SET_DATA, "omTensorSetPtr", voidTy, {opaquePtrTy, int32Ty, opaquePtrTy, opaquePtrTy}),
ApiSpec(API::GET_DYN_MEM_REF, "getRtMemRef", opaquePtrTy, {opaquePtrTy, int32Ty}), ApiSpec(API::GET_DATA_SIZES, "omTensorGetDataShape", int64PtrTy, {opaquePtrTy}),
ApiSpec(API::SET_DYN_MEM_REF, "setRtMemRef", voidTy, {opaquePtrTy, int32Ty, opaquePtrTy}), ApiSpec(API::GET_DATA_STRIDES, "omTensorGetStrides", int64PtrTy, {opaquePtrTy}),
ApiSpec(API::GET_SIZES, "getSizes", int64PtrTy, {opaquePtrTy}), ApiSpec(API::GET_DATA_TYPE, "omTensorGetDataType", int32Ty, {opaquePtrTy}),
ApiSpec(API::GET_STRIDES, "getStrides", int64PtrTy, {opaquePtrTy}), ApiSpec(API::SET_DATA_TYPE, "omTensorSetDataType", voidTy, {opaquePtrTy, int32Ty}),
ApiSpec(API::GET_DTYPE, "getDType", int32Ty, {opaquePtrTy}), ApiSpec(API::GET_OMTS, "omTensorListGetPtrToOmts", opaquePtrPtrTy, {opaquePtrTy}),
ApiSpec(API::SET_DTYPE, "setDType", voidTy, {opaquePtrTy, int32Ty}),
}; };
// clang-format on // clang-format on
@ -645,7 +705,7 @@ private:
return *entryPointEntryBlock; return *entryPointEntryBlock;
} }
void fillPtrToMemRefWithRtMemRef(Value &dynMemRef, Value &ptrToMemRef, void fillPtrToMemRefWithOMTensor(Value &rtMemRef, Value &ptrToMemRef,
PatternRewriter &rewriter, const Location &loc, PatternRewriter &rewriter, const Location &loc,
const std::map<API, ApiSpec> &apiRegistry, ModuleOp &module) const { const std::map<API, ApiSpec> &apiRegistry, ModuleOp &module) const {
auto *context = module.getContext(); auto *context = module.getContext();
@ -657,7 +717,7 @@ private:
// Set dataPtr and alignedDataPtr; // Set dataPtr and alignedDataPtr;
auto dataPtr = auto dataPtr =
callApi(rewriter, loc, apiRegistry, API::GET_DATA, {dynMemRef}); callApi(rewriter, loc, apiRegistry, API::GET_DATA, {rtMemRef});
dataPtr = rewriter.create<LLVM::BitcastOp>( dataPtr = rewriter.create<LLVM::BitcastOp>(
loc, memRefTy.getStructElementType(0), dataPtr); loc, memRefTy.getStructElementType(0), dataPtr);
memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef, memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef,
@ -674,9 +734,9 @@ private:
// Get rank, sizes array ptr and strides array ptr. // Get rank, sizes array ptr and strides array ptr.
auto rank = getRankFromMemRefType(memRefTy); auto rank = getRankFromMemRefType(memRefTy);
auto sizesArrayPtr = auto sizesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {dynMemRef}); callApi(rewriter, loc, apiRegistry, API::GET_DATA_SIZES, {rtMemRef});
auto stridesArrayPtr = auto stridesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {dynMemRef}); callApi(rewriter, loc, apiRegistry, API::GET_DATA_STRIDES, {rtMemRef});
for (decltype(rank) i = 0; i < rank; i++) { for (decltype(rank) i = 0; i < rank; i++) {
auto dimIdx = rewriter.create<LLVM::ConstantOp>( auto dimIdx = rewriter.create<LLVM::ConstantOp>(
@ -706,7 +766,7 @@ private:
rewriter.create<LLVM::StoreOp>(loc, memRef, ptrToMemRef); rewriter.create<LLVM::StoreOp>(loc, memRef, ptrToMemRef);
} }
void fillRtMemRefWithMemRef(Value &outMemRef, Value &outRtMemRef, void fillOMTensorWithMemRef(Value &outMemRef, Value &outOMTensor,
PatternRewriter &rewriter, const Location &loc, PatternRewriter &rewriter, const Location &loc,
const std::map<API, ApiSpec> &apiRegistry, ModuleOp &module) const { const std::map<API, ApiSpec> &apiRegistry, ModuleOp &module) const {
auto *context = module.getContext(); auto *context = module.getContext();
@ -714,26 +774,40 @@ private:
auto int64Ty = LLVM::LLVMType::getInt64Ty(context); auto int64Ty = LLVM::LLVMType::getInt64Ty(context);
auto int32Ty = LLVM::LLVMType::getInt32Ty(context); auto int32Ty = LLVM::LLVMType::getInt32Ty(context);
// Extract the data pointer, and record it in dynamic mem ref created. // Set ownership to true, i.e., free after OMTensor is destroyed.
Value outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(loc, Value owning = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(1));
// Extract the allocated pointer.
Value outMemRefAllocatedPtr = rewriter.create<LLVM::ExtractValueOp>(loc,
outMemRefTy.getStructElementType(0), outMemRef, outMemRefTy.getStructElementType(0), outMemRef,
rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)})); rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)}));
outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>( outMemRefAllocatedPtr = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(context), outMemRefDataPtr); loc, LLVM::LLVMType::getInt8PtrTy(context), outMemRefAllocatedPtr);
// Extract the aligned pointer.
Value outMemRefAlignedPtr = rewriter.create<LLVM::ExtractValueOp>(loc,
outMemRefTy.getStructElementType(1), outMemRef,
rewriter.getArrayAttr({rewriter.getI64IntegerAttr(1)}));
outMemRefAlignedPtr = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(context), outMemRefAlignedPtr);
// Set ownership, allocated and aligned pointer.
callApi(rewriter, loc, apiRegistry, API::SET_DATA, callApi(rewriter, loc, apiRegistry, API::SET_DATA,
{outRtMemRef, outMemRefDataPtr}); {outOMTensor, owning, outMemRefAllocatedPtr, outMemRefAlignedPtr});
auto elemTy = outMemRefTy.getStructElementType(0).getPointerElementTy(); auto elemTy = outMemRefTy.getStructElementType(0).getPointerElementTy();
auto onnxTy = llvmTypeToOnnxType(elemTy); auto onnxTy = llvmTypeToOnnxType(elemTy);
auto onnxTyVal = rewriter.create<LLVM::ConstantOp>( auto onnxTyVal = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(onnxTy)); loc, int32Ty, rewriter.getI32IntegerAttr(onnxTy));
callApi( callApi(rewriter, loc, apiRegistry, API::SET_DATA_TYPE,
rewriter, loc, apiRegistry, API::SET_DTYPE, {outRtMemRef, onnxTyVal}); {outOMTensor, onnxTyVal});
auto rank = getRankFromMemRefType(outMemRefTy); auto rank = getRankFromMemRefType(outMemRefTy);
auto sizesArrayPtr = auto sizesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {outRtMemRef}); callApi(rewriter, loc, apiRegistry, API::GET_DATA_SIZES, {outOMTensor});
auto stridesArrayPtr = auto stridesArrayPtr = callApi(
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {outRtMemRef}); rewriter, loc, apiRegistry, API::GET_DATA_STRIDES, {outOMTensor});
for (decltype(rank) i = 0; i < rank; i++) { for (decltype(rank) i = 0; i < rank; i++) {
auto dimIdx = rewriter.create<LLVM::ConstantOp>( auto dimIdx = rewriter.create<LLVM::ConstantOp>(

View File

@ -5,43 +5,55 @@ add_subdirectory(jni)
# such static library in a shared library can cause runtime failure on some architectures, # such static library in a shared library can cause runtime failure on some architectures,
# such as z. So we override the default and explicitly compile with -fPIC. # such as z. So we override the default and explicitly compile with -fPIC.
add_library(cruntime STATIC add_library(cruntime STATIC
RtMemRef.cpp OMTensor.c
RtMemRef.h OMTensor.inc
DataType.h) OMTensorList.c
OMTensorList.inc
OnnxDataType.cpp)
set_target_properties(cruntime PROPERTIES
LANGUAGE C)
set_target_properties(cruntime PROPERTIES set_target_properties(cruntime PROPERTIES
POSITION_INDEPENDENT_CODE TRUE) POSITION_INDEPENDENT_CODE TRUE)
target_include_directories(cruntime PRIVATE
${ONNX_MLIR_SRC_ROOT}
${ONNX_MLIR_SRC_ROOT}/include)
add_library(RtMemRefUtils add_library(OMTensorUtils
RtMemRef.h OMTensor.cpp
RtMemRef.cpp OMTensor.inc
DataType.h) OMTensorList.cpp
set_target_properties(RtMemRefUtils PROPERTIES OMTensorList.inc
OnnxDataType.cpp)
set_target_properties(OMTensorUtils PROPERTIES
POSITION_INDEPENDENT_CODE TRUE) POSITION_INDEPENDENT_CODE TRUE)
target_compile_definitions(OMTensorUtils PRIVATE RTMEMREF_INTERNAL_API)
target_include_directories(OMTensorUtils PRIVATE
${ONNX_MLIR_SRC_ROOT}
${ONNX_MLIR_SRC_ROOT}/include)
add_library(ExecutionSession add_library(ExecutionSession
ExecusionSession.hpp ExecutionSession.hpp
ExecusionSession.cpp) ExecutionSession.cpp)
target_include_directories(ExecutionSession PRIVATE
${ONNX_MLIR_SRC_ROOT}/src/Runtime
${ONNX_MLIR_SRC_ROOT}/include)
target_link_libraries(ExecutionSession target_link_libraries(ExecutionSession
${CMAKE_DL_LIBS}) ${CMAKE_DL_LIBS})
target_include_directories(ExecutionSession PRIVATE
${ONNX_MLIR_SRC_ROOT}
${ONNX_MLIR_BIN_ROOT}
${ONNX_MLIR_SRC_ROOT})
set_target_properties(ExecutionSession PROPERTIES set_target_properties(ExecutionSession PROPERTIES
POSITION_INDEPENDENT_CODE TRUE) POSITION_INDEPENDENT_CODE TRUE)
pybind11_add_module(PyRuntime pybind11_add_module(PyRuntime
PyExecutionSession.cpp PyExecutionSession.cpp
PyExecutionSession.hpp) PyExecutionSession.hpp)
target_include_directories(PyRuntime PRIVATE
${ONNX_MLIR_SRC_ROOT}
${ONNX_MLIR_SRC_ROOT}/src/Runtime
${ONNX_MLIR_SRC_ROOT}/include)
target_link_libraries(PyRuntime PRIVATE target_link_libraries(PyRuntime PRIVATE
${CMAKE_DL_LIBS} ${CMAKE_DL_LIBS}
ExecutionSession ExecutionSession
RtMemRefUtils OMTensorUtils
onnx) onnx)
target_include_directories(PyRuntime PRIVATE
${ONNX_MLIR_SRC_ROOT}
${ONNX_MLIR_BIN_ROOT}
${ONNX_MLIR_SRC_ROOT})
# See comments above about libcruntime.a # See comments above about libcruntime.a
add_library(EmbeddedDataLoader STATIC add_library(EmbeddedDataLoader STATIC
@ -51,6 +63,6 @@ set_target_properties(EmbeddedDataLoader PROPERTIES
POSITION_INDEPENDENT_CODE TRUE) POSITION_INDEPENDENT_CODE TRUE)
add_dependencies(PyRuntime cruntime) add_dependencies(PyRuntime cruntime)
install(FILES RtMemRef.h DESTINATION include)
install(TARGETS cruntime DESTINATION lib) install(TARGETS cruntime DESTINATION lib)
install(TARGETS EmbeddedDataLoader DESTINATION lib) install(TARGETS EmbeddedDataLoader DESTINATION lib)

View File

@ -1,40 +0,0 @@
//===---------------------- DataType.h - ONNX DataTypes -------------------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains declaration of ONNX data types.
//
//===----------------------------------------------------------------------===//
enum DYN_MEMREF_DATA_TYPE {
UNDEFINED = 0;
// Basic types.
FLOAT = 1; // float
UINT8 = 2; // uint8_t
INT8 = 3; // int8_t
UINT16 = 4; // uint16_t
INT16 = 5; // int16_t
INT32 = 6; // int32_t
INT64 = 7; // int64_t
STRING = 8; // string
BOOL = 9; // bool
// IEEE754 half-precision floating-point format (16 bits wide).
// This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
FLOAT16 = 10;
DOUBLE = 11;
UINT32 = 12;
UINT64 = 13;
COMPLEX64 = 14; // complex with float32 real and imaginary components
COMPLEX128 = 15; // complex with float64 real and imaginary components
// Non-IEEE floating-point format based on IEEE754 single-precision
// floating-point number truncated to 16 bits.
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
BFLOAT16 = 16;
// Future extensions go here.
};

View File

@ -1,10 +1,10 @@
//===------- ExecusionSession.cpp - ExecutionSession Implementation -------===// //===------- ExecutionSession.cpp - ExecutionSession Implementation -------===//
// //
// Copyright 2019-2020 The IBM Research Authors. // Copyright 2019-2020 The IBM Research Authors.
// //
// ============================================================================= // =============================================================================
// //
// This file contains implementations of ExecusionSession class, which helps C++ // This file contains implementations of ExecutionSession class, which helps C++
// programs interact with compiled binary model libraries. // programs interact with compiled binary model libraries.
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -14,7 +14,7 @@
#include <sstream> #include <sstream>
#include <vector> #include <vector>
#include "ExecusionSession.hpp" #include "ExecutionSession.hpp"
namespace onnx_mlir { namespace onnx_mlir {
@ -42,19 +42,22 @@ ExecutionSession::ExecutionSession(
} }
} }
std::vector<std::unique_ptr<RtMemRef>> ExecutionSession::run( std::vector<std::unique_ptr<OMTensor, decltype(&omTensorDestroy)>>
std::vector<std::unique_ptr<RtMemRef>> ins) { ExecutionSession::run(
auto *wrappedInput = createOrderedRtMemRefDict(); std::vector<std::unique_ptr<OMTensor, decltype(&omTensorDestroy)>> ins) {
for (size_t i = 0; i < ins.size(); i++)
setRtMemRef(wrappedInput, i, ins.at(i).get()); std::vector<OMTensor *> omts;
for (const auto &inOmt : ins)
omts.emplace_back(inOmt.get());
auto *wrappedInput = omTensorListCreate(&omts[0], omts.size());
auto *wrappedOutput = _entryPointFunc(wrappedInput); auto *wrappedOutput = _entryPointFunc(wrappedInput);
std::vector<std::unique_ptr<RtMemRef>> outs; std::vector<std::unique_ptr<OMTensor, decltype(&omTensorDestroy)>> outs;
auto outputSize = getSize(wrappedOutput);
for (size_t i = 0; i < getSize(wrappedOutput); i++) { for (size_t i = 0; i < omTensorListGetSize(wrappedOutput); i++) {
outs.emplace_back(std::unique_ptr<RtMemRef>(getRtMemRef(wrappedOutput, i))); outs.emplace_back(std::unique_ptr<OMTensor, decltype(&omTensorDestroy)>(
omTensorListGetOmtByIndex(wrappedOutput, i), omTensorDestroy));
} }
return std::move(outs); return std::move(outs);
} }

View File

@ -1,10 +1,10 @@
//===--------- ExecusionSession.hpp - ExecutionSession Declaration --------===// //===--------- ExecutionSession.hpp - ExecutionSession Declaration --------===//
// //
// Copyright 2019-2020 The IBM Research Authors. // Copyright 2019-2020 The IBM Research Authors.
// //
// ============================================================================= // =============================================================================
// //
// This file contains declarations of ExecusionSession class, which helps C++ // This file contains declarations of ExecutionSession class, which helps C++
// programs interact with compiled binary model libraries. // programs interact with compiled binary model libraries.
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -15,18 +15,19 @@
#include <dlfcn.h> #include <dlfcn.h>
#include <string> #include <string>
#include "src/Runtime/RtMemRef.h" #include "OnnxMlirRuntime.h"
namespace onnx_mlir { namespace onnx_mlir {
typedef OrderedRtMemRefDict *(*entryPointFuncType)(OrderedRtMemRefDict *); typedef OMTensorList *(*entryPointFuncType)(OMTensorList *);
class ExecutionSession { class ExecutionSession {
public: public:
ExecutionSession(std::string sharedLibPath, std::string entryPointName); ExecutionSession(std::string sharedLibPath, std::string entryPointName);
std::vector<std::unique_ptr<RtMemRef>> run( // Use custom deleter since forward declared OMTensor hides destructor
std::vector<std::unique_ptr<RtMemRef>>); std::vector<std::unique_ptr<OMTensor, decltype(&omTensorDestroy)>> run(
std::vector<std::unique_ptr<OMTensor, decltype(&omTensorDestroy)>>);
~ExecutionSession(); ~ExecutionSession();

12
src/Runtime/OMTensor.c Normal file
View File

@ -0,0 +1,12 @@
//===-------------- OMTensor.cpp - OMTensor C Implementation --------------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains implementation of OMTensor and data structures and
// helper functions.
//
//===----------------------------------------------------------------------===//
#include "OMTensor.inc"

12
src/Runtime/OMTensor.cpp Normal file
View File

@ -0,0 +1,12 @@
//===------------- OMTensor.cpp - OMTensor C++ Implementation -------------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains implementation of OMTensor and data structures and
// helper functions.
//
//===----------------------------------------------------------------------===//
#include "OMTensor.inc"

499
src/Runtime/OMTensor.inc Normal file
View File

@ -0,0 +1,499 @@
//===--------- OMTensor.inc - C/C++ Neutral OMTensor Implementation--------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains implementations of OMTensor data structures
// and helper functions.
//
//===----------------------------------------------------------------------===//
#ifdef __cplusplus
#include <cassert>
#include <map>
#include <numeric>
#include <random>
#include <string>
#include <typeinfo>
#include <vector>
#else
#include <assert.h>
#endif
#ifdef __APPLE__
#include <stdlib.h>
#else
#include <malloc.h>
#endif
#include <stdio.h>
#include <string.h>
#include "onnx-mlir/Runtime/OMTensor.h"
#ifdef __cplusplus
#include "src/Runtime/OMTensorHelper.h"
#endif
struct OMTensor {
#ifdef __cplusplus
/**
* Constructor
*
* @param rank, rank of data sizes and strides
*
* Create a OMTensor with specified rank. Memory for data sizes and strides
* are allocated.
*/
OMTensor(int rank) {
if ((_shape = (int64_t *)malloc(rank * sizeof(int64_t))) &&
(_stride = (int64_t *)malloc(rank * sizeof(int64_t)))) {
assert(_shape);
assert(_stride);
_allocatedPtr = NULL;
_alignedPtr = NULL;
_offset = 0;
_dataType = ONNX_TYPE_UNDEFINED;
_rank = rank;
_owning = false;
} else {
throw std::runtime_error(
"OMTensor(" + std::to_string(rank) + ") malloc error");
}
};
OMTensor() = default;
/**
* Destructor
*
* Destroy the OMTensor struct.
*/
~OMTensor() {
if (_owning)
free(_allocatedPtr);
free(_shape);
free(_stride);
};
#endif
// Fields are named according to:
// https://mlir.llvm.org/docs/Dialects/SPIR-V/#lowering-memrefs-to-spvarray-and-spvrtarray
// On machine without alignment constraints the allocated and aligned pointers
// are the same. However, on machines with alignment constraints not supported
// by the memory allocation system, the allocated ptr points to the chunk of
// memory that is allocated, and the aligned pointer points to a chunk of
// memory that is allocated and also satisfy the alignment constraints on the
// machine. For example, on a machine for which malloc returns chunks aligned
// at 16 byte boundaries, but where tensor must be allocated at 1K boundaries
// for performance reason, allocated pointer may return 0x1000f and aligned
// pointer may return 0x10400.
void *_allocatedPtr; // data buffer
void *_alignedPtr; // aligned data buffer that the omt indexes.
int64_t _offset; // offset of 1st element
int64_t *_shape; // sizes array
int64_t *_stride; // strides array
int64_t _rank; // rank
OM_DATA_TYPE _dataType; // ONNX data type
int _owning; // indicates whether the Omt owns the memory space
// referenced by _allocatedPtr. Omt struct will release the
// memory space referred to by _allocatedPtr upon destruction if
// and only if it owns it.
};
// Create a OMTensor.
OMTensor *omTensorCreate(
void *data_ptr, int64_t *shape, int64_t rank, OM_DATA_TYPE dtype) {
OMTensor *tensor = (OMTensor *)malloc(sizeof(OMTensor));
if (!tensor)
return NULL;
if ((tensor->_shape = (int64_t *)malloc(rank * sizeof(int64_t))) &&
(tensor->_stride = (int64_t *)malloc(rank * sizeof(int64_t)))) {
// If malloc for _shape or _stride fails, free them and return NULL.
if (!tensor->_shape || !tensor->_stride) {
if (tensor->_shape)
free(tensor->_shape);
if (tensor->_stride)
free(tensor->_stride);
return NULL;
}
tensor->_allocatedPtr = data_ptr;
tensor->_alignedPtr = data_ptr;
tensor->_rank = rank;
tensor->_dataType = dtype;
tensor->_owning = false;
}
// Using signed indices helps detect when index falls below 0.
for (int64_t i = rank - 1; i >= 0; i--) {
tensor->_shape[i] = shape[i];
if (i == rank - 1)
tensor->_stride[i] = 1;
else
tensor->_stride[i] = tensor->_stride[i + 1] * tensor->_shape[i + 1];
}
return tensor;
}
// Create a OMTensor.
OMTensor *omTensorCreateWithOwnership(void *data_ptr, int64_t *shape,
int64_t rank, OM_DATA_TYPE dtype, int owning) {
OMTensor *tensor = omTensorCreate(data_ptr, shape, rank, dtype);
// If ctor fails, return NULL.
if (!tensor)
return NULL;
tensor->_owning = owning;
return tensor;
}
// Create a OMTensor.
OMTensor *omTensorCreateEmptyDeprecated(int rank) {
OMTensor *omt = (OMTensor *)malloc(sizeof(struct OMTensor));
if (!omt)
return NULL;
if ((omt->_shape = (int64_t *)malloc(rank * sizeof(int64_t))) &&
(omt->_stride = (int64_t *)malloc(rank * sizeof(int64_t)))) {
// If malloc for _shape or _stride fails, free them and return NULL.
if (!omt->_shape || !omt->_stride) {
if (omt->_shape)
free(omt->_shape);
if (omt->_stride)
free(omt->_stride);
return NULL;
}
omt->_allocatedPtr = NULL;
omt->_alignedPtr = NULL;
omt->_offset = 0;
omt->_dataType = ONNX_TYPE_UNDEFINED;
omt->_rank = rank;
omt->_owning = false;
}
return omt;
}
OMTensor *omTensorCreateEmpty(
int64_t *shape, int64_t rank, OM_DATA_TYPE dtype) {
OMTensor *tensor =
omTensorCreateWithOwnership(NULL, shape, rank, dtype, /*owning=*/true);
// If ctor fails, return null.
if (!tensor)
return NULL;
void *dataPtr = malloc(omTensorGetNumElems(tensor) * getDataTypeSize(dtype));
if (!dataPtr)
return NULL;
tensor->_alignedPtr = dataPtr;
tensor->_allocatedPtr = dataPtr;
return tensor;
}
/* OMTensor destroyer */
void omTensorDestroy(OMTensor *tensor) {
if (tensor->_owning) {
free(tensor->_allocatedPtr);
tensor->_allocatedPtr = NULL;
tensor->_alignedPtr = NULL;
}
free(tensor);
}
/* OMTensor data getter */
void *omTensorGetDataPtr(OMTensor *tensor) { return tensor->_alignedPtr; }
/**
* OMTensor allocated and aligned pointer setter.
* This function is intentionally left out from the header because it is only
* used by the wrapper code we emit around inference function that converts
* MemRefs to OMTensors for user convenience.
*
* @param tensor pointer to the OMTensor
* @param owning whether allocatedPtr should be freed after tensor is destroyed.
* @param allocatedPtr allocated pointer to tensor content.
* @param alignedPtr aligned pointer to tensor content. If NULL will be set to
* allocatedPtr.
*
*/
void omTensorSetPtr(
OMTensor *tensor, int owning, void *allocatedPtr, void *alignedPtr) {
if (tensor->_owning) {
/* If we own the allocated buffer, free it first. */
free(tensor->_allocatedPtr);
}
tensor->_owning = owning;
tensor->_allocatedPtr = allocatedPtr;
if (alignedPtr)
tensor->_alignedPtr = alignedPtr;
else
tensor->_alignedPtr = allocatedPtr;
}
/* OMTensor data sizes getter */
int64_t *omTensorGetDataShape(OMTensor *tensor) { return tensor->_shape; }
/* OMTensor data sizes setter */
void omTensorSetShape(OMTensor *tensor, int64_t *shape) {
for (int i = 0; i < tensor->_rank; i++)
tensor->_shape[i] = shape[i];
}
/* OMTensor data strides getter */
int64_t *omTensorGetStrides(OMTensor *tensor) { return tensor->_stride; }
/* OMTensor data strides setter */
void omTensorSetStrides(OMTensor *tensor, int64_t *strides) {
for (int i = 0; i < tensor->_rank; i++)
tensor->_stride[i] = strides[i];
}
/* OMTensor data type getter */
OM_DATA_TYPE omTensorGetDataType(OMTensor *tensor) { return tensor->_dataType; }
/* OMTensor data type setter */
void omTensorSetDataType(OMTensor *tensor, OM_DATA_TYPE dataType) {
tensor->_dataType = dataType;
}
/* OMTensor data buffer size getter */
int64_t omTensorGetDataBufferSize(OMTensor *tensor) {
return getNumOfElems(tensor->_shape, tensor->_rank) *
getDataTypeSize(tensor->_dataType);
}
/* OMTensor rank getter */
int omTensorGetRank(OMTensor *tensor) { return tensor->_rank; }
/* OMTensor number of elements getter */
int64_t omTensorGetNumElems(OMTensor *tensor) {
// Using signed indices helps detect when index falls below 0.
// Verify that strides are dense, meaning that there're
// no skipping elements.
for (int64_t i = tensor->_rank - 1; i >= 0; i--) {
int64_t strideIfNotSkipping = 1;
for (int64_t j = i + 1; j < tensor->_rank; j++) {
strideIfNotSkipping *= tensor->_shape[j];
}
assert(tensor->_stride[i] == strideIfNotSkipping);
}
return getNumOfElems(tensor->_shape, tensor->_rank);
}
/**
* OMTensor allocated ptr getter.
* Note that this function is intentionally left out from the header
* because it is only used by the wrapper code we emit around inference
* function that converts OMTensor into MemRefs for user convenience.
*
* @param tensor pointer to the OMTensor
* @return pointer to the allocated data buffer of the OMTensor,
* NULL if the allocated data buffer is not set.
*/
void *omTensorGetAllocatedPtr(OMTensor *tensor) {
return tensor->_allocatedPtr;
}
#ifdef __cplusplus
/* OMTensor creator with data sizes and element type */
template <typename T>
OMTensor *omTensorCreateWithShape(std::vector<int64_t> dataSizes) {
/* Create a OMTensor with data sizes and strides allocated */
auto omt = omTensorCreateEmptyDeprecated(dataSizes.size());
if (omt == NULL)
return NULL;
/* Allocate data buffer */
omt->_rank = dataSizes.size();
if ((omt->_allocatedPtr = malloc(
getNumOfElems(dataSizes.data(), omt->_rank) * sizeof(T))) == NULL) {
omTensorDestroy(omt);
return NULL;
}
omt->_alignedPtr = omt->_allocatedPtr;
omt->_offset = 0;
/* Copy dataSizes, _shape already allocated by omTensorCreate */
copy(dataSizes.begin(), dataSizes.end(), omt->_shape);
/* Compute and copy dataStrides, _stride already allocated by
* omTensorCreateEmptyDeprecated
*/
auto computedStrides = computeStridesFromSizes(omt->_shape, omt->_rank);
copy(computedStrides.begin(), computedStrides.end(), omt->_stride);
/* Convert CPP type to ONNX type */
try {
omt->_dataType = OM_DATA_TYPE_CPP_TO_ONNX.at(std::string(typeid(T).name()));
} catch (const std::out_of_range &e) {
omt->_dataType = ONNX_TYPE_UNDEFINED;
}
/* Set flag for destructor */
omt->_owning = true;
return omt;
}
/* OMTensor creator with data sizes, element type and random data */
template <typename T>
OMTensor *omTensorCreateWithRandomData(
std::vector<int64_t> dataSizes, T lbound, T ubound) {
// Will be used to obtain a seed for the random number engine
std::random_device rd;
// Standard mersenne_twister_engine seeded with rd()
std::mt19937 gen(rd());
std::uniform_real_distribution<> dis(lbound, ubound);
auto omt = omTensorCreateWithShape<T>(dataSizes);
if (omt == NULL)
return NULL;
std::generate((T *)omt->_allocatedPtr,
(T *)omt->_allocatedPtr + getNumOfElems(omt->_shape, omt->_rank),
[&]() { return dis(gen); });
return omt;
}
/* Access an element (by reference) at offset computed by index array */
template <typename T>
T &omTensorGetElem(OMTensor *omt, std::vector<int64_t> indexes) {
int64_t elemOffset = omTensorComputeElemOffset(omt, indexes);
return ((T *)omt->_allocatedPtr)[elemOffset];
}
/* Access an element (by reference) at linear offset */
template <typename T>
T &omTensorGetElemByOffset(OMTensor *omt, int64_t index) {
return ((T *)omt->_allocatedPtr)[index];
}
/* Compute strides vector from sizes vector */
std::vector<int64_t> omTensorComputeStridesFromShape(OMTensor *omt) {
return computeStridesFromSizes(omt->_shape, omt->_rank);
}
/* Compute linear element offset from multi-dimensional index array */
int64_t omTensorComputeElemOffset(
OMTensor *omt, std::vector<int64_t> &indexes) {
return computeElemOffset(omt->_stride, omt->_rank, indexes);
}
/* Compute index set for the whole OMTensor */
std::vector<std::vector<int64_t>> omTensorComputeIndexSet(OMTensor *omt) {
// First, we create index set of each dimension separately.
// i.e., for a tensor/OMT of shape (2, 3), its dimWiseIdxSet will be:
// {{0,1}, {0,1,2}};
std::vector<std::vector<int64_t>> dimWiseIdxSet;
for (auto dimSize :
std::vector<int64_t>(omt->_shape, omt->_shape + omt->_rank)) {
std::vector<int64_t> dimIdxSet(dimSize);
iota(begin(dimIdxSet), end(dimIdxSet), 0);
dimWiseIdxSet.emplace_back(dimIdxSet);
}
// Then, the cartesian product of vectors within dimWiseIdxSet will be the
// index set for the whole OMT.
return CartProduct(dimWiseIdxSet);
}
/* Check whether two OMTensor data are "close" to each other */
template <typename T>
inline bool omTensorAreTwoOmtsClose(
OMTensor *a, OMTensor *b, float rtol, float atol) {
// Compare shape.
auto aShape = std::vector<int64_t>(a->_shape, a->_shape + a->_rank);
auto bShape = std::vector<int64_t>(b->_shape, b->_shape + b->_rank);
if (aShape != bShape) {
std::cerr << "Shape mismatch ";
printVector(aShape, ",", std::cerr);
std::cerr << " != ";
printVector(bShape, ",", std::cerr);
return false;
}
// Compute absolute difference, verify it's within tolerable range.
auto anum = omTensorGetNumElems(a);
std::vector<T> absoluteDiff(anum);
std::transform((T *)a->_allocatedPtr, (T *)a->_allocatedPtr + anum,
(T *)b->_allocatedPtr, absoluteDiff.begin(), std::minus<>());
std::transform(absoluteDiff.begin(), absoluteDiff.end(), absoluteDiff.begin(),
static_cast<T (*)(T)>(&std::abs));
bool atolSatisfied = std::all_of(
absoluteDiff.begin(), absoluteDiff.end(), [&](T a) { return a < atol; });
// Compute relative difference, verify it's within tolerable range.
std::vector<T> relativeDiff(anum);
std::transform(absoluteDiff.begin(), absoluteDiff.end(),
(T *)a->_allocatedPtr, relativeDiff.begin(), std::divides<>());
bool rtolSatisfied = all_of(
relativeDiff.begin(), relativeDiff.end(), [&](T a) { return a < rtol; });
if (atolSatisfied && rtolSatisfied) {
return true;
} else {
// Figure out where and what went wrong, this can be slow; but hopefully we
// don't need this often.
for (const auto &idx : omTensorComputeIndexSet(a)) {
T aElem = omTensorGetElem<T>(a, idx);
T bElem = omTensorGetElem<T>(b, idx);
auto elmAbsDiff = abs(aElem - bElem);
auto withinRtol = (elmAbsDiff / aElem < rtol);
auto withinAtol = (elmAbsDiff < atol);
if (!withinRtol || !withinAtol) {
std::cerr << "a[";
printVector(idx, ",", std::cerr);
std::cerr << "] = " << aElem << " != ";
std::cerr << "b[";
printVector(idx, ",", std::cerr);
std::cerr << "] = " << bElem << std::endl;
}
}
return false;
}
}
// Explicit instantiation of all templated API functions.
template OMTensor *omTensorCreateWithShape<int32_t>(
std::vector<int64_t> dataSizes);
template OMTensor *omTensorCreateWithShape<int64_t>(
std::vector<int64_t> dataSizes);
template OMTensor *omTensorCreateWithShape<float>(
std::vector<int64_t> dataSizes);
template OMTensor *omTensorCreateWithShape<double>(
std::vector<int64_t> dataSizes);
template OMTensor *omTensorCreateWithRandomData<int32_t>(
std::vector<int64_t> dataSizes, int32_t lbound, int32_t ubound);
template OMTensor *omTensorCreateWithRandomData<int64_t>(
std::vector<int64_t> dataSizes, int64_t lbound, int64_t ubound);
template OMTensor *omTensorCreateWithRandomData<float>(
std::vector<int64_t> dataSizes, float lbound, float ubound);
template OMTensor *omTensorCreateWithRandomData<double>(
std::vector<int64_t> dataSizes, double lbound, double ubound);
template int32_t &omTensorGetElem<int32_t>(
OMTensor *, std::vector<int64_t> indexes);
template int64_t &omTensorGetElem<int64_t>(
OMTensor *, std::vector<int64_t> indexes);
template float &omTensorGetElem<float>(
OMTensor *, std::vector<int64_t> indexes);
template double &omTensorGetElem<double>(
OMTensor *, std::vector<int64_t> indexes);
template int32_t &omTensorGetElemByOffset<int32_t>(OMTensor *, int64_t index);
template int64_t &omTensorGetElemByOffset<int64_t>(OMTensor *, int64_t index);
template float &omTensorGetElemByOffset<float>(OMTensor *, int64_t indexs);
template double &omTensorGetElemByOffset<double>(OMTensor *, int64_t index);
template bool omTensorAreTwoOmtsClose<int32_t>(
OMTensor *a, OMTensor *b, float rtol, float atol);
template bool omTensorAreTwoOmtsClose<int64_t>(
OMTensor *a, OMTensor *b, float rtol, float atol);
template bool omTensorAreTwoOmtsClose<float>(
OMTensor *a, OMTensor *b, float rtol, float atol);
template bool omTensorAreTwoOmtsClose<double>(
OMTensor *a, OMTensor *b, float rtol, float atol);
#endif

View File

@ -0,0 +1,160 @@
//===----------- OMTensorHelper.h - OMTensor Helper Func header -----------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains declaration of OMTensor C++ helper functions. At some
// point, this file needs to be merged into the OMTensor.h along with other C++
// APIs operating on OMTensor.
//
//===----------------------------------------------------------------------===//
#ifndef ONNX_MLIR_OMTENSORHELPER_H
#define ONNX_MLIR_OMTENSORHELPER_H
/* Helper function to compute cartisian product */
static inline std::vector<std::vector<int64_t>> CartProduct(
const std::vector<std::vector<int64_t>> &v) {
std::vector<std::vector<int64_t>> s = {{}};
for (const auto &u : v) {
std::vector<std::vector<int64_t>> r;
for (const auto &x : s) {
for (const auto y : u) {
r.push_back(x);
r.back().push_back(y);
}
}
s = move(r);
}
return s;
}
/* Helper function to compute data strides from sizes */
static inline std::vector<int64_t> computeStridesFromSizes(
int64_t *dataSizes, int rank) {
// Shift dimension sizes one to the left, fill in the vacated rightmost
// element with 1; this gets us a vector that'll be more useful for computing
// strides of memory access along each dimension using prefix product (aka
// partial_sum with a multiply operator below). The intuition is that the size
// of the leading dimension does not matter when computing strides.
std::vector<int64_t> sizesVec(dataSizes + 1, dataSizes + rank);
sizesVec.push_back(1);
std::vector<int64_t> dimStrides(rank);
partial_sum(sizesVec.rbegin(), sizesVec.rend(), dimStrides.rbegin(),
std::multiplies<>());
return dimStrides;
}
/* Helper function to compute linear offset from a multi-dimensional index array
*/
static inline int64_t computeElemOffset(
int64_t *dataStrides, int rank, std::vector<int64_t> &indexes) {
auto dimStrides = std::vector<int64_t>(dataStrides, dataStrides + rank);
int64_t elemOffset = inner_product(
indexes.begin(), indexes.end(), dimStrides.begin(), (int64_t)0);
return elemOffset;
}
/* Helper function to print a vector with delimiter */
template <typename T>
static inline void printVector(std::vector<T> vec, std::string _delimiter = ",",
std::ostream &stream = std::cout) {
std::string delimiter;
for (const auto &elem : vec) {
stream << delimiter << elem;
delimiter = _delimiter;
}
}
/**
* OMTensor creator with data sizes and element type
*
* @param dataSizes, data sizes array
* @return pointer to OMTensor created, NULL if creation failed.
*
* Create a full OMTensor of data type T and shape dataSizes, with all
* data fields initialized to proper values and data pointers malloc'ed.
*/
template <typename T>
OMTensor *omTensorCreateWithShape(std::vector<int64_t> dataSizes);
/**
* OMTensor creator with data sizes, element type and random data
*
* @param dataSizes, data sizes array
* @param lbound (optional), lower bound of the random distribution
* @param ubound (optional), upper bound of the random distribution
* @return pointer to OMTensor created, NULL if creation failed.
*
* Create a full OMTensor like what omTensorCreateWithShape does
* and also fill the OMTensor data buffer with randomly generated
* real numbers from a uniform distribution between lbound and ubound.
*/
template <typename T>
OMTensor *omTensorCreateWithRandomData(
std::vector<int64_t> dataSizes, T lbound = -1.0, T ubound = 1.0);
/**
* OMTensor data element getter by offset
*
* @param omt, pointer to the OMTensor
* @param indexes, multi-dimensional index array of the element
* @return typed element by reference at the offset computed by the index array.
*/
template <typename T>
T &omTensorGetElem(OMTensor *omt, std::vector<int64_t> indexes);
/**
* OMTensor data element getter by index
*
* @param omt, pointer to the OMTensor
* @param index, index of the element
* @return typed element by reference at the linear offset.
*/
template <typename T>
T &omTensorGetElemByOffset(OMTensor *omt, int64_t index);
/**
* OMTensor strides computation
*
* @param omt, pointer to the OMTensor
* @return data strides of the OMTensor computed from the data sizes.
*/
std::vector<int64_t> omTensorComputeStridesFromShape(OMTensor *omt);
/**
* OMTensor linear offset computation
*
* @param omt, pointer to the OMTensor
* @param indexes, multi-dimensional index array
* @return linear offset.
*/
int64_t omTensorComputeElemOffset(OMTensor *omt, std::vector<int64_t> &indexes);
/**
* OMTensor index set computation
*
* @param omt, pointer to the OMTensor
* @return index set (i.e., all valid multi-dimensional array indexes
* that can be used to access this OMTensor's constituent elements)
* for the whole OMTensor.
*/
std::vector<std::vector<int64_t>> omTensorComputeIndexSet(OMTensor *omt);
/**
* OMTensor "distance" computation
*
* @param a, 1st OMTensor
* @param b, 2nd OMTensor
* @param rtol (optional), relative difference tolerance
* @param atol (optional), absolute difference tolerance
* @return true if both relative and absolute difference are within the
* specified tolerance, respectively, false otherwise.
*/
template <typename T>
bool omTensorAreTwoOmtsClose(
OMTensor *a, OMTensor *b, float rtol = 1e-5, float atol = 1e-5);
#endif // ONNX_MLIR_OMTENSORHELPER_H

View File

@ -0,0 +1,12 @@
//===------------- OMTensorList.c - OMTensor C Implementation -------------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains implementation of OMTensorList data structures and
// helper functions.
//
//===----------------------------------------------------------------------===//
#include "OMTensorList.inc"

View File

@ -0,0 +1,12 @@
//===------------ OMTensorList.cpp - OMTensor C++ Implementation ----------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains implementation of OMTensorList data structures and
// helper functions.
//
//===----------------------------------------------------------------------===//
#include "OMTensorList.inc"

View File

@ -0,0 +1,93 @@
//===---------- OMTensorList.cpp - OMTensor C/C++ Implementation ----------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains C/C++ neutral implementation of OMTensorList data
// structures and helper functions.
//
//===----------------------------------------------------------------------===//
#ifdef __APPLE__
#include <stdlib.h>
#else
#include <malloc.h>
#endif
#ifdef __cplusplus
#include <cassert>
#else
#include <assert.h>
#endif
#include "onnx-mlir/Runtime/OMTensorList.h"
struct OMTensorList {
#ifdef __cplusplus
/**
* Constructor
*
* Create an OMTensorList with specified OMTensor pointer array
* and the size of the array
*/
OMTensorList(OMTensor *omts[], int n) : _omts(omts), _size(n){};
/**
* Constructor
*
* Create an empty OMTensorList for internal API calls.
*/
OMTensorList() = default;
/**
* Destructor
*
* Destroy the OMTensorList struct.
*/
~OMTensorList() {
/* Destroy all the OMTensors */
for (int i = 0; i < _size; i++)
if (_omts[i])
omTensorDestroy(_omts[i]);
};
#endif
/* To facilitate user facing API getOmts, OMTensors are kept in a vector
* that can be quickly returned as an array. A name to index map is used
* to address ReMemRefs by name.
*/
OMTensor **_omts; // OMTensor array
size_t _size; // Number of elements in _omts.
};
/* OMTensorList creator */
OMTensorList *omTensorListCreate(OMTensor **tensors, int n) {
OMTensorList *list = (OMTensorList *)malloc(sizeof(struct OMTensorList));
if (!list)
return NULL;
list->_omts = tensors;
list->_size = n;
return list;
}
/* OMTensorList destroyer */
void omTensorListDestroy(OMTensorList *list) {
for (int i = 0; i < list->_size; i++)
free(list->_omts[i]);
free(list);
}
/* OMTensorList OMTensor array getter */
OMTensor **omTensorListGetPtrToOmts(OMTensorList *list) { return list->_omts; }
/* OMTensorList number of OMTensor getter */
int omTensorListGetSize(OMTensorList *list) { return list->_size; }
/* Return OMTensor at specified index in the OMTensorList */
OMTensor *omTensorListGetOmtByIndex(OMTensorList *rlist, size_t index) {
assert(index >= 0);
assert(index < rlist->_size);
return rlist->_omts[index];
}

View File

@ -0,0 +1,8 @@
#include "onnx-mlir/Runtime/OnnxDataType.h"
const int OM_DATA_TYPE_SIZE[] = {
#define OM_TYPE_METADATA_DEF(ENUM_NAME, ENUM_VAL, DTYPE_SIZE) DTYPE_SIZE,
#include "onnx-mlir/Runtime/OnnxDataTypeMetaData.inc"
#undef OM_TYPE_METADATA_DEF
};

View File

@ -17,74 +17,109 @@
namespace onnx_mlir { namespace onnx_mlir {
std::vector<py::array> PyExecutionSession::pyRun( std::vector<py::array> PyExecutionSession::pyRun(
std::vector<py::array> inputsPyArray) { const std::vector<py::array> &inputsPyArray) {
assert(_entryPointFunc && "Entry point not loaded."); assert(_entryPointFunc && "Entry point not loaded.");
auto *wrappedInput = createOrderedRtMemRefDict();
int inputIdx = 0; std::vector<OMTensor *> omts;
for (auto inputPyArray : inputsPyArray) { for (auto inputPyArray : inputsPyArray) {
auto *inputRtMemRef = createRtMemRef(inputPyArray.ndim());
assert(inputPyArray.flags() && py::array::c_style && assert(inputPyArray.flags() && py::array::c_style &&
"Expect contiguous python array."); "Expect contiguous python array.");
void *dataPtr;
int ownData = 0;
if (inputPyArray.writeable()) { if (inputPyArray.writeable()) {
inputRtMemRef->data = inputPyArray.mutable_data(); dataPtr = inputPyArray.mutable_data();
inputRtMemRef->alignedData = inputPyArray.mutable_data();
} else { } else {
// If data is not writable, copy them to a writable buffer. // If data is not writable, copy them to a writable buffer.
auto *copiedData = (float *)malloc(inputPyArray.nbytes()); auto *copiedData = (float *)malloc(inputPyArray.nbytes());
memcpy(copiedData, inputPyArray.data(), inputPyArray.nbytes()); memcpy(copiedData, inputPyArray.data(), inputPyArray.nbytes());
inputRtMemRef->data = copiedData; dataPtr = copiedData;
inputRtMemRef->alignedData = copiedData; // We want OMTensor to free up the memory space upon destruction.
ownData = 1;
} }
for (int i = 0; i < inputPyArray.ndim(); i++) { // Borrowed from:
inputRtMemRef->sizes[i] = inputPyArray.shape(i); // https://github.com/pybind/pybind11/issues/563#issuecomment-267835542
inputRtMemRef->strides[i] = inputPyArray.strides(i); OM_DATA_TYPE dtype;
} if (py::isinstance<py::array_t<float>>(inputPyArray))
dtype = ONNX_TYPE_FLOAT;
setRtMemRef(wrappedInput, inputIdx++, inputRtMemRef); else if (py::isinstance<py::array_t<std::uint8_t>>(inputPyArray))
} dtype = ONNX_TYPE_UINT8;
else if (py::isinstance<py::array_t<std::int8_t>>(inputPyArray))
std::vector<py::array> outputPyArrays; dtype = ONNX_TYPE_INT8;
auto *wrappedOutput = _entryPointFunc(wrappedInput); else if (py::isinstance<py::array_t<std::uint16_t>>(inputPyArray))
for (int i = 0; i < numRtMemRefs(wrappedOutput); i++) { dtype = ONNX_TYPE_UINT16;
auto *dynMemRef = getRtMemRef(wrappedOutput, i); else if (py::isinstance<py::array_t<std::int16_t>>(inputPyArray))
auto shape = std::vector<int64_t>( dtype = ONNX_TYPE_INT16;
dynMemRef->sizes, dynMemRef->sizes + dynMemRef->rank); else if (py::isinstance<py::array_t<std::int32_t>>(inputPyArray))
dtype = ONNX_TYPE_INT32;
// https://numpy.org/devdocs/user/basics.types.html else if (py::isinstance<py::array_t<std::int64_t>>(inputPyArray))
py::dtype dtype; dtype = ONNX_TYPE_INT64;
if (dynMemRef->onnx_dtype == onnx::TensorProto::FLOAT) else if (py::isinstance<py::array_t<bool>>(inputPyArray))
dtype = py::dtype("float32"); dtype = ONNX_TYPE_BOOL;
else if (dynMemRef->onnx_dtype = onnx::TensorProto::UINT8) // Missing fp16 support.
dtype = py::dtype("uint8"); else if (py::isinstance<py::array_t<double>>(inputPyArray))
else if (dynMemRef->onnx_dtype = onnx::TensorProto::INT8) dtype = ONNX_TYPE_DOUBLE;
dtype = py::dtype("int8"); else if (py::isinstance<py::array_t<std::uint32_t>>(inputPyArray))
else if (dynMemRef->onnx_dtype = onnx::TensorProto::UINT16) dtype = ONNX_TYPE_UINT32;
dtype = py::dtype("uint16"); else if (py::isinstance<py::array_t<std::uint64_t>>(inputPyArray))
else if (dynMemRef->onnx_dtype = onnx::TensorProto::INT16) dtype = ONNX_TYPE_UINT64;
dtype = py::dtype("int16");
else if (dynMemRef->onnx_dtype == onnx::TensorProto::INT32)
dtype = py::dtype("int32");
else if (dynMemRef->onnx_dtype == onnx::TensorProto::INT64)
dtype = py::dtype("int64");
// TODO(tjingrant) wait for Tong's input for how to represent string.
else if (dynMemRef->onnx_dtype = onnx::TensorProto::BOOL)
dtype = py::dtype("bool_");
else if (dynMemRef->onnx_dtype = onnx::TensorProto::FLOAT16)
dtype = py::dtype("float32");
else if (dynMemRef->onnx_dtype = onnx::TensorProto::DOUBLE)
dtype = py::dtype("float64");
else if (dynMemRef->onnx_dtype == onnx::TensorProto::UINT32)
dtype = py::dtype("uint32");
else if (dynMemRef->onnx_dtype == onnx::TensorProto::UINT64)
dtype = py::dtype("uint64");
else { else {
fprintf(stderr, "Unsupported ONNX type in RtMemRef.onnx_dtype."); std::cerr << "Numpy type not supported: " << inputPyArray.dtype()
<< ".\n";
exit(1); exit(1);
} }
outputPyArrays.emplace_back(py::array(dtype, shape, dynMemRef->data)); auto *inputOMTensor = omTensorCreateWithOwnership(dataPtr,
(int64_t *)inputPyArray.shape(), inputPyArray.ndim(), dtype, ownData);
omTensorSetStrides(inputOMTensor, (int64_t *)inputPyArray.strides());
omts.emplace_back(inputOMTensor);
}
auto *wrappedInput = omTensorListCreate(&omts[0], omts.size());
auto *wrappedOutput = _entryPointFunc(wrappedInput);
std::vector<py::array> outputPyArrays;
for (int i = 0; i < omTensorListGetSize(wrappedOutput); i++) {
auto *omt = omTensorListGetOmtByIndex(wrappedOutput, i);
auto shape = std::vector<int64_t>(omTensorGetDataShape(omt),
omTensorGetDataShape(omt) + omTensorGetRank(omt));
// https://numpy.org/devdocs/user/basics.types.html
py::dtype dtype;
if (omTensorGetDataType(omt) == onnx::TensorProto::FLOAT)
dtype = py::dtype("float32");
else if (omTensorGetDataType(omt) == onnx::TensorProto::UINT8)
dtype = py::dtype("uint8");
else if (omTensorGetDataType(omt) == onnx::TensorProto::INT8)
dtype = py::dtype("int8");
else if (omTensorGetDataType(omt) == onnx::TensorProto::UINT16)
dtype = py::dtype("uint16");
else if (omTensorGetDataType(omt) == onnx::TensorProto::INT16)
dtype = py::dtype("int16");
else if (omTensorGetDataType(omt) == onnx::TensorProto::INT32)
dtype = py::dtype("int32");
else if (omTensorGetDataType(omt) == onnx::TensorProto::INT64)
dtype = py::dtype("int64");
// TODO(tjingrant) wait for Tong's input for how to represent string.
else if (omTensorGetDataType(omt) == onnx::TensorProto::BOOL)
dtype = py::dtype("bool_");
else if (omTensorGetDataType(omt) == onnx::TensorProto::FLOAT16)
dtype = py::dtype("float32");
else if (omTensorGetDataType(omt) == onnx::TensorProto::DOUBLE)
dtype = py::dtype("float64");
else if (omTensorGetDataType(omt) == onnx::TensorProto::UINT32)
dtype = py::dtype("uint32");
else if (omTensorGetDataType(omt) == onnx::TensorProto::UINT64)
dtype = py::dtype("uint64");
else {
fprintf(stderr, "Unsupported ONNX type in OMTensor.");
exit(1);
}
outputPyArrays.emplace_back(
py::array(dtype, shape, omTensorGetDataPtr(omt)));
} }
return outputPyArrays; return outputPyArrays;

View File

@ -15,7 +15,7 @@
namespace py = pybind11; namespace py = pybind11;
#include "src/Runtime/ExecusionSession.hpp" #include "ExecutionSession.hpp"
namespace onnx_mlir { namespace onnx_mlir {
@ -24,7 +24,7 @@ public:
PyExecutionSession(std::string sharedLibPath, std::string entryPointName) PyExecutionSession(std::string sharedLibPath, std::string entryPointName)
: onnx_mlir::ExecutionSession(sharedLibPath, entryPointName){}; : onnx_mlir::ExecutionSession(sharedLibPath, entryPointName){};
std::vector<py::array> pyRun(std::vector<py::array> inputsPyArray); std::vector<py::array> pyRun(const std::vector<py::array> &inputsPyArray);
}; };
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@ -1,171 +0,0 @@
//===----------- RtMemRef.cpp - Dynamic MemRef Implementation ------------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains implementations of Dynamic MemRef data structures and
// helper functions.
//
//===----------------------------------------------------------------------===//
#include <cassert>
#include <map>
#include <string>
#include <vector>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "RtMemRef.h"
namespace {
// Helper function to compute cartisian product.
inline std::vector<std::vector<INDEX_TYPE>> CartProduct(
const std::vector<std::vector<INDEX_TYPE>> &v) {
std::vector<std::vector<INDEX_TYPE>> s = {{}};
for (const auto &u : v) {
std::vector<std::vector<INDEX_TYPE>> r;
for (const auto &x : s) {
for (const auto y : u) {
r.push_back(x);
r.back().push_back(y);
}
}
s = move(r);
}
return s;
}
} // namespace
RtMemRef::RtMemRef(int _rank) {
rank = _rank;
sizes = (INDEX_TYPE *)malloc(rank * sizeof(INDEX_TYPE));
strides = (int64_t *)malloc(rank * sizeof(int64_t));
}
INDEX_TYPE RtMemRef::size() const {
return std::accumulate(sizes, sizes + rank, 1, std::multiplies<>());
}
std::vector<std::vector<INDEX_TYPE>> RtMemRef::indexSet() const {
// First, we create index set of each dimension separately.
// i.e., for a tensor/RMR of shape (2, 3), its dimWiseIdxSet will be:
// {{0,1}, {0,1,2}};
std::vector<std::vector<INDEX_TYPE>> dimWiseIdxSet;
for (auto dimSize : std::vector<INDEX_TYPE>(sizes, sizes + rank)) {
std::vector<INDEX_TYPE> dimIdxSet(dimSize);
std::iota(std::begin(dimIdxSet), std::end(dimIdxSet), 0);
dimWiseIdxSet.emplace_back(dimIdxSet);
}
// Then, the cartesian product of vectors within dimWiseIdxSet will be the
// index set for the whole RMR.
return CartProduct(dimWiseIdxSet);
}
INDEX_TYPE RtMemRef::computeOffset(std::vector<INDEX_TYPE> &idxs) const {
auto dimStrides = std::vector<INDEX_TYPE>(strides, strides + rank);
INDEX_TYPE elemOffset = std::inner_product(
idxs.begin(), idxs.end(), dimStrides.begin(), (INDEX_TYPE)0);
return elemOffset;
}
std::vector<int64_t> RtMemRef::computeStridesFromSizes() const {
// Shift dimension sizes one to the left, fill in the vacated rightmost
// element with 1; this gets us a vector that'll be more useful for computing
// strides of memory access along each dimension using prefix product (aka
// partial_sum with a multiply operator below). The intuition is that the size
// of the leading dimension does not matter when computing strides.
std::vector<int64_t> sizesVec(sizes + 1, sizes + rank);
sizesVec.push_back(1);
std::vector<int64_t> dimStrides(rank);
std::partial_sum(sizesVec.rbegin(), sizesVec.rend(), dimStrides.rbegin(),
std::multiplies<>());
return dimStrides;
}
RtMemRef::~RtMemRef() {
free(data);
free(sizes);
free(strides);
}
// An ordered dynamic MemRef dictionary.
// The goal is to support accessing dynamic memory ref by name and by index.
// Currently, only accessing by index is supported.
struct OrderedRtMemRefDict {
std::map<std::string, RtMemRef *> tensorDict;
std::vector<std::string> orderedNames;
};
int numRtMemRefs(OrderedRtMemRefDict *dict) {
return dict->orderedNames.size();
}
OrderedRtMemRefDict *createOrderedRtMemRefDict() {
return new OrderedRtMemRefDict();
}
RtMemRef *createRtMemRef(int rank) { return new RtMemRef(rank); }
RtMemRef *getRtMemRef(OrderedRtMemRefDict *tensorDict, int idx) {
return tensorDict->tensorDict[tensorDict->orderedNames[idx]];
}
void setRtMemRef(OrderedRtMemRefDict *tensorDict, int idx, RtMemRef *tensor) {
if (tensorDict->orderedNames.size() <= idx)
tensorDict->orderedNames.resize(idx + 1);
// The dynamic memref is essentially anonymous, since we are storing it by
// indexed position.
// TODO: can use random string as names to reduce chance of collision.
auto unique_name = std::to_string(idx);
assert(tensorDict->tensorDict.count(unique_name) == 0 &&
"duplicate dynamic mem ref name");
tensorDict->orderedNames[idx] = unique_name;
tensorDict->tensorDict[tensorDict->orderedNames[idx]] = tensor;
}
void *getData(RtMemRef *dynMemRef) { return dynMemRef->data; }
void setData(RtMemRef *dynMemRef, void *dataPtr) { dynMemRef->data = dataPtr; }
void *getAlignedData(RtMemRef *dynMemRef) { return dynMemRef->alignedData; }
void setAlignedData(RtMemRef *dynMemRef, void *dataPtr) {
dynMemRef->alignedData = dataPtr;
}
INDEX_TYPE *getSizes(RtMemRef *dynMemRef) { return dynMemRef->sizes; }
void setSizes(RtMemRef *dynMemRef, INDEX_TYPE *sizes) {
for (int i = 0; i < dynMemRef->rank; i++)
dynMemRef->sizes[i] = sizes[i];
}
int64_t *getStrides(RtMemRef *dynMemRef) { return dynMemRef->strides; }
int64_t getSize(OrderedRtMemRefDict *dict) { return dict->orderedNames.size(); }
INDEX_TYPE getDataSize(RtMemRef *rtMemRef) {
INDEX_TYPE n = rtMemRef->sizes[0];
for (int i = 1; i < rtMemRef->rank; i++)
n *= rtMemRef->sizes[i];
return n;
}
void setDType(RtMemRef *dynMemRef, int onnxType) {
dynMemRef->onnx_dtype = onnxType;
}
int getDType(RtMemRef *dynMemRef) { return dynMemRef->onnx_dtype; }
unsigned int getRank(RtMemRef *dynMemRef) { return dynMemRef->rank; }
void setStrides(RtMemRef *dynMemRef, int64_t *strides) {
for (int i = 0; i < dynMemRef->rank; i++)
dynMemRef->strides[i] = strides[i];
}

View File

@ -1,283 +0,0 @@
//===------------ RtMemRef.h - Dynamic MemRef Implementation -------------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains declaration of Dynamic MemRef data structures and helper
// functions.
//
//===----------------------------------------------------------------------===//
#ifdef __cplusplus
#pragma once
#include <algorithm>
#include <cstdint>
#include <iostream>
#include <numeric>
#include <random>
#include <vector>
#else
#include <stdint.h>
#endif
typedef int64_t INDEX_TYPE;
// Typically, MemRefs in MLIR context are used as a compile-time constructs.
// Information such as element type and rank of the data payload is statically
// encoded, meaning that they are determined and fixed at compile-time. This
// presents significant burden for any runtime components trying to interact
// with the compiled executable.
//
// Thus a version of MemRef struct that is amenable to runtime manipulation is
// provided as a basis for building any runtime-related components providing
// user-facing programming interfaces. All information are dynamically encoded
// as members of this struct so that they can be accessed and modified easily
// during runtime.
//
// We will refer to it as a RMF (Runtime MemRef).
struct RtMemRef {
// Pointer to the raw memory space allocated to host the RMR content. This
// pointer should only be acessed for memory management purposes, not for
// reading RMR content.
void *data;
// Pointer to the properly aligned array of elements stored in this Rmr.
void *alignedData;
// Distance between the start of the raw memory space and the first element of
// the RMR content.
INDEX_TYPE offset;
// Number of dimensions of the array represented by the RMR.
unsigned int rank;
// An array recording the per-dimension sizes of the array represented by the
// RMR.
INDEX_TYPE *sizes;
// An array recording the per-dimension strides of the array represented by
// the RMR.
int64_t *strides;
// Refer to TensorProto_DataType at
// https://github.com/onnx/onnx/blob/cc2230603422bae893d5bc900d2d773ab34400a4/onnx/onnx-ml.proto#L451
// for enum value interpretation.
unsigned int onnx_dtype;
#ifdef __cplusplus
explicit RtMemRef(int _rank);
// Create a full RMR of type T and shape _sizes, with all data fields
// initialized to proper values and data pointers malloc'ed.
template <typename T>
static RtMemRef *create(std::vector<INDEX_TYPE> _sizes) {
auto rmr = new RtMemRef(_sizes.size());
rmr->offset = 0;
rmr->rank = _sizes.size();
rmr->sizes = (INDEX_TYPE *)malloc(rmr->rank * sizeof(INDEX_TYPE));
std::copy(_sizes.begin(), _sizes.end(), rmr->sizes);
rmr->strides = (int64_t *)malloc(rmr->rank * sizeof(int64_t));
auto computedStrides = rmr->computeStridesFromSizes();
std::copy(computedStrides.begin(), computedStrides.end(), rmr->strides);
rmr->data = malloc(rmr->size() * sizeof(T));
rmr->alignedData = rmr->data;
return rmr;
}
// Access an element (by reference) at index position idxs.
template <typename T>
T &elem(std::vector<INDEX_TYPE> idxs) {
INDEX_TYPE elemOffset = computeOffset(idxs);
T *typedPtr = (T *)data;
return typedPtr[elemOffset];
}
// Access an element (by reference) at *flattened* index position idx.
template <typename T>
T &elem(INDEX_TYPE idx) {
T *typedPtr = (T *)data;
return typedPtr[idx];
}
// Get a typed ptr to the data content of the RMR.
template <typename T>
T *typedPtr() {
return (T *)data;
}
// Get how many elements are stored in RMR, as implied by its shape.
INDEX_TYPE size() const;
// Helper function to compute strides of access along each dimensions from its
// shape.
std::vector<int64_t> computeStridesFromSizes() const;
// Compute flattened array idx from a multi-dimensional array idx.
INDEX_TYPE computeOffset(std::vector<INDEX_TYPE> &idxs) const;
// Get the index set (i.e., all valid multi-dimensional array indexes that can
// be used to access this RMR's constituent elements).
std::vector<std::vector<INDEX_TYPE>> indexSet() const;
~RtMemRef();
#endif
};
#ifdef __cplusplus
// Ordered RtMemRef Dictionary is a data structure for wrapping the input
// dynmemrefs so that they can be addressed both by index and by name.
struct OrderedRtMemRefDict;
#else
typedef struct RtMemRef RtMemRef;
typedef struct _OrderedRtMemRefDict OrderedRtMemRefDict;
#endif
#ifdef __cplusplus
extern "C" {
#endif
// Get number of dynamic memrefs in OrderedRtMemRefDict dict.
int numRtMemRefs(OrderedRtMemRefDict *dict);
// Create an ordered dynamic memref dictionary.
OrderedRtMemRefDict *createOrderedRtMemRefDict();
// Get how many dynamic memrefs are in dict.
int64_t getSize(OrderedRtMemRefDict *dict);
// Get how many data elements are in RtMemRef.
INDEX_TYPE getDataSize(RtMemRef *rtMemRef);
// Create a dynmemref with a certain rank.
RtMemRef *createRtMemRef(int rank);
// Get the i-th dynmemref from orderedDict.
RtMemRef *getRtMemRef(OrderedRtMemRefDict *orderedDict, int i);
// Set the i-th dynmemref in orderedDict to be dynMemRef.
void setRtMemRef(OrderedRtMemRefDict *tensorDict, int idx, RtMemRef *dynMemRef);
// Get data pointer from dynMemRef.
void *getData(RtMemRef *dynMemRef);
// Set data pointer for dynMemRef.
void setData(RtMemRef *dynMemRef, void *data);
// Get algined data pointer from dynMemRef.
void *getAlignedData(RtMemRef *);
// Set aligned data pointer for dynMemRef.
void setAlignedData(RtMemRef *, void *);
// Get the data type enum value of the dynMemRef.
int getDType(RtMemRef *dynMemRef);
// Set the data type enum value of the dynMemRef.
void setDType(RtMemRef *dynMemRef, int onnxType);
// Get the rank of the dynMemRef.
unsigned int getRank(RtMemRef *dynMemRef);
// Get ptr to sizes array.
INDEX_TYPE *getSizes(RtMemRef *);
// Set the sizes array (by copying size values from array `sizes`).
void setSizes(RtMemRef *, INDEX_TYPE *sizes);
// Get ptr to strides array.
int64_t *getStrides(RtMemRef *);
// Set the strides array (by copying stride values from array `strides`).
void setStrides(RtMemRef *, int64_t *strides);
#ifdef __cplusplus
}
template <typename T>
void printVector(std::vector<T> vec, std::string _delimiter = ",",
std::ostream &stream = std::cout) {
std::string delimiter;
for (const auto &elem : vec) {
stream << delimiter << elem;
delimiter = _delimiter;
}
}
template <typename T>
RtMemRef *getRndRealRmr(
std::vector<INDEX_TYPE> sizes, T lb = -1.0, T ub = 1.0) {
// Will be used to obtain a seed for the random number engine
std::random_device rd;
// Standard mersenne_twister_engine seeded with rd()
std::mt19937 gen(rd());
std::uniform_real_distribution<> dis(lb, ub);
auto rmr = RtMemRef::create<T>(sizes);
auto ptr = (T *)rmr->data;
std::generate(ptr, ptr + rmr->size(), [&]() { return dis(gen); });
return rmr;
}
template <typename T>
inline bool isRmrClose(
RtMemRef *a, RtMemRef *b, float rtol = 1e-5, float atol = 1e-5) {
// Compare shape.
auto aShape = std::vector<INDEX_TYPE>(a->sizes, a->sizes + a->rank);
auto bShape = std::vector<INDEX_TYPE>(b->sizes, b->sizes + b->rank);
if (aShape != bShape) {
std::cerr << "Shape mismatch ";
printVector(aShape, ",", std::cerr);
std::cerr << " != ";
printVector(bShape, ",", std::cerr);
return false;
}
// Compute absolute difference, verify it's within tolerable range.
std::vector<T> absoluteDiff(a->size());
std::transform(a->typedPtr<T>(), a->typedPtr<T>() + a->size(),
b->typedPtr<T>(), absoluteDiff.begin(), std::minus<>());
std::transform(absoluteDiff.begin(), absoluteDiff.end(), absoluteDiff.begin(),
static_cast<T (*)(T)>(&std::abs));
bool atolSatisfied = std::all_of(
absoluteDiff.begin(), absoluteDiff.end(), [&](T a) { return a < atol; });
// Compute relative difference, verify it's within tolerable range.
std::vector<T> relativeDiff(a->size());
std::transform(absoluteDiff.begin(), absoluteDiff.end(), a->typedPtr<T>(),
relativeDiff.begin(), std::divides<>());
bool rtolSatisfied = std::all_of(
relativeDiff.begin(), relativeDiff.end(), [&](T a) { return a < rtol; });
if (atolSatisfied && rtolSatisfied) {
return true;
} else {
// Figure out where and what went wrong, this can be slow; but hopefully we
// don't need this often.
for (const auto &idx : a->indexSet()) {
T aElem = a->elem<T>(idx);
T bElem = b->elem<T>(idx);
auto elmAbsDiff = std::abs(aElem - bElem);
auto withinRtol = (elmAbsDiff / aElem < rtol);
auto withinAtol = (elmAbsDiff < atol);
if (!withinRtol || !withinAtol) {
std::cerr << "a[";
printVector(idx, ",", std::cerr);
std::cerr << "] = " << aElem << " != ";
std::cerr << "b[";
printVector(idx, ",", std::cerr);
std::cerr << "] = " << bElem << std::endl;
}
}
return false;
}
}
#endif
// Will transition from RtMemRef to RtMemRef soon.
typedef RtMemRef RtMemRef;

View File

@ -7,18 +7,18 @@ if(Java_Development_FOUND AND JNI_FOUND)
# Target for Java runtime jar # Target for Java runtime jar
add_jar(javaruntime add_jar(javaruntime
src/com/ibm/onnxmlir/DynEntryPoint.java src/com/ibm/onnxmlir/DynEntryPoint.java
src/com/ibm/onnxmlir/OrderedRtMemRefDict.java src/com/ibm/onnxmlir/OMTensorList.java
src/com/ibm/onnxmlir/RtMemRef.java src/com/ibm/onnxmlir/OMTensor.java
OUTPUT_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) OUTPUT_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
# Target for JNI runtime lib # Target for JNI runtime lib
add_library(jniruntime STATIC add_library(jniruntime STATIC
jniwrapper.c jnilog.c jnidummy.c jniwrapper.c jnilog.c jnidummy.c
com_ibm_onnxmlir_DynEntryPoint.h jnilog.h ../RtMemRef.h) com_ibm_onnxmlir_DynEntryPoint.h jnilog.h)
set_target_properties(jniruntime PROPERTIES set_target_properties(jniruntime PROPERTIES
POSITION_INDEPENDENT_CODE TRUE) POSITION_INDEPENDENT_CODE TRUE)
target_include_directories(jniruntime PRIVATE target_include_directories(jniruntime PRIVATE
${ONNX_MLIR_SRC_ROOT}/src/Runtime ${ONNX_MLIR_SRC_ROOT}/include
${JAVA_INCLUDE_PATH} ${JAVA_INCLUDE_PATH}
${JAVA_INCLUDE_PATH2}) ${JAVA_INCLUDE_PATH2})

View File

@ -11,7 +11,7 @@ extern "C" {
* Class: com_ibm_onnxmlir_DynEntryPoint * Class: com_ibm_onnxmlir_DynEntryPoint
* Method: main_graph_jni * Method: main_graph_jni
* Signature: * Signature:
* (Lcom/ibm/onnxmlir/OrderedRtMemRefDict;)Lcom/ibm/onnxmlir/OrderedRtMemRefDict; * (Lcom/ibm/onnxmlir/OMTensorList;)Lcom/ibm/onnxmlir/OMTensorList;
*/ */
JNIEXPORT jobject JNICALL Java_com_ibm_onnxmlir_DynEntryPoint_main_1graph_1jni( JNIEXPORT jobject JNICALL Java_com_ibm_onnxmlir_DynEntryPoint_main_1graph_1jni(
JNIEnv *, jclass, jobject); JNIEnv *, jclass, jobject);

View File

@ -63,26 +63,6 @@ enum { LOG_TRACE, LOG_DEBUG, LOG_INFO, LOG_WARNING, LOG_ERROR, LOG_FATAL };
sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \ sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \
} while (0) } while (0)
enum {
ONNX_TYPE_UNDEFINED, /* 0 */
ONNX_TYPE_FLOAT, /* 1 */
ONNX_TYPE_UINT8, /* 2 */
ONNX_TYPE_INT8, /* 3 */
ONNX_TYPE_UINT16, /* 4 */
ONNX_TYPE_INT16, /* 5 */
ONNX_TYPE_INT32, /* 6 */
ONNX_TYPE_INT64, /* 7 */
ONNX_TYPE_STRING, /* 8 */
ONNX_TYPE_BOOL, /* 9 */
ONNX_TYPE_FLOAT16, /* 10 */
ONNX_TYPE_DOUBLE, /* 11 */
ONNX_TYPE_UINT32, /* 12 */
ONNX_TYPE_UINT64, /* 13 */
ONNX_TYPE_COMPLEX64, /* 14 */
ONNX_TYPE_COMPLEX128, /* 15 */
ONNX_TYPE_BFLOAT16, /* 16 */
};
/* Construct string of up to LOG_MAX_NUM elements of a "type" array */ /* Construct string of up to LOG_MAX_NUM elements of a "type" array */
#define LOG_TYPE_BUF(type, buf, data, n) \ #define LOG_TYPE_BUF(type, buf, data, n) \
do { \ do { \

View File

@ -6,17 +6,20 @@
#endif #endif
#include <string.h> #include <string.h>
#include "RtMemRef.h" #include "OnnxMlirRuntime.h"
#include "com_ibm_onnxmlir_DynEntryPoint.h" #include "com_ibm_onnxmlir_DynEntryPoint.h"
#include "jnilog.h" #include "jnilog.h"
/* Declare type var, make call and assign to var, check against val */ /* Declare type var, make call and assign to var, check against val.
* It's assumed that a Java exception has already been thrown so
* this call simply returns NULL.
*/
#define CHECK_CALL(type, var, call, val) \ #define CHECK_CALL(type, var, call, val) \
type var = call; \ type var = call; \
if (var == val) \ if (var == val) \
return NULL return NULL
/* Make a JNI call and throw Java exception if the call failed */ /* Make a JNI call, log error and throw Java exception if the call failed */
#define JNI_CALL(env, stmt) \ #define JNI_CALL(env, stmt) \
stmt; \ stmt; \
do { \ do { \
@ -29,332 +32,354 @@
} while (0) } while (0)
/* Make a JNI call and assign return value to var, /* Make a JNI call and assign return value to var,
* throw Java exception if the call failed * log error and throw Java exception if the call failed
*/ */
#define JNI_VAR_CALL(env, var, call) JNI_CALL(env, var = call) #define JNI_VAR_CALL(env, var, call) JNI_CALL(env, var = call)
/* Declare type var, make a JNI call and assign return value to var, /* Declare type var, make a JNI call and assign return value to var,
* throw Java exception if the call failed * log error and throw Java exception if the call failed
*/ */
#define JNI_TYPE_VAR_CALL(env, type, var, call) JNI_CALL(env, type var = call); #define JNI_TYPE_VAR_CALL(env, type, var, call) JNI_CALL(env, type var = call);
/* If cond is true (native code failed), log error and throw Java exception */ /* Make a native library call, if cond is true (native code failed),
#define JNI_COND(type, var, call, val, env, cls, ...) \ * log error and throw Java exception
type var = call; \ */
#define LIB_CALL(stmt, check, env, cls, ...) \
stmt; \
do { \ do { \
if (var == val) { \ if (check) { \
LOG_PRINTF(LOG_ERROR, __VA_ARGS__); \ LOG_PRINTF(LOG_ERROR, __VA_ARGS__); \
(*env)->ThrowNew(env, cls, "native code error"); \ (*env)->ThrowNew(env, cls, "native code error"); \
return NULL; \ return NULL; \
} \ } \
} while (0) } while (0)
/* Debug output of RtMemRef fields */ /* Make a native library call and assign return value to var,
#define RMR_DEBUG(i, type, rank, sizes, strides, data, datasize) \ * log error and throw Java exception if the call failed
*/
#define LIB_VAR_CALL(var, call, val, env, cls, ...) \
LIB_CALL(var = call, var == val, env, cls, __VA_ARGS__);
/* Declare type var, make a native library call and assign return value to var,
* log error and throw Java exception if the call failed
*/
#define LIB_TYPE_VAR_CALL(type, var, call, val, env, cls, ...) \
LIB_CALL(type var = call, var == val, env, cls, __VA_ARGS__);
/* Debug output of OMTensor fields */
#define OMT_DEBUG( \
i, n, data, dataSizes, dataStrides, dataType, dataBufferSize, rank, name) \
do { \ do { \
char tmp[1024]; \ char tmp[1024]; \
LOG_PRINTF(LOG_DEBUG, "rmr[%d]:type=%d", i, type); \ LOG_TYPE_BUF(dataType, tmp, data, n); \
LOG_PRINTF(LOG_DEBUG, "rmr[%d]:rank=%d", i, rank); \ LOG_PRINTF(LOG_DEBUG, "omt[%d]:data=[%s]", i, tmp); \
LOG_LONG_BUF(tmp, sizes, rank); \ LOG_LONG_BUF(tmp, dataSizes, rank); \
LOG_PRINTF(LOG_DEBUG, "rmr[%d]:sizes=[%s]", i, tmp); \ LOG_PRINTF(LOG_DEBUG, "omt[%d]:dataSizes=[%s]", i, tmp); \
LOG_LONG_BUF(tmp, strides, rank); \ LOG_LONG_BUF(tmp, dataStrides, rank); \
LOG_PRINTF(LOG_DEBUG, "rmr[%d]:strides=[%s]", i, tmp); \ LOG_PRINTF(LOG_DEBUG, "omt[%d]:dataStrides=[%s]", i, tmp); \
LOG_TYPE_BUF(type, tmp, data, datasize); \ LOG_PRINTF(LOG_DEBUG, "omt[%d]:dataType=%d", i, dataType); \
LOG_PRINTF(LOG_DEBUG, "rmr[%d]:data=[%s]", i, tmp); \ LOG_PRINTF(LOG_DEBUG, "omt[%d]:dataBufferSize=%ld", i, dataBufferSize); \
LOG_PRINTF(LOG_DEBUG, "omt[%d]:rank=%d", i, rank); \
LOG_PRINTF(LOG_DEBUG, "omt[%d]:name=%s", i, name); \
LOG_PRINTF(LOG_DEBUG, "omt[%d]:numOfElems=%ld", i, n); \
} while (0) } while (0)
/* Model shared library entry point */ /* Model shared library entry point */
extern OrderedRtMemRefDict *_dyn_entry_point_main_graph(OrderedRtMemRefDict *); extern OMTensorList *_dyn_entry_point_main_graph(OMTensorList *);
/* ONNX type to size (number of bytes) mapping */
int onnx_type_size[] = {
0, /* UNDEFINED = 0 */
4, /* FLOAT = 1 */
1, /* UINT8 = 2 */
1, /* INT8 = 3 */
2, /* UINT16 = 4 */
2, /* INT16 = 5 */
4, /* INT32 = 6 */
8, /* INT64 = 7 */
0, /* STRING = 8 */
1, /* BOOL = 9 */
2, /* FLOAT16 = 10 */
8, /* DOUBLE = 11 */
4, /* UINT32 = 12 */
8, /* UINT64 = 13 */
8, /* COMPLEX64 = 14 */
16, /* COMPLEX128 = 15 */
2, /* BFLOAT16 = 16 */
};
/* Java classes and methods needed for making various JNI API calls */ /* Java classes and methods needed for making various JNI API calls */
typedef struct { typedef struct {
jclass ecpt_cls; /* java/lang/Exception class */ jclass ecpt_cls; /* java/lang/Exception class */
jclass long_cls; /* java/lang/Long class */ jclass long_cls; /* java/lang/Long class */
jclass string_cls; /* java/lang/String class */ jclass string_cls; /* java/lang/String class */
jclass ormrd_cls; /* com/ibm/onnxmlir/OrderedRtMemRefDict class */ jclass omt_cls; /* com/ibm/onnxmlir/OMTensor class */
jclass rmr_cls; /* com/ibm/onnxmlir/RtMemRef class */ jclass omt_list_cls; /* com/ibm/onnxmlir/OMTensorList class */
jmethodID ormrd_constructor; /* OrderedRtMemRefDict constructor */ jmethodID omt_constructor; /* OMTensor constructor */
jmethodID ormrd_getRmrs; /* OrderedRtMemRefDict getRmrs method */ jmethodID omt_getData; /* OMTensor getData method */
jmethodID ormrd_getNames; /* OrderedRtMemRefDict getNames method */ jmethodID omt_setData; /* OMTensor setData method */
jmethodID omt_getDataSizes; /* OMTensor getDataSizes method */
jmethodID omt_setDataSizes; /* OMTensor setDataSizes method */
jmethodID omt_getDataStrides; /* OMTensor getDataStrides method */
jmethodID omt_setDataStrides; /* OMTensor setDataStrides method */
jmethodID omt_getDataType; /* OMTensor getDataType method */
jmethodID omt_setDataType; /* OMTensor setDataType method */
jmethodID omt_getDataBufferSize; /* OMTensor getDataBufferSize method */
jmethodID omt_getRank; /* OMTensor getRank method */
jmethodID omt_getName; /* OMTensor getName method */
jmethodID omt_setName; /* OMTensor setName method */
jmethodID omt_getNumOfElems; /* OMTensor getNumOfElems method */
jmethodID rmr_constructor; /* RtMemRef constructor */ jmethodID omt_list_constructor; /* OMTensorList constructor */
jmethodID rmr_getType; /* RtMemRef getType method */ jmethodID omt_list_getOmts; /* OMTensorList getOmts method */
jmethodID rmr_setType; /* RtMemRef setType method */
jmethodID rmr_getRank; /* RtMemRef getRank method */
jmethodID rmr_getData; /* RtMemRef getData method */
jmethodID rmr_setData; /* RtMemRef setData method */
jmethodID rmr_getSizes; /* RtMemRef getSizes method */
jmethodID rmr_setSizes; /* RtMemRef setSizes method */
jmethodID rmr_getStrides; /* RtMemRef getStrides method */
jmethodID rmr_setStrides; /* RtMemRef setStrides method */
jmethodID rmr_getDataSize; /* RtMemRef getDataSize method */
} jniapi_t; } jniapi_t;
jniapi_t jniapi; jniapi_t jniapi;
/* Fill in struct jniapi */ /* Fill in struct jniapi */
jniapi_t *fill_jniapi(JNIEnv *env, jniapi_t *japi) { jniapi_t *fill_jniapi(JNIEnv *env, jniapi_t *japi) {
/* Get Java Exception, Long, String, OrderedRtMemRefDict, and RtMemRef classes /* Get Java Exception, Long, String, OMTensor, and OMTensorList classes
*/ */
JNI_VAR_CALL( JNI_VAR_CALL(
env, japi->ecpt_cls, (*env)->FindClass(env, "java/lang/Exception")); env, japi->ecpt_cls, (*env)->FindClass(env, "java/lang/Exception"));
JNI_VAR_CALL(env, japi->long_cls, (*env)->FindClass(env, "java/lang/Long")); JNI_VAR_CALL(env, japi->long_cls, (*env)->FindClass(env, "java/lang/Long"));
JNI_VAR_CALL( JNI_VAR_CALL(
env, japi->string_cls, (*env)->FindClass(env, "java/lang/String")); env, japi->string_cls, (*env)->FindClass(env, "java/lang/String"));
JNI_VAR_CALL(env, japi->ormrd_cls,
(*env)->FindClass(env, "com/ibm/onnxmlir/OrderedRtMemRefDict"));
JNI_VAR_CALL( JNI_VAR_CALL(
env, japi->rmr_cls, (*env)->FindClass(env, "com/ibm/onnxmlir/RtMemRef")); env, japi->omt_cls, (*env)->FindClass(env, "com/ibm/onnxmlir/OMTensor"));
JNI_VAR_CALL(env, japi->omt_list_cls,
(*env)->FindClass(env, "com/ibm/onnxmlir/OMTensorList"));
/* Get method ID of constructor and various methods in OrderedRtMemRefDict */ /* Get method ID of constructor and various methods in OMTensor */
JNI_VAR_CALL(env, japi->ormrd_constructor, JNI_VAR_CALL(env, japi->omt_constructor,
(*env)->GetMethodID(env, japi->omt_cls, "<init>", "(I)V"));
JNI_VAR_CALL(env, japi->omt_getData,
(*env)->GetMethodID( (*env)->GetMethodID(
env, japi->ormrd_cls, "<init>", "([Lcom/ibm/onnxmlir/RtMemRef;)V")); env, japi->omt_cls, "getData", "()Ljava/nio/ByteBuffer;"));
JNI_VAR_CALL(env, japi->ormrd_getRmrs, JNI_VAR_CALL(env, japi->omt_setData,
(*env)->GetMethodID( (*env)->GetMethodID(
env, japi->ormrd_cls, "getRmrs", "()[Lcom/ibm/onnxmlir/RtMemRef;")); env, japi->omt_cls, "setData", "(Ljava/nio/ByteBuffer;)V"));
JNI_VAR_CALL(env, japi->ormrd_getNames, JNI_VAR_CALL(env, japi->omt_getDataSizes,
(*env)->GetMethodID(env, japi->omt_cls, "getDataSizes", "()[J"));
JNI_VAR_CALL(env, japi->omt_setDataSizes,
(*env)->GetMethodID(env, japi->omt_cls, "setDataSizes", "([J)V"));
JNI_VAR_CALL(env, japi->omt_getDataStrides,
(*env)->GetMethodID(env, japi->omt_cls, "getDataStrides", "()[J"));
JNI_VAR_CALL(env, japi->omt_setDataStrides,
(*env)->GetMethodID(env, japi->omt_cls, "setDataStrides", "([J)V"));
JNI_VAR_CALL(env, japi->omt_getDataType,
(*env)->GetMethodID(env, japi->omt_cls, "getDataType", "()I"));
JNI_VAR_CALL(env, japi->omt_setDataType,
(*env)->GetMethodID(env, japi->omt_cls, "setDataType", "(I)V"));
JNI_VAR_CALL(env, japi->omt_getDataBufferSize,
(*env)->GetMethodID(env, japi->omt_cls, "getDataBufferSize", "()J"));
JNI_VAR_CALL(env, japi->omt_getRank,
(*env)->GetMethodID(env, japi->omt_cls, "getRank", "()I"));
JNI_VAR_CALL(env, japi->omt_getName,
(*env)->GetMethodID( (*env)->GetMethodID(
env, japi->ormrd_cls, "getNames", "()[Ljava/lang/String;")); env, japi->omt_cls, "getName", "()Ljava/lang/String;"));
JNI_VAR_CALL(env, japi->omt_setName,
(*env)->GetMethodID(
env, japi->omt_cls, "setName", "(Ljava/lang/String;)V"));
JNI_VAR_CALL(env, japi->omt_getNumOfElems,
(*env)->GetMethodID(env, japi->omt_cls, "getNumOfElems", "()J"));
/* Get method ID of constructor and various methods in RtMemRef */ /* Get method ID of constructor and various methods in OMTensorList */
JNI_VAR_CALL(env, japi->rmr_constructor, JNI_VAR_CALL(env, japi->omt_list_constructor,
(*env)->GetMethodID(env, japi->rmr_cls, "<init>", "(I)V")); (*env)->GetMethodID(env, japi->omt_list_cls, "<init>",
JNI_VAR_CALL(env, japi->rmr_getType, "([Lcom/ibm/onnxmlir/OMTensor;)V"));
(*env)->GetMethodID(env, japi->rmr_cls, "getType", "()I")); JNI_VAR_CALL(env, japi->omt_list_getOmts,
JNI_VAR_CALL(env, japi->rmr_setType, (*env)->GetMethodID(env, japi->omt_list_cls, "getOmts",
(*env)->GetMethodID(env, japi->rmr_cls, "setType", "(I)V")); "()[Lcom/ibm/onnxmlir/OMTensor;"));
JNI_VAR_CALL(env, japi->rmr_getRank,
(*env)->GetMethodID(env, japi->rmr_cls, "getRank", "()I"));
JNI_VAR_CALL(env, japi->rmr_getData,
(*env)->GetMethodID(
env, japi->rmr_cls, "getData", "()Ljava/nio/ByteBuffer;"));
JNI_VAR_CALL(env, japi->rmr_setData,
(*env)->GetMethodID(
env, japi->rmr_cls, "setData", "(Ljava/nio/ByteBuffer;)V"));
JNI_VAR_CALL(env, japi->rmr_getSizes,
(*env)->GetMethodID(env, japi->rmr_cls, "getSizes", "()[J"));
JNI_VAR_CALL(env, japi->rmr_setSizes,
(*env)->GetMethodID(env, japi->rmr_cls, "setSizes", "([J)V"));
JNI_VAR_CALL(env, japi->rmr_getStrides,
(*env)->GetMethodID(env, japi->rmr_cls, "getStrides", "()[J"));
JNI_VAR_CALL(env, japi->rmr_setStrides,
(*env)->GetMethodID(env, japi->rmr_cls, "setStrides", "([J)V"));
JNI_VAR_CALL(env, japi->rmr_getDataSize,
(*env)->GetMethodID(env, japi->rmr_cls, "getDataSize", "()J"));
return japi; return japi;
} }
/* Convert Java object to native data structure */ /* Convert Java object to native data structure */
OrderedRtMemRefDict *ormrd_java_to_native( OMTensorList *omt_list_java_to_native(
JNIEnv *env, jclass cls, jobject obj, jniapi_t *japi) { JNIEnv *env, jclass cls, jobject obj, jniapi_t *japi) {
/* Get object array "rmrs" and "names" in OrderedRtMemRefDict */
JNI_TYPE_VAR_CALL(env, jobjectArray, ormrd_rmrs,
(*env)->CallObjectMethod(env, obj, japi->ormrd_getRmrs));
JNI_TYPE_VAR_CALL(env, jobjectArray, ormrd_names,
(*env)->CallObjectMethod(env, obj, japi->ormrd_getNames));
/* Get length of object array "rmrs" and "names" in OrderedRtMemRefDict */ /* Get OMTensor array Java object in OMTensorList */
JNI_TYPE_VAR_CALL(env, jobjectArray, omt_list_omts,
(*env)->CallObjectMethod(env, obj, japi->omt_list_getOmts));
/* Get the number of OMTensors in the array */
JNI_TYPE_VAR_CALL( JNI_TYPE_VAR_CALL(
env, jsize, ormrd_rmrs_len, (*env)->GetArrayLength(env, ormrd_rmrs)); env, jsize, omt_list_nomt, (*env)->GetArrayLength(env, omt_list_omts));
JNI_TYPE_VAR_CALL(
env, jsize, ormrd_names_len, (*env)->GetArrayLength(env, ormrd_names));
/* Allocate memory for holding each Java rmr object and name string, /* Allocate memory for holding each Java omt object and OMTensor pointers
* and RtMemRef and char pointers for constructing native RtMemRef and name * for constructing native OMTensor array
* array
*/ */
JNI_COND(jobject *, obj_rmr, malloc(ormrd_rmrs_len * sizeof(jobject)), NULL, LIB_TYPE_VAR_CALL(jobject *, obj_omts,
env, japi->ecpt_cls, "obj_rmr=null"); malloc(omt_list_nomt * sizeof(jobject)), NULL, env, japi->ecpt_cls,
JNI_COND(jstring *, obj_name, malloc(ormrd_names_len * sizeof(jstring)), NULL, "obj_omts=null");
env, japi->ecpt_cls, "obj_name=null"); LIB_TYPE_VAR_CALL(OMTensor **, jni_omts,
JNI_COND(RtMemRef **, jni_rmr, malloc(ormrd_rmrs_len * sizeof(RtMemRef *)), malloc(omt_list_nomt * sizeof(OMTensor *)), NULL, env, japi->ecpt_cls,
NULL, env, japi->ecpt_cls, "jni_rmr=null"); "jni_omts=null");
JNI_COND(const char **, jni_name,
malloc(ormrd_names_len * sizeof(const char *)), NULL, env, japi->ecpt_cls,
"jni_name=null");
/* Create OrderedRtMemRefDict to be constructed and passed to the model shared /* Loop through all the omt_list_omts */
* library */ for (int i = 0; i < omt_list_nomt; i++) {
JNI_COND(OrderedRtMemRefDict *, ormrd, createOrderedRtMemRefDict(), NULL, env,
japi->ecpt_cls, "ormrd=null");
/* Loop through all the ormrd_rmrs and ormrd_names */
for (int i = 0; i < ormrd_rmrs_len; i++) {
JNI_VAR_CALL( JNI_VAR_CALL(
env, obj_rmr[i], (*env)->GetObjectArrayElement(env, ormrd_rmrs, i)); env, obj_omts[i], (*env)->GetObjectArrayElement(env, omt_list_omts, i));
JNI_VAR_CALL(
env, obj_name[i], (*env)->GetObjectArrayElement(env, ormrd_names, i));
/* Get type, rank, data, sizes, and strides by calling corresponding methods /* Get data, dataSizes, dataStrides, dataType, rank, name and
* dataBufferSize by calling corresponding methods
*/ */
JNI_TYPE_VAR_CALL(env, jint, rmr_type, JNI_TYPE_VAR_CALL(env, jobject, omt_data,
(*env)->CallIntMethod(env, obj_rmr[i], japi->rmr_getType)); (*env)->CallObjectMethod(env, obj_omts[i], japi->omt_getData));
JNI_TYPE_VAR_CALL(env, jint, rmr_rank, JNI_TYPE_VAR_CALL(env, jobject, omt_dataSizes,
(*env)->CallIntMethod(env, obj_rmr[i], japi->rmr_getRank)); (*env)->CallObjectMethod(env, obj_omts[i], japi->omt_getDataSizes));
JNI_TYPE_VAR_CALL(env, jlong, rmr_datasize, JNI_TYPE_VAR_CALL(env, jobject, omt_dataStrides,
(*env)->CallLongMethod(env, obj_rmr[i], japi->rmr_getDataSize)); (*env)->CallObjectMethod(env, obj_omts[i], japi->omt_getDataStrides));
JNI_TYPE_VAR_CALL(env, jobject, rmr_data, JNI_TYPE_VAR_CALL(env, jint, omt_dataType,
(*env)->CallObjectMethod(env, obj_rmr[i], japi->rmr_getData)); (*env)->CallIntMethod(env, obj_omts[i], japi->omt_getDataType));
JNI_TYPE_VAR_CALL(env, jobject, rmr_sizes, JNI_TYPE_VAR_CALL(env, jlong, omt_dataBufferSize,
(*env)->CallObjectMethod(env, obj_rmr[i], japi->rmr_getSizes)); (*env)->CallLongMethod(env, obj_omts[i], japi->omt_getDataBufferSize));
JNI_TYPE_VAR_CALL(env, jobject, rmr_strides, JNI_TYPE_VAR_CALL(env, jint, omt_rank,
(*env)->CallObjectMethod(env, obj_rmr[i], japi->rmr_getStrides)); (*env)->CallIntMethod(env, obj_omts[i], japi->omt_getRank));
JNI_TYPE_VAR_CALL(env, jstring, omt_name,
/* Primitive type int and long can be directly used */ (*env)->CallObjectMethod(env, obj_omts[i], japi->omt_getName));
int jni_type = rmr_type, jni_rank = rmr_rank; JNI_TYPE_VAR_CALL(env, jlong, omt_numOfElems,
long jni_datasize = rmr_datasize; (*env)->CallLongMethod(env, obj_omts[i], japi->omt_getNumOfElems));
/* Get direct buffer associated with data */ /* Get direct buffer associated with data */
JNI_TYPE_VAR_CALL( JNI_TYPE_VAR_CALL(
env, void *, jni_data, (*env)->GetDirectBufferAddress(env, rmr_data)); env, void *, jni_data, (*env)->GetDirectBufferAddress(env, omt_data));
/* Get long array associated with sizes and strides */ /* Get long array associated with data sizes and strides */
JNI_TYPE_VAR_CALL(env, long *, jni_sizes, JNI_TYPE_VAR_CALL(env, long *, jni_dataSizes,
(*env)->GetLongArrayElements(env, rmr_sizes, NULL)); (*env)->GetLongArrayElements(env, omt_dataSizes, NULL));
JNI_TYPE_VAR_CALL(env, long *, jni_strides, JNI_TYPE_VAR_CALL(env, long *, jni_dataStrides,
(*env)->GetLongArrayElements(env, rmr_strides, NULL)); (*env)->GetLongArrayElements(env, omt_dataStrides, NULL));
/* Primitive type int and long can be directly used */
int jni_dataType = omt_dataType;
long jni_dataBufferSize = omt_dataBufferSize;
int jni_rank = omt_rank;
long jni_numOfElems = omt_numOfElems;
/* Get name string */
JNI_TYPE_VAR_CALL(env, char *, jni_name,
(char *)(*env)->GetStringUTFChars(env, omt_name, NULL));
/* Print debug info on what we got from the Java side */ /* Print debug info on what we got from the Java side */
RMR_DEBUG( OMT_DEBUG(i, jni_numOfElems, jni_data, jni_dataSizes, jni_dataStrides,
i, jni_type, jni_rank, jni_sizes, jni_strides, jni_data, jni_datasize); jni_dataType, jni_dataBufferSize, jni_rank, jni_name);
/* Create native RtMemRef struct and fill in its fields */ /* Create native OMTensor struct and fill in its fields */
jni_rmr[i] = createRtMemRef(jni_rank); LIB_VAR_CALL(jni_omts[i], omt_create(jni_rank), NULL, env, japi->ecpt_cls,
setDType(jni_rmr[i], jni_type); "jni_omts[%d]=null", i);
setData(jni_rmr[i], jni_data); omt_setData(jni_omts[i], jni_data);
setSizes(jni_rmr[i], jni_sizes); omt_setDataSizes(jni_omts[i], jni_dataSizes);
setStrides(jni_rmr[i], jni_strides); omt_setDataStrides(jni_omts[i], jni_dataStrides);
omt_setDataType(jni_omts[i], jni_dataType);
/*jni_name[i] = (*env)->GetStringUTFChars(env, obj_name[i], NULL); omt_setName(jni_omts[i], jni_name);
printf("jni_name=%s\n", jni_name[i]);*/
/* Install RtMemRef into OrderedRtMemRefDict */
setRtMemRef(ormrd, i, jni_rmr[i]);
/* Release reference to the java objects */ /* Release reference to the java objects */
JNI_CALL(
env, (*env)->ReleaseLongArrayElements(env, rmr_sizes, jni_sizes, 0));
JNI_CALL(env, JNI_CALL(env,
(*env)->ReleaseLongArrayElements(env, rmr_strides, jni_strides, 0)); (*env)->ReleaseLongArrayElements(env, omt_dataSizes, jni_dataSizes, 0));
JNI_CALL(env, (*env)->ReleaseLongArrayElements(
env, omt_dataStrides, jni_dataStrides, 0));
JNI_CALL(env, (*env)->ReleaseStringUTFChars(env, omt_name, jni_name));
} }
/* setRtMemRef(ormrd, jni_rmr, jni_name); */ /* Create OMTensorList to be constructed and passed to the
return ormrd; * model shared library
*/
LIB_TYPE_VAR_CALL(OMTensorList *, list,
omt_list_create(jni_omts, omt_list_nomt), NULL, env, japi->ecpt_cls,
"list=null");
return list;
} }
/* Convert native data structure to Java object */ /* Convert native data structure to Java object */
jobject ormrd_native_to_java( jobject omt_list_native_to_java(
JNIEnv *env, jclass cls, OrderedRtMemRefDict *dict, jniapi_t *japi) { JNIEnv *env, jclass cls, OMTensorList *dict, jniapi_t *japi) {
JNI_COND(int, nrmr, numRtMemRefs(dict), 0, env, japi->ecpt_cls, "nrmr=0");
/* Create RtMemRef java object array */ /* Get the OMTensor array in the OMTensorList */
JNI_TYPE_VAR_CALL(env, jobjectArray, rmrs, LIB_TYPE_VAR_CALL(OMTensor **, jni_omts, omTensorListGetOmts(dict), NULL, env,
(*env)->NewObjectArray(env, nrmr, japi->rmr_cls, NULL)); japi->ecpt_cls, "jni_omts=null");
/* Get the number of OMTensors in the OMTensorList */
LIB_TYPE_VAR_CALL(int, jni_nomt, omTensorListGetNumOfOmts(dict), 0, env,
japi->ecpt_cls, "jni_nomt=0");
/* Loop through the native RtMemRef structs */ /* Create OMTensor java object array */
for (int i = 0; i < nrmr; i++) { JNI_TYPE_VAR_CALL(env, jobjectArray, obj_omts,
JNI_COND(RtMemRef *, rmr, getRtMemRef(dict, i), NULL, env, japi->ecpt_cls, (*env)->NewObjectArray(env, jni_nomt, japi->omt_cls, NULL));
"rmr[%d]=null", i);
JNI_COND(int, jni_type, getDType(rmr), 0, env, japi->ecpt_cls, /* Loop through the native OMTensor structs */
"rmr[%d]:type=0", i); for (int i = 0; i < jni_nomt; i++) {
JNI_COND(int, jni_rank, getRank(rmr), 0, env, japi->ecpt_cls,
"rmr[%d]:rank=0", i); LIB_TYPE_VAR_CALL(void *, jni_data, omTensorGetData(jni_omts[i]), NULL, env,
JNI_COND(long, jni_datasize, getDataSize(rmr), 0, env, japi->ecpt_cls, japi->ecpt_cls, "omt[%d]:data=null", i);
"rmr[%d]:datasize=0", i); LIB_TYPE_VAR_CALL(long *, jni_dataSizes, omTensorGetDataSizes(jni_omts[i]),
JNI_COND(void *, jni_data, getData(rmr), NULL, env, japi->ecpt_cls, NULL, env, japi->ecpt_cls, "omt[%d]:dataSizes=null", i);
"rmr[%d]:data=null", i); LIB_TYPE_VAR_CALL(long *, jni_dataStrides,
JNI_COND(long *, jni_sizes, getSizes(rmr), NULL, env, japi->ecpt_cls, omTensorGetDataStrides(jni_omts[i]), NULL, env, japi->ecpt_cls,
"rmr[%d]:sizes=null", i); "omt[%d]:dataStrides=null", i);
JNI_COND(long *, jni_strides, getStrides(rmr), NULL, env, japi->ecpt_cls, LIB_TYPE_VAR_CALL(int, jni_dataType, omTensorGetDataType(jni_omts[i]), 0,
"rmr[%d]:strides=null", i); env, japi->ecpt_cls, "omt[%d]:dataType=0", i);
LIB_TYPE_VAR_CALL(long, jni_dataBufferSize,
omt_getDataBufferSize(jni_omts[i]), 0, env, japi->ecpt_cls,
"omt[%ld]:dataBufferSize=0", i);
LIB_TYPE_VAR_CALL(int, jni_rank, omTensorGetRank(jni_omts[i]), 0, env,
japi->ecpt_cls, "omt[%d]:rank=0", i);
LIB_TYPE_VAR_CALL(char *, jni_name, omTensorGetName(jni_omts[i]), NULL, env,
japi->ecpt_cls, "omt[%d]:name=null", i);
LIB_TYPE_VAR_CALL(long, jni_numOfElems, omTensorGetNumOfElems(jni_omts[i]),
0, env, japi->ecpt_cls, "omt[%d]:numOfElems=0", i);
/* Print debug info on what we got from the native side */ /* Print debug info on what we got from the native side */
RMR_DEBUG( OMT_DEBUG(i, jni_numOfElems, jni_data, jni_dataSizes, jni_dataStrides,
i, jni_type, jni_rank, jni_sizes, jni_strides, jni_data, jni_datasize); jni_dataType, jni_dataBufferSize, jni_rank, jni_name);
/* create the following Java objects: /* Create the OMTensor Java object */
* - RtMemRef JNI_TYPE_VAR_CALL(env, jobject, obj_omt,
* - DirectByteBuffer (from native buffers) (*env)->NewObject(env, japi->omt_cls, japi->omt_constructor, jni_rank));
* - long array for sizes and strides
/* Create direct byte buffer Java object from native data buffer, and
* call setData method
*/
JNI_TYPE_VAR_CALL(env, jobject, omt_data,
(*env)->NewDirectByteBuffer(env, jni_data, jni_dataBufferSize));
JNI_CALL(env,
(*env)->CallObjectMethod(env, obj_omt, japi->omt_setData, omt_data));
/* Create data sizes array Java object, fill in from native array, and
* call setDataSizes method
*/ */
JNI_TYPE_VAR_CALL(env, jobject, obj_rmr,
(*env)->NewObject(env, japi->rmr_cls, japi->rmr_constructor, jni_rank));
JNI_TYPE_VAR_CALL(env, jobject, rmr_data,
(*env)->NewDirectByteBuffer(
env, jni_data, jni_datasize * onnx_type_size[jni_type]));
JNI_TYPE_VAR_CALL( JNI_TYPE_VAR_CALL(
env, jlongArray, rmr_sizes, (*env)->NewLongArray(env, jni_rank)); env, jlongArray, omt_dataSizes, (*env)->NewLongArray(env, jni_rank));
JNI_TYPE_VAR_CALL( JNI_CALL(env, (*env)->SetLongArrayRegion(
env, jlongArray, rmr_strides, (*env)->NewLongArray(env, jni_rank)); env, omt_dataSizes, 0, jni_rank, jni_dataSizes));
/* Call setType method */
JNI_CALL(env,
(*env)->CallObjectMethod(env, obj_rmr, japi->rmr_setType, jni_type));
/* Call setData method */
JNI_CALL(env,
(*env)->CallObjectMethod(env, obj_rmr, japi->rmr_setData, rmr_data));
/* Fill in sizes array from native array and call setSizes method */
JNI_CALL(env,
(*env)->SetLongArrayRegion(env, rmr_sizes, 0, jni_rank, jni_sizes));
JNI_CALL(env,
(*env)->CallObjectMethod(env, obj_rmr, japi->rmr_setSizes, rmr_sizes));
/* Fill in strides array from native array and call setStrides method */
JNI_CALL(env,
(*env)->SetLongArrayRegion(env, rmr_strides, 0, jni_rank, jni_strides));
JNI_CALL(env, (*env)->CallObjectMethod( JNI_CALL(env, (*env)->CallObjectMethod(
env, obj_rmr, japi->rmr_setStrides, rmr_strides)); env, obj_omt, japi->omt_setDataSizes, omt_dataSizes));
/* Set DynMemRef object in the object array */ /* Create data strides array Java object, fill in from native array, and
JNI_CALL(env, (*env)->SetObjectArrayElement(env, rmrs, i, obj_rmr)); * call setStrides method
*/
JNI_TYPE_VAR_CALL(
env, jlongArray, omt_dataStrides, (*env)->NewLongArray(env, jni_rank));
JNI_CALL(env, (*env)->SetLongArrayRegion(
env, omt_dataStrides, 0, jni_rank, jni_dataStrides));
JNI_CALL(env, (*env)->CallObjectMethod(
env, obj_omt, japi->omt_setDataStrides, omt_dataStrides));
/* Primitive type int can be directly used. Call setDataType method */
JNI_CALL(env, (*env)->CallIntMethod(
env, obj_omt, japi->omt_setDataType, (jint)jni_dataType));
/* Create string Java object from native char * and call setName method */
JNI_TYPE_VAR_CALL(
env, jstring, omt_name, (*env)->NewStringUTF(env, jni_name));
JNI_CALL(env,
(*env)->CallObjectMethod(env, obj_omt, japi->omt_setName, omt_name));
/* Set OMTensor object in the object array */
JNI_CALL(env, (*env)->SetObjectArrayElement(env, obj_omts, i, obj_omt));
} }
/* Create the OrderedRtMemRefDict java object */ /* Create the OMTensorList java object */
JNI_TYPE_VAR_CALL(env, jobject, ormrd, JNI_TYPE_VAR_CALL(env, jobject, list,
(*env)->NewObject(env, japi->ormrd_cls, japi->ormrd_constructor, rmrs)); (*env)->NewObject(
env, japi->omt_list_cls, japi->omt_list_constructor, obj_omts));
return ormrd; return list;
} }
JNIEXPORT jobject JNICALL Java_com_ibm_onnxmlir_DynEntryPoint_main_1graph_1jni( JNIEXPORT jobject JNICALL Java_com_ibm_onnxmlir_DynEntryPoint_main_1graph_1jni(
JNIEnv *env, jclass cls, jobject obj) { JNIEnv *env, jclass cls, jobject obj) {
CHECK_CALL(jniapi_t *, japi, fill_jniapi(env, &jniapi), NULL); CHECK_CALL(jniapi_t *, japi, fill_jniapi(env, &jniapi), NULL);
CHECK_CALL(OrderedRtMemRefDict *, input_ormrd, CHECK_CALL(OMTensorList *, input_list,
ormrd_java_to_native(env, cls, obj, japi), NULL); omt_list_java_to_native(env, cls, obj, japi), NULL);
CHECK_CALL(OrderedRtMemRefDict *, dict,
_dyn_entry_point_main_graph(input_ormrd), NULL);
CHECK_CALL( CHECK_CALL(
jobject, output_ormrd, ormrd_native_to_java(env, cls, dict, japi), NULL); OMTensorList *, dict, _dyn_entry_point_main_graph(input_list), NULL);
return output_ormrd; CHECK_CALL(jobject, output_list,
omt_list_native_to_java(env, cls, dict, japi), NULL);
return output_list;
} }

View File

@ -56,9 +56,9 @@ public class DynEntryPoint {
} }
} }
private static native OrderedRtMemRefDict main_graph_jni(OrderedRtMemRefDict ormrd); private static native OMTensorList main_graph_jni(OMTensorList list);
public static OrderedRtMemRefDict main_graph(OrderedRtMemRefDict ormrd) { public static OMTensorList main_graph(OMTensorList list) {
return main_graph_jni(ormrd); return main_graph_jni(list);
} }
} }

View File

@ -1,5 +1,6 @@
package com.ibm.onnxmlir; package com.ibm.onnxmlir;
import java.lang.reflect.Array;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.ByteOrder; import java.nio.ByteOrder;
import java.nio.DoubleBuffer; import java.nio.DoubleBuffer;
@ -8,7 +9,7 @@ import java.nio.IntBuffer;
import java.nio.LongBuffer; import java.nio.LongBuffer;
import java.nio.ShortBuffer; import java.nio.ShortBuffer;
public class RtMemRef { public class OMTensor {
final ByteOrder endian = ByteOrder.nativeOrder(); final ByteOrder endian = ByteOrder.nativeOrder();
/* We can use enum but that creates another class /* We can use enum but that creates another class
@ -52,47 +53,26 @@ public class RtMemRef {
2, /* BFLOAT16 */ 2, /* BFLOAT16 */
}; };
private ByteBuffer _data; private ByteBuffer _allocatedPtr;
private int _type; private long[] _shape;
private long[] _stride;
private int _dataType;
private int _rank; private int _rank;
private long[] _sizes; private String _name;
private long[] _strides;
/** /**
* Constructor * Constructor
*/ */
public RtMemRef(int rank) { public OMTensor(int rank) {
if (rank <= 0) if (rank <= 0)
throw new IllegalArgumentException( throw new IllegalArgumentException(
"invalid rank " + rank); "invalid rank " + rank);
_data = null; _allocatedPtr = null;
_type = ONNX_TYPE_UNDEFINED; _shape = new long[rank];
_stride = new long[rank];
_dataType = ONNX_TYPE_UNDEFINED;
_rank = rank; _rank = rank;
_sizes = new long[rank]; _name = "";
_strides = new long[rank];
}
/* ---------- Data type getter and setter ---------- */
/* For JNI wrapper only. Not intended for end user. */
/**
* Type getter
*
* @return data type
*/
@SuppressWarnings("unused")
private int getType() {
return _type;
}
/**
* Type setter
*
* @param type data type to be set
*/
@SuppressWarnings("unused")
private void setType(int type) {
_type = type;
} }
/* ---------- Raw data getter and setter ---------- */ /* ---------- Raw data getter and setter ---------- */
@ -105,7 +85,7 @@ public class RtMemRef {
*/ */
@SuppressWarnings("unused") @SuppressWarnings("unused")
private ByteBuffer getData() { private ByteBuffer getData() {
return _data; return _allocatedPtr;
} }
/** /**
@ -115,7 +95,7 @@ public class RtMemRef {
*/ */
@SuppressWarnings("unused") @SuppressWarnings("unused")
private void setData(ByteBuffer data) { private void setData(ByteBuffer data) {
_data = data.order(endian); _allocatedPtr = data.order(endian);
} }
/* ---------- Byte data getter and setter ---------- */ /* ---------- Byte data getter and setter ---------- */
@ -126,14 +106,14 @@ public class RtMemRef {
* @return byte data array * @return byte data array
*/ */
public byte[] getByteData() { public byte[] getByteData() {
if (_data == null) return null; if (_allocatedPtr == null) return null;
/* asReadOnlyBuffer() creates a new view so the position of the /* asReadOnlyBuffer() creates a new view so the position of the
* original data will stay at 0 for subsequent getByteData() * original data will stay at 0 for subsequent getByteData()
* after get(b). * after get(b).
*/ */
byte[] b = new byte[_data.limit()]; byte[] b = new byte[_allocatedPtr.limit()];
_data.asReadOnlyBuffer().get(b); _allocatedPtr.asReadOnlyBuffer().get(b);
return b; return b;
} }
@ -146,9 +126,9 @@ public class RtMemRef {
/* slice() creates a new view so the position of the /* slice() creates a new view so the position of the
* original data will stay at 0 for getByteData() after put(data). * original data will stay at 0 for getByteData() after put(data).
*/ */
_data = ByteBuffer.allocateDirect(data.length); _allocatedPtr = ByteBuffer.allocateDirect(data.length);
_data.slice().put(data); _allocatedPtr.slice().put(data);
_type = ONNX_TYPE_INT8; _dataType = ONNX_TYPE_INT8;
} }
/* ---------- Short data getter and setter ---------- */ /* ---------- Short data getter and setter ---------- */
@ -159,13 +139,13 @@ public class RtMemRef {
* @return short data array * @return short data array
*/ */
public short[] getShortData() { public short[] getShortData() {
if (_data == null) return null; if (_allocatedPtr == null) return null;
/* asShortBuffer() creates a new view so the position of the /* asShortBuffer() creates a new view so the position of the
* original data will stay at 0 for subsequent getShortData() * original data will stay at 0 for subsequent getShortData()
* after get(s). * after get(s).
*/ */
ShortBuffer sb = _data.asShortBuffer(); ShortBuffer sb = _allocatedPtr.asShortBuffer();
short[] s = new short[sb.limit()]; short[] s = new short[sb.limit()];
sb.get(s); sb.get(s);
return s; return s;
@ -180,9 +160,9 @@ public class RtMemRef {
/* asShortBuffer() creates a new view so the position of the /* asShortBuffer() creates a new view so the position of the
* original data will stay at 0 for getShortData() after put(data). * original data will stay at 0 for getShortData() after put(data).
*/ */
_data = ByteBuffer.allocateDirect(data.length*2).order(endian); _allocatedPtr = ByteBuffer.allocateDirect(data.length*2).order(endian);
_data.asShortBuffer().put(data); _allocatedPtr.asShortBuffer().put(data);
_type = ONNX_TYPE_INT16; _dataType = ONNX_TYPE_INT16;
} }
/* ---------- Int data getter and setter ---------- */ /* ---------- Int data getter and setter ---------- */
@ -193,13 +173,13 @@ public class RtMemRef {
* @return int data array * @return int data array
*/ */
public int[] getIntData() { public int[] getIntData() {
if (_data == null) return null; if (_allocatedPtr == null) return null;
/* asIntBuffer() creates a new view so the position of the /* asIntBuffer() creates a new view so the position of the
* original data will stay at 0 for subsequent getIntData() * original data will stay at 0 for subsequent getIntData()
* after get(i). * after get(i).
*/ */
IntBuffer ib = _data.asIntBuffer(); IntBuffer ib = _allocatedPtr.asIntBuffer();
int[] i = new int[ib.limit()]; int[] i = new int[ib.limit()];
ib.get(i); ib.get(i);
return i; return i;
@ -214,9 +194,9 @@ public class RtMemRef {
/* asIntBuffer() creates a new view so the position of the /* asIntBuffer() creates a new view so the position of the
* original data will stay at 0 for getIntData() after put(data). * original data will stay at 0 for getIntData() after put(data).
*/ */
_data = ByteBuffer.allocateDirect(data.length*4).order(endian); _allocatedPtr = ByteBuffer.allocateDirect(data.length*4).order(endian);
_data.asIntBuffer().put(data); _allocatedPtr.asIntBuffer().put(data);
_type = ONNX_TYPE_INT32; _dataType = ONNX_TYPE_INT32;
} }
/* ---------- Long data getter and setter ---------- */ /* ---------- Long data getter and setter ---------- */
@ -227,13 +207,13 @@ public class RtMemRef {
* @return long data array * @return long data array
*/ */
public long[] getLongData() { public long[] getLongData() {
if (_data == null) return null; if (_allocatedPtr == null) return null;
/* asLongBuffer() creates a new view so the position of the /* asLongBuffer() creates a new view so the position of the
* original data will stay at 0 for subsequent getLongData() * original data will stay at 0 for subsequent getLongData()
* after get(l). * after get(l).
*/ */
LongBuffer lb = _data.asLongBuffer(); LongBuffer lb = _allocatedPtr.asLongBuffer();
long[] l = new long[lb.limit()]; long[] l = new long[lb.limit()];
lb.get(l); lb.get(l);
return l; return l;
@ -248,9 +228,9 @@ public class RtMemRef {
/* asLongBuffer() creates a new view so the position of the /* asLongBuffer() creates a new view so the position of the
* original data will stay at 0 for getLongData() after put(data). * original data will stay at 0 for getLongData() after put(data).
*/ */
_data = ByteBuffer.allocateDirect(data.length*8).order(endian); _allocatedPtr = ByteBuffer.allocateDirect(data.length*8).order(endian);
_data.asLongBuffer().put(data); _allocatedPtr.asLongBuffer().put(data);
_type = ONNX_TYPE_INT64; _dataType = ONNX_TYPE_INT64;
} }
/* ---------- Float data getter and setter ---------- */ /* ---------- Float data getter and setter ---------- */
@ -261,13 +241,13 @@ public class RtMemRef {
* @return float data array * @return float data array
*/ */
public float[] getFloatData() { public float[] getFloatData() {
if (_data == null) return null; if (_allocatedPtr == null) return null;
/* asFloatBuffer() creates a new view so the position of the /* asFloatBuffer() creates a new view so the position of the
* original data will stay at 0 for subsequent getFloatData() * original data will stay at 0 for subsequent getFloatData()
* after get(f). * after get(f).
*/ */
FloatBuffer fb = _data.asFloatBuffer(); FloatBuffer fb = _allocatedPtr.asFloatBuffer();
float[] f = new float[fb.limit()]; float[] f = new float[fb.limit()];
fb.get(f); fb.get(f);
return f; return f;
@ -282,9 +262,9 @@ public class RtMemRef {
/* asFloatBuffer() creates a new view so the position of the /* asFloatBuffer() creates a new view so the position of the
* original data will stay at 0 for getFloatData() after put(data). * original data will stay at 0 for getFloatData() after put(data).
*/ */
_data = ByteBuffer.allocateDirect(data.length*4).order(endian); _allocatedPtr = ByteBuffer.allocateDirect(data.length*4).order(endian);
_data.asFloatBuffer().put(data); _allocatedPtr.asFloatBuffer().put(data);
_type = ONNX_TYPE_FLOAT; _dataType = ONNX_TYPE_FLOAT;
} }
/* ---------- Double data getter and setter ---------- */ /* ---------- Double data getter and setter ---------- */
@ -295,13 +275,13 @@ public class RtMemRef {
* @return double data array * @return double data array
*/ */
public double[] getDoubleData() { public double[] getDoubleData() {
if (_data == null) return null; if (_allocatedPtr == null) return null;
/* asDoubleBuffer() creates a new view so the position of the /* asDoubleBuffer() creates a new view so the position of the
* original data will stay at 0 for subsequent getDoubleData() * original data will stay at 0 for subsequent getDoubleData()
* after get(d). * after get(d).
*/ */
DoubleBuffer db = _data.asDoubleBuffer(); DoubleBuffer db = _allocatedPtr.asDoubleBuffer();
double[] d = new double[db.limit()]; double[] d = new double[db.limit()];
db.get(d); db.get(d);
return d; return d;
@ -316,85 +296,141 @@ public class RtMemRef {
/* asDoubleBuffer() creates a new view so the position of the /* asDoubleBuffer() creates a new view so the position of the
* original data will stay at 0 for getDoubleData() after put(data). * original data will stay at 0 for getDoubleData() after put(data).
*/ */
_data = ByteBuffer.allocateDirect(data.length*8).order(endian); _allocatedPtr = ByteBuffer.allocateDirect(data.length*8).order(endian);
_data.asDoubleBuffer().put(data); _allocatedPtr.asDoubleBuffer().put(data);
_type = ONNX_TYPE_DOUBLE; _dataType = ONNX_TYPE_DOUBLE;
} }
/* ---------- Data sizes getter and setter ---------- */
/**
* Data sizes getter
*
* @return data sizes array
*/
public long[] getDataSizes() {
return _shape;
}
/**
* Data sizes setter
*
* @param dataSizes data sizes array to be set
*/
public void setDataSizes(long[] dataSizes) {
if (dataSizes.length != _rank)
throw new IllegalArgumentException(
"array length " + dataSizes.length + " != rank " + _rank);
_shape = dataSizes.clone();
}
/* ---------- Data strides getter and setter ---------- */
/**
* Data strides getter
*
* @return data strides array
*/
public long[] getDataStrides() {
return _stride;
}
/**
* Data strides setter
*
* @param dataStrides data strides array to be set
*/
public void setDataStrides(long[] dataStrides) {
if (dataStrides.length != _rank)
throw new IllegalArgumentException(
"array length " + dataStrides.length + " != rank " + _rank);
_stride = dataStrides.clone();
}
/* ---------- Data type getter and setter ---------- */
/**
* Data type getter
*
* @return data type
*/
public int getDataType() {
return _dataType;
}
/**
* Data type setter
*
* @param type data type to be set
*/
public void setDataType(int dataType) {
if (dataType < 0 || dataType > ONNX_TYPE_BFLOAT16)
throw new IllegalArgumentException(
"data type " + dataType + " unknown");
_dataType = dataType;
}
/* ---------- Data buffer size getter ---------- */
/**
* Data buffer size getter
*
* @return total size of the data buffer in bytes
*/
public long getDataBufferSize() {
return _allocatedPtr == null ? 0 : _allocatedPtr.limit();
}
/* ---------- Rank getter ---------- */
/** /**
* Rank getter * Rank getter
* *
* @return rank * @return rank of the OMTensor
*/ */
public int getRank() { public int getRank() {
return _rank; return _rank;
} }
/* ---------- Sizes getter and setter ---------- */ /* ---------- Name getter and setter ---------- */
/** /**
* Sizes getter * Name getter
* *
* @return sizes array * @return name of the OMTensor
*/ */
public long[] getSizes() { public String getName() {
return _sizes; return _name;
} }
/** /**
* Sizes setter * Name setter
* *
* @param sizes sizes array to be set * @param name name of the OMTensor
*/ */
public void setSizes(long[] sizes) { public void setName(String name) {
if (sizes.length != _rank) _name = name == null ? "" : name;
throw new IllegalArgumentException(
"array length " + sizes.length + " != rank " + _rank);
_sizes = sizes.clone();
}
/* ---------- Strides getter and setter ---------- */
/**
* Strides getter
*
* @return strides array
*/
public long[] getStrides() {
return _strides;
} }
/** /**
* Strides setter * Number of elements getter
* *
* @param strides strides array to be set * @return number of data elements in the data buffer
*/ */
public void setStrides(long[] strides) { public long getNumOfElems() {
if (strides.length != _rank) long n = _shape[0];
throw new IllegalArgumentException( for (int i = 1; i < _shape.length; i++) n *= _shape[i];
"array length " + strides.length + " != rank " + _rank);
_strides = strides.clone();
}
/**
* Size getter
*
* @return product of sizes array, i.e., total number of data elements
*/
public long getDataSize() {
long n = _sizes[0];
for (int i = 1; i < _sizes.length; i++) n *= _sizes[i];
return n; return n;
} }
/** /**
* Check validity of RtMemRef * Check validity of OMTensor
* *
* @return true if RtMemRef is valid, false otherwise * @return true if OMTensor is valid, false otherwise
*/ */
public boolean validRmr() { public boolean isValidOmt() {
return (_data != null && return (_allocatedPtr != null &&
_data.limit() != 0 && _allocatedPtr.limit() != 0 &&
_data.limit() == getDataSize() * ONNX_TYPE_SIZE[_type]); _allocatedPtr.limit() == getNumOfElems() * ONNX_TYPE_SIZE[_dataType]);
} }
} }

View File

@ -0,0 +1,40 @@
package com.ibm.onnxmlir;
import java.util.HashMap;
public class OMTensorList {
private OMTensor[] _omts;
private HashMap<String, Integer> _n2i;
/**
* Constructor
*
* @param omts DynMemRef array
*/
public OMTensorList(OMTensor[] omts) {
/* Go through the OMTensor array, check each for validity,
* and create name (if not empty) to index mapping.
*/
for (int i = 0; i < omts.length; i++) {
if (omts[i] == null || !omts[i].isValidOmt())
throw new IllegalArgumentException(
"OMTensor[" + i + "] is invalid");
String name = omts[i].getName();
if (!name.isEmpty() && _n2i.put(name, i) != null)
throw new IllegalArgumentException(
"OMTensor[" + i + "] duplicate name: " + name);
}
_omts = omts;
}
/**
* OMTensor array getter
*
* @return OMTensor array
*/
public OMTensor[] getOmts() {
return _omts;
}
}

View File

@ -1,118 +0,0 @@
package com.ibm.onnxmlir;
import java.util.HashMap;
public class OrderedRtMemRefDict {
private RtMemRef[] _rmrs;
private String[] _names;
private HashMap<String, Integer> _n2i;
/**
* Constructor
*
* @param rmrs DynMemRef array
*/
public OrderedRtMemRefDict(RtMemRef[] rmrs) {
this(rmrs, null);
}
/**
* Constructor
*
* @param rmrs DynMemRef array
* @param names name array
*/
public OrderedRtMemRefDict(RtMemRef[] rmrs, String[] names) {
/* rmrs cannot be null or empty */
if (rmrs == null || rmrs.length == 0)
throw new IllegalArgumentException(
"Number of dmrs is invalid");
/* If names is null or empty, construct a default one with
* index as name.
*/
if (names == null || names.length == 0) {
names = new String[rmrs.length];
for (int i = 0; i < rmrs.length; i++)
names[i] = Integer.toString(i);
}
/* Number of rmrs and names must match */
if (rmrs.length != names.length)
throw new IllegalArgumentException(
"Number of dmrs and names do not match");
/* Establish name to index mapping. Individual rmr is
* checked for validity.
*/
_n2i = new HashMap<String, Integer>();
for (int i = 0; i < names.length; i++) {
if (rmrs[i] == null || !rmrs[i].validRmr())
throw new IllegalArgumentException(
"rmrs[" + i + "] is invalid");
if (_n2i.put(names[i], i) != null)
throw new IllegalArgumentException(
"name[" + i + "] = " + names[i] + " not unique");
}
_rmrs = rmrs;
_names = names;
}
/**
* RtMemRef getter by index
*
* @param idx index of RtMemRef instance to get
* @return RtMemRef instance
*/
public RtMemRef getRmrbyIndex(int idx) {
return _rmrs[idx];
}
/**
* RtMemRef getter by name
*
* @param name name of RtMemRef instance to get
* @return RtMemRef instance
*/
public RtMemRef getRmrByName(String name) {
return _rmrs[_n2i.get(name)];
}
/**
* RtMemRef array getter
*
* @return RtMemRef array
*/
public RtMemRef[] getRmrs() {
return _rmrs;
}
/**
* Name getter
*
* @param idx index of name to get
* @return name string
*/
public String getName(int idx) {
return _names[idx];
}
/**
* Name array getter
*
* @return name array
*/
public String[] getNames() {
return _names;
}
/**
* RtMemRef array size getter
*
* @return RtMemRef array size
*/
public int size() {
return _rmrs.length;
}
}

View File

@ -1,3 +1,4 @@
add_subdirectory(mlir) add_subdirectory(mlir)
add_subdirectory(backend) add_subdirectory(backend)
add_subdirectory(numerical) add_subdirectory(numerical)
add_subdirectory(unit)

View File

@ -92,7 +92,7 @@ class DummyBackend(onnx.backend.base.Backend):
# Call frontend to process temp_model.onnx, bit code will be generated. # Call frontend to process temp_model.onnx, bit code will be generated.
execute_commands([ONNX_MLIR, "temp_model.onnx"]) execute_commands([ONNX_MLIR, "temp_model.onnx"])
return EndiannessAwareExecutionSession("./temp_model.so", return EndiannessAwareExecutionSession("./temp_model.so",
"_dyn_entry_point_main_graph") "run_main_graph")
@classmethod @classmethod
def supports_device(cls, device): def supports_device(cls, device):
@ -448,7 +448,6 @@ test_to_enable = [
"test_shufflenet_cpu", "test_shufflenet_cpu",
] ]
# Extract name of all test cases. # Extract name of all test cases.
import inspect import inspect
all_tests = [] all_tests = []

View File

@ -1,4 +1,9 @@
add_executable(TestConv TestConv.cpp) add_executable(TestConv TestConv.cpp)
target_compile_definitions(TestConv PRIVATE RTMEMREF_INTERNAL_API)
target_include_directories(TestConv
PRIVATE
${ONNX_MLIR_BIN_ROOT}
${ONNX_MLIR_SRC_ROOT}/include)
target_link_libraries(TestConv target_link_libraries(TestConv
${OMLibs} ${OMLibs}
${MLIRLibs} ${MLIRLibs}
@ -6,11 +11,5 @@ target_link_libraries(TestConv
rapidcheck rapidcheck
MainUtils MainUtils
ExecutionSession ExecutionSession
RtMemRefUtils) OMTensorUtils)
target_include_directories(TestConv
PRIVATE
${ONNX_MLIR_SRC_ROOT}
${ONNX_MLIR_BIN_ROOT}
${ONNX_MLIR_SRC_ROOT})
add_test(NAME OMTestConv COMMAND TestConv) add_test(NAME OMTestConv COMMAND TestConv)

View File

@ -11,7 +11,8 @@
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/MainUtils.hpp" #include "src/MainUtils.hpp"
#include "src/Runtime/ExecusionSession.hpp" #include "src/Runtime/ExecutionSession.hpp"
#include "src/Runtime/OMTensorHelper.h"
#define SHARED_LIB_BASE string("./TestConv_main_graph") #define SHARED_LIB_BASE string("./TestConv_main_graph")
@ -93,38 +94,39 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H,
OwningModuleRef moduleRef(module); OwningModuleRef moduleRef(module);
compileModule(moduleRef, ctx, SHARED_LIB_BASE, EmitLib); compileModule(moduleRef, ctx, SHARED_LIB_BASE, EmitLib);
onnx_mlir::ExecutionSession sess( onnx_mlir::ExecutionSession sess(SHARED_LIB_BASE + ".so", "run_main_graph");
SHARED_LIB_BASE + ".so", "_dyn_entry_point_main_graph");
std::vector<unique_ptr<RtMemRef>> inputs; std::vector<unique_ptr<OMTensor, decltype(&omTensorDestroy)>> inputs;
auto xRmr = unique_ptr<RtMemRef>(getRndRealRmr<float>({N, C, H, W})); auto xOmt = unique_ptr<OMTensor, decltype(&omTensorDestroy)>(
inputs.emplace_back(move(xRmr)); omTensorCreateWithRandomData<float>({N, C, H, W}), omTensorDestroy);
auto wRmr = unique_ptr<RtMemRef>(getRndRealRmr<float>({C, C, kH, kW})); inputs.emplace_back(move(xOmt));
inputs.emplace_back(move(wRmr)); auto wOmt = unique_ptr<OMTensor, decltype(&omTensorDestroy)>(
omTensorCreateWithRandomData<float>({C, C, kH, kW}), omTensorDestroy);
inputs.emplace_back(move(wOmt));
auto ref = RtMemRef::create<float>({NOut, COut, HOut, WOut}); auto ref = omTensorCreateWithShape<float>({NOut, COut, HOut, WOut});
auto &img = inputs.at(0); auto &img = inputs.at(0);
auto &filter = inputs.at(1); auto &filter = inputs.at(1);
for (int64_t n = 0; n < NOut; n++) for (int64_t n = 0; n < NOut; n++)
for (int64_t c = 0; c < COut; c++) for (int64_t c = 0; c < COut; c++)
for (int64_t h = 0; h < HOut; h++) for (int64_t h = 0; h < HOut; h++)
for (int64_t w = 0; w < WOut; w++) { for (int64_t w = 0; w < WOut; w++) {
ref->elem<float>({n, c, h, w}) = 0; omTensorGetElem<float>(ref, {n, c, h, w}) = 0;
for (int64_t ci = 0; ci < C; ci++) for (int64_t ci = 0; ci < C; ci++)
for (int64_t kh = 0; kh < kH; kh++) for (int64_t kh = 0; kh < kH; kh++)
for (int64_t kw = 0; kw < kW; kw++) for (int64_t kw = 0; kw < kW; kw++)
if ((h + kh - pHBegin >= 0 && h + kh - pHBegin < H) && if ((h + kh - pHBegin >= 0 && h + kh - pHBegin < H) &&
(w + kw - pWBegin >= 0 && w + kw - pWBegin < W)) (w + kw - pWBegin >= 0 && w + kw - pWBegin < W))
ref->elem<float>({n, c, h, w}) += omTensorGetElem<float>(ref, {n, c, h, w}) +=
img->elem<float>( omTensorGetElem<float>(img.get(),
{n, ci, h + kh - pHBegin, w + kw - pWBegin}) * {n, ci, h + kh - pHBegin, w + kw - pWBegin}) *
filter->elem<float>({c, ci, kh, kw}); omTensorGetElem<float>(filter.get(), {c, ci, kh, kw});
} }
auto outputs = sess.run(move(inputs)); auto outputs = sess.run(move(inputs));
auto &conv = outputs.at(0); auto &conv = outputs.at(0);
return isRmrClose<float>(conv.get(), ref); return omTensorAreTwoOmtsClose<float>(conv.get(), ref);
} }
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {

24
test/unit/CMakeLists.txt Normal file
View File

@ -0,0 +1,24 @@
include(GoogleTest)
# Adapted from https://cliutils.gitlab.io/modern-cmake/chapters/testing/googletest.html.
macro(add_unit_test TESTNAME)
# create an exectuable in which the tests will be stored
add_executable(${TESTNAME} ${ARGN})
# link the Google test infrastructure, mocking library, and a default main fuction to
# the test executable. Remove g_test_main if writing your own main function.
target_link_libraries(${TESTNAME} gtest gmock gtest_main)
# gtest_discover_tests replaces gtest_add_tests,
# see https://cmake.org/cmake/help/v3.10/module/GoogleTest.html for more options to pass to it
gtest_discover_tests(${TESTNAME}
# set a working directory so your project root so that you can find test data via paths relative to the project root
WORKING_DIRECTORY ${ONNX_MLIR_BIN_ROOT}
PROPERTIES VS_DEBUGGER_WORKING_DIRECTORY "${ONNX_MLIR_BIN_ROOT}")
endmacro()
macro(add_c_unit_test TESTNAME)
add_executable(${TESTNAME} ${ARGN})
add_test(NAME ${TESTNAME}
COMMAND ${TESTNAME})
endmacro()
add_subdirectory(Runtime)

View File

@ -0,0 +1,7 @@
add_subdirectory(DocExampleTest)
add_c_unit_test(OMTensorTest OMTensorTest.c)
target_include_directories(OMTensorTest PRIVATE
${ONNX_MLIR_SRC_ROOT}/include)
target_link_libraries(OMTensorTest
cruntime)

View File

@ -0,0 +1,44 @@
# Documentation example requires ONNX package installation, which has been
# flaky on non-x86 platforms, so only perform this test on x86 arch.
if (CMAKE_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)")
find_package(PythonInterp 3 REQUIRED)
add_custom_target(OMInstallOnnx
COMMAND ${PYTHON_EXECUTABLE} setup.py install --user
WORKING_DIRECTORY ${ONNX_MLIR_SRC_ROOT}/third_party/onnx)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/gen_add_onnx.py
${CMAKE_CURRENT_BINARY_DIR}/gen_add_onnx.py COPYONLY)
add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/add.onnx
COMMAND ${PYTHON_EXECUTABLE} gen_add_onnx.py
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
add_custom_target(OMGenerateAddModel
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/add.onnx)
add_dependencies(OMGenerateAddModel OMInstallOnnx)
add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/add.so
COMMAND onnx-mlir ${CMAKE_CURRENT_BINARY_DIR}/add.onnx
DEPENDS OMGenerateAddModel)
add_custom_target(OMGenerateAddLibrary
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/add.so)
add_custom_target(OMCopyAndRename
COMMAND ${CMAKE_COMMAND} -E copy
${CMAKE_CURRENT_BINARY_DIR}/add.so
${CMAKE_CURRENT_BINARY_DIR}/library.so
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/add.so)
add_executable(OMRuntimeTest
main.c)
add_library(OMRuntimeTestModel SHARED IMPORTED)
file(TOUCH ${CMAKE_CURRENT_BINARY_DIR}/library.so)
set_property(TARGET OMRuntimeTestModel
PROPERTY IMPORTED_LOCATION ${CMAKE_CURRENT_BINARY_DIR}/library.so)
target_link_libraries(OMRuntimeTest
${CMAKE_CURRENT_BINARY_DIR}/library.so)
target_include_directories(OMRuntimeTest
PRIVATE ${ONNX_MLIR_SRC_ROOT}/include)
add_dependencies(OMRuntimeTest OMCopyAndRename)
add_test(NAME OMRuntimeTest
COMMAND OMRuntimeTest)
endif ()

View File

@ -0,0 +1,33 @@
import onnx
from onnx import helper
from onnx import AttributeProto, TensorProto, GraphProto
# Create one input (ValueInfoProto)
X1 = helper.make_tensor_value_info('X1', TensorProto.FLOAT, [3, 2])
X2 = helper.make_tensor_value_info('X2', TensorProto.FLOAT, [3, 2])
# Create one output (ValueInfoProto)
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [3, 2])
# Create a node (NodeProto) - This is based on Pad-11
node_def = helper.make_node(
'Add', # node name
['X1', 'X2'], # inputs
['Y'], # outputs
)
# Create the graph (GraphProto)
graph_def = helper.make_graph(
[node_def],
'test-model',
[X1, X2],
[Y],
)
# Create the model (ModelProto)
model_def = helper.make_model(graph_def, producer_name='onnx-example')
print('The model is:\n{}'.format(model_def))
onnx.checker.check_model(model_def)
onnx.save(model_def, "add.onnx")
print('The model is checked!')

View File

@ -0,0 +1,30 @@
#include <OnnxMlirRuntime.h>
#include <stdio.h>
OMTensorList *run_main_graph(OMTensorList *);
int main() {
// Shared shape & rank.
int64_t shape[] = {2, 2};
int64_t rank = 2;
// Construct x1 omt filled with 1.
float x1Data[] = {1., 1., 1., 1., 1., 1.};
int64_t *x1Shape = {2, 2};
OMTensor *x1 = omTensorCreate(x1Data, shape, rank, ONNX_TYPE_FLOAT);
// Construct x2 omt filled with 2.
float x2Data[] = {2., 2., 2., 2., 2., 2.};
int64_t *x2Shape = {2, 2};
OMTensor *x2 = omTensorCreate(x2Data, shape, rank, ONNX_TYPE_FLOAT);
// Construct a list of omts as input.
OMTensor *list[2] = {x1, x2};
OMTensorList *input = omTensorListCreate(list, 2);
// Call the compiled onnx model function.
OMTensorList *outputList = run_main_graph(input);
// Get the first omt as output.
OMTensor *y = omTensorListGetOmtByIndex(outputList, 0);
float *outputPtr = (float *) omTensorGetDataPtr(y);
// Print its content, should be all 3.
for (int i = 0; i < 6; i++)
printf("%f ", outputPtr[i]);
return 0;
}

View File

@ -0,0 +1,36 @@
//===----------------- OMTensorTest.h - OMTensor Unit Test -----------------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This file contains declaration of OMTensor and data structures and
// helper functions.
//
//===----------------------------------------------------------------------===//
#include <assert.h>
#include <stdio.h>
#include "OnnxMlirRuntime.h"
void testOMTensorCtor() {
float data[4] = {1.f, 1.f};
int64_t shape[2] = {2, 2};
OMTensor *tensor = omTensorCreate(data, shape, 2, ONNX_TYPE_FLOAT);
assert(tensor);
int64_t* shape_ptr = omTensorGetDataShape(tensor);
assert(shape_ptr);
assert(shape_ptr[0] == 2);
assert(shape_ptr[1] == 2);
int64_t* strides_ptr = omTensorGetStrides(tensor);
assert(strides_ptr);
assert(strides_ptr[0] == 2);
assert(strides_ptr[1] == 1);
}
int main() {
testOMTensorCtor();
return 0;
}