Cleanup rtmemref api (#238)
* Detect llvm-project commit change in utils/clone-mlir.sh and rebuild llvm-project for zLinux Jenkins build bot * Cleanup RtMemRef API - use forward declaration to hide private data fields - RtMemRef.h: external user header, C/C++ - _RtMemRef.h: internal user header, C++ only - RtMemRef.hpp and RtMemRef.cpp: implementation header and file - add external APIs OrderedRtMemRefDict *ormrd_create(RtMemRef **rmrs, int n) RtMemRef **ormrd_getRmrs(OrderedRtMemRefDict *ormrd) int ormrd_getNumOfRmrs(OrderedRtMemRefDict *ormrd) for creating and querying OrderedRtMemRefDict with RtMemRef arrays - data buffer installed by rmr_setData() will be managed by user - unique_ptr<RtMemRef> must use custom deleter <RtMemRef,decltype(&rmr_destroy)> * See if I have write access. * Remove test CMake code. * Use new API. * Format code. * Format code & rename variables for readability. * Remove used API spec. * Rename OrderedRtMemRefDict -> RtMemRefList, _dataMalloc -> _owningData. * OrderedRtMemRefDict -> RtMemRefList * Update KrnlToLLVM.cpp * Trigger Jenkins * Restart Jenkins * OrderedRtMemRefDict -> RtRmrRefList * More OrderedRtMemRefDict -> RtMemRefList. * Format jni wrapper. * Rename API functions to maintain stylistic consistency. * Bug fix. * Bug fix. * Format code. * Fix RtMemRefUtils. * Format code. * Using llvm function naming scheme. * Rename runtime api file name to project name (onnx-mlir) as per convention. * Include the new runtime header file. * Reflect api header file name change in build script. * Bug fix. * Remove C++ code. * Revert "Remove C++ code." This reverts commit b217dfabae99e42db30721600cb5507866d4dc98. * Clarify memory management responsibility. * Add constructor to specify name & data ownership. * Include stdbool. * Remove dictionary semantics from RtMemRefList * Bug fix. * Format code. * Format code. * Use macro to define database of metadata. * Prevent formatter from acting on metadata decl. * Nit. * Restore backend unit tests. * Use spaces instead of tabs for better formatting. * Use explicit template instantiation. * Update RtMemRef struct doc. * Make runtime compilable both in c and c++ mode. Build two versions of the runtime library, one c version as the user-facing c runtime, and one c++ version as the one used inside this project. * Bug fix, avoid stack allocation for output rmr list. * Change _dyn_entry_point_main_graph -> run_main_graph for better memorability. * Write a complete introductory tutorial on c99 Runtime and a test for it. * Add onnx installation as dependency. * Use target_include_directory to avoid installation. * Format code. * Fix cmake target_include_directories. * Address compiler warning. * First pass of RtMemRef->OMTensor. * Second pass of RtMemRef -> OMTensor. * nit, omtList -> omTensorList. * omt -> omTensor for clarity. * Rename OnnxMlirInternal.h -> OnnxMlirRuntime.hpp because there's no internal/external API, only C/C++ API. * Format code. * Restructure Runtime source code and move header -> /include and test -> /test/unit. * Bugfix. * Format code. * Add unit test for OMTensor ctor. * Update JNI CMake include directory. * Bugfix. * No need to re-declare ONNX types. * Disable runtime doc test on non-x86 platforms. * Rename OMTensor fields to be more sensible. * Fix data type mismatch. * size_t -> int64_t, prefer fixed width integers. * Use consistent header guard style. * Further tweak OMTensor API. * Bugfix. * Bugfix. * Format code. * Add doxygen config file. * Tweak OMTensor API. * Tweak API doc, hide OMTensorList implementation. * Format code. * Add new documentation item for Runtime API. * Hide internal use only API declarations, move their comments to their implementations. * Clarify ownership semantics in relevant API documentations. * Fix PyRuntime. * Remove alignment concerns from public API and include explaination of alignment issue in struct OMTensor definition. * Print out unsupported numpy dtype. * Use preferred way of type comparison in pybind11. * Debug s390x issue. * Remove debug code. * Clarify semantics of strides/shape setter/getter, use \brief to include short description of API function. * Improve documentation. * Single out unpolished C++ API declarations. * Clarify OMTensorList API. * Bugfix. * Bugfix. * Assert after malloc. * Handle malloc failures. * Nit. * Tweak legal notices. * Format code. * Remove doxygen generated files. * Tweak legal notice format. * Upgrade Cython Numpy installation depends on Cython. Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
533d47acf1
commit
81c774ba5b
|
@ -8,7 +8,7 @@ jobs:
|
|||
steps:
|
||||
- run:
|
||||
name: Installing GCC, CMake, Ninja, Protobuf
|
||||
command: sudo apt-get update && sudo apt-get install -y gcc g++ cmake ninja-build protobuf-compiler
|
||||
command: sudo apt-get update && sudo apt-get install -y gcc g++ cmake ninja-build protobuf-compiler && pip install --upgrade -q cython
|
||||
- checkout:
|
||||
path: onnx-mlir
|
||||
- run:
|
||||
|
|
|
@ -27,6 +27,7 @@ if (MSVC)
|
|||
endif()
|
||||
add_subdirectory(third_party/onnx)
|
||||
add_subdirectory(third_party/googletest)
|
||||
SET(BENCHMARK_ENABLE_GTEST_TESTS OFF)
|
||||
add_subdirectory(third_party/benchmark)
|
||||
add_subdirectory(third_party/pybind11)
|
||||
add_subdirectory(third_party/variant)
|
||||
|
@ -35,6 +36,7 @@ add_subdirectory(third_party/rapidcheck)
|
|||
set(CMAKE_CXX_STANDARD 14)
|
||||
|
||||
add_subdirectory(utils)
|
||||
add_subdirectory(include)
|
||||
add_subdirectory(src)
|
||||
add_subdirectory(docs)
|
||||
add_subdirectory(test)
|
||||
|
|
|
@ -11,10 +11,10 @@ toc:
|
|||
# subfolderitems:
|
||||
# - page: Placeholder
|
||||
# url: /piece1.html
|
||||
# - title: How-tos
|
||||
# subfolderitems:
|
||||
# - page: Placeholder
|
||||
# url: /piece1.html
|
||||
- title: How-tos
|
||||
subfolderitems:
|
||||
- page: Runtime API
|
||||
url: /doxygen_html/index.html
|
||||
- title: References
|
||||
subfolderitems:
|
||||
- page: ONNX Dialect
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
add_subdirectory(onnx-mlir)
|
||||
|
||||
install(FILES OnnxMlirRuntime.h DESTINATION include)
|
|
@ -0,0 +1,116 @@
|
|||
//===------- OnnxMlirRuntime.h - ONNX-MLIR Runtime API Declarations -------===//
|
||||
//
|
||||
// Copyright 2019-2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains declaration of external OMTensor data structures and
|
||||
// helper functions.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#ifndef ONNX_MLIR_ONNXMLIRRUNTIME_H
|
||||
#define ONNX_MLIR_ONNXMLIRRUNTIME_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
#include <cstdint>
|
||||
#else
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#endif
|
||||
|
||||
#include <onnx-mlir/Runtime/OMTensor.h>
|
||||
#include <onnx-mlir/Runtime/OMTensorList.h>
|
||||
|
||||
/*! \mainpage ONNX-MLIR Runtime API documentation
|
||||
*
|
||||
* \section intro_sec Introduction
|
||||
*
|
||||
* ONNX-MLIR project comes with an executable `onnx-mlir` capable
|
||||
* of compiling onnx models to a shared library. In this documentation, we
|
||||
* demonstrate how to interact programmatically with the compiled
|
||||
* shared library using ONNX-MLIR's Runtime API.
|
||||
*
|
||||
* \section c-runtime-api C Runtime API
|
||||
*
|
||||
* \subsection data-structures Data Structures
|
||||
*
|
||||
* `OMTensor` is the data structure used to describe the runtime information
|
||||
* (rank, shape, data type, etc) associated with a tensor input or output.
|
||||
*
|
||||
* `OMTensorList` is the data structure used to hold a list of pointers to
|
||||
* OMTensor so that they can be passed into and out of the compiled model as
|
||||
* inputs and outputs.
|
||||
*
|
||||
* \subsection model-entry-point-signature Model Entry Point Signature
|
||||
*
|
||||
* All compiled model will have the same exact C function signature equivalent
|
||||
* to:
|
||||
*
|
||||
* ```c
|
||||
* OMTensorList* run_main_graph(OMTensorList*);
|
||||
* ```
|
||||
*
|
||||
* Intuitively, the model takes a list of tensors as input and returns a list of
|
||||
* ensors as output.
|
||||
*
|
||||
* \subsection invoke-models-using-c-runtime-api Invoke Models Using C Runtime
|
||||
* API
|
||||
*
|
||||
* We demonstrate using the API functions to run a simple ONNX model consisting
|
||||
* of an add operation. To create such an onnx model, use this
|
||||
* <a href="gen_add_onnx.py" target="_blank"><b>python script</b></a>
|
||||
*
|
||||
* To compile the above model, run `onnx-mlir add.onnx` and a binary library
|
||||
* "add.so" should appear. We can use the following C code to call into the
|
||||
* compiled function computing the sum of two inputs:
|
||||
*
|
||||
* ```c
|
||||
* #include <OnnxMlirRuntime.h>
|
||||
* #include <stdio.h>
|
||||
*
|
||||
* OMTensorList *run_main_graph(OMTensorList *);
|
||||
*
|
||||
* int main() {
|
||||
* // Shared shape & rank.
|
||||
* int64_t shape[] = {2, 2};
|
||||
* int64_t rank = 2;
|
||||
* // Construct x1 omt filled with 1.
|
||||
* float x1Data[] = {1., 1., 1., 1., 1., 1.};
|
||||
* int64_t *x1Shape = {2, 2};
|
||||
* OMTensor *x1 = omTensorCreate(x1Data, shape, rank, ONNX_TYPE_FLOAT);
|
||||
* // Construct x2 omt filled with 2.
|
||||
* float x2Data[] = {2., 2., 2., 2., 2., 2.};
|
||||
* int64_t *x2Shape = {2, 2};
|
||||
* OMTensor *x2 = omTensorCreate(x2Data, shape, rank, ONNX_TYPE_FLOAT);
|
||||
* // Construct a list of omts as input.
|
||||
* OMTensor *list[2] = {x1, x2};
|
||||
* OMTensorList *input = omTensorListCreate(list, 2);
|
||||
* // Call the compiled onnx model function.
|
||||
* OMTensorList *outputList = run_main_graph(input);
|
||||
* // Get the first omt as output.
|
||||
* OMTensor *y = omTensorListGetOmtByIndex(outputList, 0);
|
||||
* float *outputPtr = (float *)omTensorGetDataPtr(y);
|
||||
* // Print its content, should be all 3.
|
||||
* for (int i = 0; i < 6; i++)
|
||||
* printf("%f ", outputPtr[i]);
|
||||
* return 0;
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* Compile with `gcc main.c add.so -o add`, you should see an executable `add`
|
||||
* appearing. Run it, and the output should be:
|
||||
*
|
||||
* ```
|
||||
* 3.000000 3.000000 3.000000 3.000000 3.000000 3.000000
|
||||
* ```
|
||||
* Exactly as it should be.
|
||||
*
|
||||
* \subsection reference Reference
|
||||
*
|
||||
* For full reference to available C Runtime API, refer to
|
||||
*`include/onnx-mlir/Runtime/OMTensor.h` and
|
||||
* `include/onnx-mlir/Runtime/OMTensorList.h`.
|
||||
*
|
||||
*/
|
||||
|
||||
#endif // ONNX_MLIR_ONNXMLIRRUNTIME_H
|
|
@ -0,0 +1 @@
|
|||
add_subdirectory(Runtime)
|
|
@ -0,0 +1,2 @@
|
|||
install(FILES OnnxDataType.h DESTINATION include/onnx-mlir/Runtime)
|
||||
install(FILES OnnxDataTypeMetaData.inc DESTINATION include/onnx-mlir/Runtime)
|
|
@ -0,0 +1,263 @@
|
|||
//===-------------- OMTensor.h - OMTensor Declaration header --------------===//
|
||||
//
|
||||
// Copyright 2019-2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains declaration of OMTensor data structures and
|
||||
// API functions.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef ONNX_MLIR_OMTENSOR_H
|
||||
#define ONNX_MLIR_OMTENSOR_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#else
|
||||
#include <stdbool.h>
|
||||
#endif // #ifdef __cplusplus
|
||||
|
||||
#ifdef __APPLE__
|
||||
#include <stdlib.h>
|
||||
#else
|
||||
#include <malloc.h>
|
||||
#endif // #ifdef __APPLE__
|
||||
|
||||
#include "onnx-mlir/Runtime/OnnxDataType.h"
|
||||
|
||||
/* Typically, MemRefs in MLIR context are used as a compile-time constructs.
|
||||
* Information such as element type and rank of the data payload is statically
|
||||
* encoded, meaning that they are determined and fixed at compile-time. This
|
||||
* presents significant burden for any runtime components trying to interact
|
||||
* with the compiled executable.
|
||||
*
|
||||
* Thus a version of MemRef struct that is amenable to runtime manipulation is
|
||||
* provided as a basis for building any runtime-related components providing
|
||||
* user-facing programming interfaces. All information are dynamically encoded
|
||||
* as members of this struct so that they can be accessed and modified easily
|
||||
* during runtime.
|
||||
*
|
||||
* We will refer to it as an OMTensor.
|
||||
*/
|
||||
|
||||
struct OMTensor;
|
||||
|
||||
#ifndef __cplusplus
|
||||
typedef struct OMTensor OMTensor;
|
||||
#endif
|
||||
|
||||
/* Helper function to compute the number of data elements */
|
||||
static inline int64_t getNumOfElems(int64_t *dataSizes, int rank) {
|
||||
int64_t numElem = 1;
|
||||
for (int i = 0; i < rank; i++)
|
||||
numElem *= dataSizes[i];
|
||||
return numElem;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Create a OMTensor with specified data pointer, shape, rank and element
|
||||
* type.
|
||||
*
|
||||
* The call will not create a copy of the data. By default, caller is
|
||||
* responsible for managing the memory this pointer refers to. Namely, the
|
||||
* OMTensor is not the owner of the data. To indicate OMTensor's ownership of
|
||||
* data, use `omTensorCreateWithOwnership`. Ownership determines what happens
|
||||
* with the OMTensor is destroyed. With ownership of the data, the destruction
|
||||
* of the OMTensor will also free the data.
|
||||
*
|
||||
* @param data_ptr pointer to tensor data. By default, caller is responsible for
|
||||
* managing the memory this pointer refers to.
|
||||
* @param shape list of integers indicating the tensor shape.
|
||||
* @param rank tensor rank.
|
||||
* @param dtype tensor element data type.
|
||||
* @return pointer to OMTensor created, NULL if creation failed.
|
||||
*
|
||||
*/
|
||||
OMTensor *omTensorCreate(
|
||||
void *data_ptr, int64_t *shape, int64_t rank, OM_DATA_TYPE dtype);
|
||||
|
||||
/**
|
||||
* \brief Create an OMTensor with the specified shape, rank and element type,
|
||||
* allocate uninitialized data for the specified shape.
|
||||
*
|
||||
* The OMTensor created using this constructor owns the underlying memory
|
||||
* space allocated to the content of the tensor.
|
||||
*
|
||||
* @param shape list of integers indicating the tensor shape.
|
||||
* @param rank tensor rank.
|
||||
* @param dtype tensor element data type.
|
||||
* @return pointer to OMTensor created, NULL if creation failed.
|
||||
*
|
||||
*/
|
||||
OMTensor *omTensorCreateEmpty(int64_t *shape, int64_t rank, OM_DATA_TYPE dtype);
|
||||
|
||||
/**
|
||||
* \brief Create an OMTensor with specified data pointer, shape, rank and
|
||||
* element type, manually setting data ptr ownership.
|
||||
*
|
||||
* Using this constructor, users can
|
||||
* specify whether OMTensor owns the data, which subsequently determines whether
|
||||
* the memory space underlying the data will be freed or not when OMTensor gets
|
||||
* destroyed.
|
||||
*
|
||||
* @param data_ptr pointer to tensor data.
|
||||
* @param shape list of integers indicating the tensor shape.
|
||||
* @param rank tensor rank.
|
||||
* @param dtype tensor element data type.
|
||||
* @param owning whether OMTensor owns the data, if set to true, OMTensor will
|
||||
* release the data_ptr upon destruction.
|
||||
* @return pointer to OMTensor created, NULL if creation failed.
|
||||
*
|
||||
*/
|
||||
OMTensor *omTensorCreateWithOwnership(void *data_ptr, int64_t *shape,
|
||||
int64_t rank, OM_DATA_TYPE dtype, int owning);
|
||||
|
||||
/**
|
||||
* \brief Create an empty OMTensor with specified rank.
|
||||
*
|
||||
* This constructor returns a
|
||||
* partially filled omTensor; prefer using the new omTensorCreateEmpty()
|
||||
* function to fill shape & stride fields automatically.
|
||||
*
|
||||
* @param rank tensor rank
|
||||
* @return pointer to OMTensor created, NULL if creation failed.
|
||||
*
|
||||
*/
|
||||
OMTensor *omTensorCreateEmptyDeprecated(int rank);
|
||||
|
||||
/**
|
||||
* \brief Destroy the OMTensor struct.
|
||||
*
|
||||
* If OMTensor does not own the data, destroying the omTensor does not free up
|
||||
* the memory occupied by the tensor content. If OMTensor owns the data, this
|
||||
* function will free up the memory space underlying the tensor as well. The
|
||||
* documentation of OMTensor constructors clarifies the ownership semantics.
|
||||
*
|
||||
* @param tensor pointer to the OMTensor
|
||||
*
|
||||
*/
|
||||
void omTensorDestroy(OMTensor *tensor);
|
||||
|
||||
/**
|
||||
* \brief OMTensor data pointer getter.
|
||||
*
|
||||
* @param tensor pointer to the OMTensor
|
||||
* @return pointer to the data buffer of the OMTensor,
|
||||
* NULL if the data buffer is not set.
|
||||
*/
|
||||
void *omTensorGetDataPtr(OMTensor *tensor);
|
||||
|
||||
/**
|
||||
* \brief OMTensor data shape getter.
|
||||
*
|
||||
* The data shape is returned as a pointer pointing to an array of
|
||||
* n 64-bit integers where n is the rank of the tensor.
|
||||
*
|
||||
* The shape array is returned without copying, so caller should
|
||||
* not free the returned pointer.
|
||||
*
|
||||
* @param tensor pointer to the OMTensor
|
||||
* @return pointer to the data shape array.
|
||||
*/
|
||||
int64_t *omTensorGetDataShape(OMTensor *tensor);
|
||||
|
||||
/**
|
||||
* \brief OMTensor data shape setter.
|
||||
*
|
||||
* n int64 elements are copied from the shape array to indicate the shape of the
|
||||
* tensor, where n is the rank of the tensor.
|
||||
*
|
||||
* The shape array is copied without being freed, so caller is expected to
|
||||
* manage the shape array oneself.
|
||||
*
|
||||
* @param tensor pointer to the OMTensor
|
||||
* @param shape data sizes array to be set
|
||||
*
|
||||
* Set the data sizes array of the OMTensor to the values in the input array.
|
||||
*/
|
||||
void omTensorSetShape(OMTensor *tensor, int64_t *shape);
|
||||
|
||||
/**
|
||||
* \brief OMTensor data strides getter
|
||||
*
|
||||
* The data strides are returned as a pointer pointing to an array of
|
||||
* n 64-bit integers where n is the rank of the tensor.
|
||||
*
|
||||
* The strides array is returned without copying, so caller should
|
||||
* not free the returned pointer.
|
||||
*
|
||||
* @param tensor pointer to the OMTensor
|
||||
* @return pointer to the data strides array.
|
||||
*/
|
||||
int64_t *omTensorGetStrides(OMTensor *tensor);
|
||||
|
||||
/**
|
||||
* \brief OMTensor data strides setter
|
||||
*
|
||||
* n int64 elements are copied from the strides array to indicate the
|
||||
* per-dimension stride of the tensor, where n is the rank of the tensor.
|
||||
*
|
||||
* The strides array is copied without being freed, so caller is expected to
|
||||
* manage the strides array oneself.
|
||||
*
|
||||
* @param tensor pointer to the OMTensor
|
||||
* @param strides tensor strides array to be set.
|
||||
*
|
||||
* Set the data strides array of the OMTensor to the values in the input array.
|
||||
*/
|
||||
void omTensorSetStrides(OMTensor *tensor, int64_t *strides);
|
||||
|
||||
/**
|
||||
* \brief OMTensor data type getter
|
||||
*
|
||||
* @param tensor pointer to the OMTensor
|
||||
* @return ONNX data type of the data buffer elements.
|
||||
*/
|
||||
OM_DATA_TYPE omTensorGetDataType(OMTensor *tensor);
|
||||
|
||||
/**
|
||||
* \brief OMTensor data type setter
|
||||
*
|
||||
* @param tensor pointer to the OMTensor
|
||||
* @param dataType ONNX data type to be set
|
||||
*
|
||||
* Set the ONNX data type of the data buffer elements.
|
||||
*/
|
||||
void omTensorSetDataType(OMTensor *tensor, OM_DATA_TYPE dataType);
|
||||
|
||||
/* Helper function to get the ONNX data type size in bytes */
|
||||
static inline int getDataTypeSize(OM_DATA_TYPE dataType) {
|
||||
return OM_DATA_TYPE_SIZE[dataType];
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief OMTensor data buffer size getter
|
||||
*
|
||||
* @param tensor pointer to the OMTensor
|
||||
* @return the total size of the data buffer in bytes.
|
||||
*/
|
||||
int64_t omTensorGetDataBufferSize(OMTensor *tensor);
|
||||
|
||||
/**
|
||||
* \brief OMTensor rank getter
|
||||
*
|
||||
* @param tensor, pointer to the OMTensor
|
||||
* @return rank of data sizes and strides of the OMTensor.
|
||||
*/
|
||||
int omTensorGetRank(OMTensor *tensor);
|
||||
|
||||
/**
|
||||
* \brief OMTensor number of elements getter
|
||||
*
|
||||
* @param tensor, pointer to the OMTensor
|
||||
* @return the number of elements in the data buffer.
|
||||
*/
|
||||
int64_t omTensorGetNumElems(OMTensor *tensor);
|
||||
|
||||
#endif // ONNX_MLIR_OMTENSOR_H
|
|
@ -0,0 +1,77 @@
|
|||
//===-------- OMTensorList.h - OMTensorList Declaration header-------------===//
|
||||
//
|
||||
// Copyright 2019-2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains declaration of OMTensorList data structures and
|
||||
// API functions.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef ONNX_MLIR_OMTENSORLIST_H
|
||||
#define ONNX_MLIR_OMTENSORLIST_H
|
||||
|
||||
#include "onnx-mlir/Runtime/OMTensor.h"
|
||||
|
||||
struct OMTensorList;
|
||||
|
||||
#ifndef __cplusplus
|
||||
typedef struct OMTensorList OMTensorList;
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \brief OMTensorList creator
|
||||
*
|
||||
* Create an OMTensorList with specified OMTensor array. The array of pointers
|
||||
* to OMTensor pointers is used without copying, so caller should not free the
|
||||
* `tensors` ptr.
|
||||
*
|
||||
* @param tensors array of pointers to OMTensor
|
||||
* @param n number of elements in tensors array
|
||||
* @return pointer to the OMTensorList created, NULL if creation failed.
|
||||
*
|
||||
*/
|
||||
OMTensorList *omTensorListCreate(OMTensor **tensors, int n);
|
||||
|
||||
/**
|
||||
* \brief OMTensorList destroyer
|
||||
*
|
||||
* Destroy the OMTensorList struct recursively. That is to say, both the
|
||||
* ptr to the OMTensor pointers AND the OMTensor pointers are freed.
|
||||
*
|
||||
* @param list pointer to the OMTensorList to be destroyed
|
||||
*
|
||||
*/
|
||||
void omTensorListDestroy(OMTensorList *list);
|
||||
|
||||
/**
|
||||
* \brief OMTensorList OMTensor array getter
|
||||
*
|
||||
* The pointer to OMTensor pointers are returned without copying, so caller
|
||||
* should not free the returned pointer.
|
||||
*
|
||||
* @param list pointer to the OMTensorList
|
||||
* @return pointer to the array of OMTensor pointers.
|
||||
*/
|
||||
OMTensor **omTensorListGetPtrToOmts(OMTensorList *list);
|
||||
|
||||
/**
|
||||
* \brief OMTensorList size getter
|
||||
*
|
||||
*
|
||||
* @param list pointer to the OMTensorList
|
||||
* @return number of elements in the OMTensor array.
|
||||
*/
|
||||
int omTensorListGetSize(OMTensorList *list);
|
||||
|
||||
/**
|
||||
* \brief OMTensorList OMTensor getter by index
|
||||
*
|
||||
* @param list pointer to the OMTensorList
|
||||
* @param index index of the OMTensor
|
||||
* @reutrn pointer to the OMTensor, NULL if not found.
|
||||
*/
|
||||
OMTensor *omTensorListGetOmtByIndex(OMTensorList *list, size_t index);
|
||||
|
||||
#endif // ONNX_MLIR_OMTENSORLIST_H
|
|
@ -0,0 +1,57 @@
|
|||
//===---------------------- DataType.h - ONNX DataTypes -------------------===//
|
||||
//
|
||||
// Copyright 2019-2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains declaration of ONNX data types and type size mapping.
|
||||
// It is provided as a convenience and not used by OMTensor implementation.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef ONNX_MLIR_ONNXDATATYPE_H
|
||||
#define ONNX_MLIR_ONNXDATATYPE_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#else
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#endif
|
||||
|
||||
enum OM_DATA_TYPE {
|
||||
#define OM_TYPE_METADATA_DEF(ENUM_NAME, ENUM_VAL, DTYPE_SIZE) \
|
||||
ENUM_NAME = ENUM_VAL,
|
||||
#include "OnnxDataTypeMetaData.inc"
|
||||
|
||||
#undef OM_TYPE_METADATA_DEF
|
||||
};
|
||||
|
||||
#ifndef __cplusplus
|
||||
typedef enum OM_DATA_TYPE OM_DATA_TYPE;
|
||||
#endif
|
||||
|
||||
extern const int OM_DATA_TYPE_SIZE[];
|
||||
|
||||
#ifdef __cplusplus
|
||||
// Note by design const map has no [] operator since [] creates a default
|
||||
// key value mapping when the key is not found which changes the map
|
||||
const std::map<std::string, OM_DATA_TYPE> OM_DATA_TYPE_CPP_TO_ONNX = {
|
||||
{"b", ONNX_TYPE_BOOL}, // bool -> BOOL
|
||||
{"c", ONNX_TYPE_INT8}, // char -> INT8 (platform dependent, can be UINT8)
|
||||
{"a", ONNX_TYPE_INT8}, // int8_t -> INT8
|
||||
{"h", ONNX_TYPE_UINT8}, // uint8_t -> UINT8, unsigned char -> UNIT 8
|
||||
{"s", ONNX_TYPE_INT16}, // int16_t -> INT16, short -> INT16
|
||||
{"t", ONNX_TYPE_UINT16}, // uint16_t -> UINT16, unsigned short -> UINT16
|
||||
{"i", ONNX_TYPE_INT32}, // int32_t -> INT32, int -> INT32
|
||||
{"j", ONNX_TYPE_UINT32}, // uint32_t -> UINT32, unsigned int -> UINT32
|
||||
{"l", ONNX_TYPE_INT64}, // int64_t -> INT64, long -> INT64
|
||||
{"m", ONNX_TYPE_UINT64}, // uint64_t -> UINT64, unsigned long -> UINT64
|
||||
{"f", ONNX_TYPE_FLOAT}, // float -> FLOAT
|
||||
{"d", ONNX_TYPE_DOUBLE}, // double -> DOUBLE
|
||||
};
|
||||
#endif //__cplusplus
|
||||
|
||||
#endif // ONNX_MLIR_ONNXDATATYPE_H
|
|
@ -0,0 +1,25 @@
|
|||
#if defined(OM_TYPE_METADATA_DEF)
|
||||
// Data type metadata declared in the following format:
|
||||
// OM_TYPE_METADATA_DEF( dtype enum name, dtype enum value, dtype size)
|
||||
// clang-format off
|
||||
OM_TYPE_METADATA_DEF(ONNX_TYPE_UNDEFINED, 0, 0)
|
||||
OM_TYPE_METADATA_DEF(ONNX_TYPE_FLOAT, 1, sizeof(float))
|
||||
OM_TYPE_METADATA_DEF(ONNX_TYPE_UINT8, 2, sizeof(uint8_t))
|
||||
OM_TYPE_METADATA_DEF(ONNX_TYPE_INT8, 3, sizeof(int8_t))
|
||||
OM_TYPE_METADATA_DEF(ONNX_TYPE_UINT16, 4, sizeof(uint16_t))
|
||||
OM_TYPE_METADATA_DEF(ONNX_TYPE_INT16, 5, sizeof(int16_t))
|
||||
OM_TYPE_METADATA_DEF(ONNX_TYPE_INT32, 6, sizeof(int32_t))
|
||||
OM_TYPE_METADATA_DEF(ONNX_TYPE_INT64, 7, sizeof(int64_t))
|
||||
OM_TYPE_METADATA_DEF(ONNX_TYPE_STRING, 8, 0)
|
||||
OM_TYPE_METADATA_DEF(ONNX_TYPE_BOOL, 9, sizeof(bool))
|
||||
OM_TYPE_METADATA_DEF(ONNX_TYPE_FLOAT16, 10, 2)
|
||||
OM_TYPE_METADATA_DEF(ONNX_TYPE_DOUBLE, 11, sizeof(double))
|
||||
OM_TYPE_METADATA_DEF(ONNX_TYPE_UINT32, 12, sizeof(uint32_t))
|
||||
OM_TYPE_METADATA_DEF(ONNX_TYPE_UINT64, 13, sizeof(uint64_t))
|
||||
OM_TYPE_METADATA_DEF(ONNX_TYPE_COMPLEX64, 14, 8)
|
||||
OM_TYPE_METADATA_DEF(ONNX_TYPE_COMPLEX128, 15, 16)
|
||||
OM_TYPE_METADATA_DEF(ONNX_TYPE_BFLOAT16, 16, 2)
|
||||
// clang-format on
|
||||
#else
|
||||
#error "Must define OM_TYPE_METADATA_DEF macro."
|
||||
#endif
|
|
@ -131,6 +131,26 @@ static FlatSymbolRefAttr getOrInsertMemcpy(
|
|||
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
|
||||
}
|
||||
|
||||
static FlatSymbolRefAttr getOrInsertMalloc(
|
||||
PatternRewriter &rewriter, ModuleOp module) {
|
||||
// Insert the malloc/aligned_alloc declaration if it is not already present.
|
||||
auto allocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc");
|
||||
auto ctx = rewriter.getContext();
|
||||
LLVMTypeConverter converter(ctx);
|
||||
if (!allocFunc) {
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(module.getBody());
|
||||
SmallVector<LLVM::LLVMType, 2> callArgTypes = {converter.getIndexType()};
|
||||
// aligned_alloc(size_t alignment, size_t size)
|
||||
auto voidPtrType = LLVM::LLVMType::getInt8PtrTy(&converter.getContext());
|
||||
allocFunc =
|
||||
rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(), "malloc",
|
||||
LLVM::LLVMType::getFunctionTy(voidPtrType, callArgTypes,
|
||||
/*isVarArg=*/false));
|
||||
}
|
||||
return SymbolRefAttr::get("malloc", ctx);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// KRNL to LLVM: KrnlGetRefOpLowering
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -399,16 +419,15 @@ public:
|
|||
using OpRewritePattern<KrnlEntryPointOp>::OpRewritePattern;
|
||||
|
||||
enum class API {
|
||||
CREATE_ORDERED_DYN_MEM_REF_DICT,
|
||||
CREATE_DYN_MEM_REF,
|
||||
GET_DYN_MEM_REF,
|
||||
SET_DYN_MEM_REF,
|
||||
CREATE_OMTENSOR_LIST,
|
||||
CREATE_OMTENSOR,
|
||||
GET_DATA,
|
||||
SET_DATA,
|
||||
GET_SIZES,
|
||||
GET_STRIDES,
|
||||
SET_DTYPE,
|
||||
GET_DTYPE,
|
||||
GET_DATA_SIZES,
|
||||
GET_DATA_STRIDES,
|
||||
SET_DATA_TYPE,
|
||||
GET_DATA_TYPE,
|
||||
GET_OMTS,
|
||||
};
|
||||
|
||||
struct ApiSpec {
|
||||
|
@ -443,6 +462,7 @@ public:
|
|||
using LLVMType = LLVM::LLVMType;
|
||||
auto opaquePtrTy = LLVMType::getInt8PtrTy(context);
|
||||
auto int32Ty = LLVMType::getInt32Ty(context);
|
||||
auto int64Ty = LLVMType::getInt64Ty(context);
|
||||
|
||||
// Rewrite Krnl Entry Point Operation to an LLVM function with a dynamic
|
||||
// signature. The signature is dynamic because it remains the same no matter
|
||||
|
@ -455,7 +475,7 @@ public:
|
|||
op.getAttrOfType<SymbolRefAttr>(
|
||||
KrnlEntryPointOp::getEntryPointFuncAttrName())
|
||||
.getLeafReference();
|
||||
auto dynEntryPointName = "_dyn_entry_point_" + staticEntryPointFuncName;
|
||||
auto dynEntryPointName = "run_" + staticEntryPointFuncName;
|
||||
assert(module.lookupSymbol(dynEntryPointName.str()) == nullptr &&
|
||||
"dynamic entry point name is not unique");
|
||||
rewriter.eraseOp(op);
|
||||
|
@ -484,12 +504,22 @@ public:
|
|||
// them to static mem refs.
|
||||
SmallVector<Value, 4> staticInputs;
|
||||
auto wrappedInput = entryPointEntryBlock.getArgument(0);
|
||||
|
||||
auto omTensorPtrArr =
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_OMTS, {wrappedInput});
|
||||
for (size_t i = 0; i < staticEntryPointTy.getFunctionNumParams(); i++) {
|
||||
// Call API function to retrieve the i-th dynamic memref.
|
||||
auto idxVal = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Ty, rewriter.getI32IntegerAttr(i));
|
||||
auto dynMemRef = callApi(rewriter, loc, apiRegistry, API::GET_DYN_MEM_REF,
|
||||
{wrappedInput, idxVal});
|
||||
|
||||
auto omTensorPtrAddrTy = opaquePtrTy.getPointerTo();
|
||||
auto omTensorPtrAddr = rewriter
|
||||
.create<LLVM::GEPOp>(loc, omTensorPtrAddrTy,
|
||||
omTensorPtrArr, ArrayRef<Value>({idxVal}))
|
||||
.getResult();
|
||||
auto omTensorPtr =
|
||||
rewriter.create<LLVM::LoadOp>(loc, opaquePtrTy, omTensorPtrAddr)
|
||||
.getResult();
|
||||
|
||||
// Create a (static) memref type corresponding to the i-th memref input to
|
||||
// the inference function on stack, and load it to memRef.
|
||||
|
@ -501,9 +531,9 @@ public:
|
|||
/*alignment=*/0);
|
||||
|
||||
// Fill in the memref underlying ptrToMemRef with information extracted
|
||||
// from dynMemRef.
|
||||
fillPtrToMemRefWithRtMemRef(
|
||||
dynMemRef, ptrToMemRef, rewriter, loc, apiRegistry, module);
|
||||
// from omTensorPtr.
|
||||
fillPtrToMemRefWithOMTensor(
|
||||
omTensorPtr, ptrToMemRef, rewriter, loc, apiRegistry, module);
|
||||
|
||||
// ptrToMemRef will be an input to main computation graph function.
|
||||
staticInputs.emplace_back(ptrToMemRef);
|
||||
|
@ -539,27 +569,57 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
// Create wrapped output.
|
||||
auto wrappedOutput = callApi(
|
||||
rewriter, loc, apiRegistry, API::CREATE_ORDERED_DYN_MEM_REF_DICT, {});
|
||||
auto numOutput = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Ty, rewriter.getI64IntegerAttr(outMemRefList.size()));
|
||||
|
||||
auto mallocSym = getOrInsertMalloc(rewriter, module);
|
||||
// TODO(tjingrant): get pointer size from data layout.
|
||||
size_t kPtrSize = 8;
|
||||
auto outputOmtPtrsArraySizeInByte = rewriter.create<LLVM::ConstantOp>(loc,
|
||||
int64Ty, rewriter.getI64IntegerAttr(outMemRefList.size() * kPtrSize));
|
||||
auto outOmtPtrsArr =
|
||||
rewriter
|
||||
.create<LLVM::CallOp>(loc,
|
||||
LLVM::LLVMType::getInt8PtrTy(module.getContext()), mallocSym,
|
||||
ArrayRef<Value>(outputOmtPtrsArraySizeInByte))
|
||||
.getResult(0);
|
||||
outOmtPtrsArr = rewriter
|
||||
.create<LLVM::BitcastOp>(loc,
|
||||
LLVM::LLVMType::getInt8PtrTy(module.getContext())
|
||||
.getPointerTo(0),
|
||||
outOmtPtrsArr)
|
||||
.getResult();
|
||||
|
||||
for (decltype(numOutputs) i = 0; i < outMemRefList.size(); i++) {
|
||||
// Get the i-th memref returned, convert to a dynamic memref and store it
|
||||
// in the wrappedOutput.
|
||||
|
||||
auto memRef = outMemRefList.at(i);
|
||||
auto outMemRefTy = memRef.getType().dyn_cast<LLVMType>();
|
||||
auto outMemRefRank = getRankFromMemRefType(outMemRefTy);
|
||||
auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
|
||||
auto outRtMemRef = callApi(rewriter, loc, apiRegistry,
|
||||
API::CREATE_DYN_MEM_REF, {outMemRefRankVal});
|
||||
fillRtMemRefWithMemRef(
|
||||
memRef, outRtMemRef, rewriter, loc, apiRegistry, module);
|
||||
auto idx = rewriter.create<LLVM::ConstantOp>(
|
||||
auto outOMTensor = callApi(
|
||||
rewriter, loc, apiRegistry, API::CREATE_OMTENSOR, {outMemRefRankVal});
|
||||
fillOMTensorWithMemRef(
|
||||
memRef, outOMTensor, rewriter, loc, apiRegistry, module);
|
||||
|
||||
auto idxVal = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Ty, rewriter.getI32IntegerAttr(i));
|
||||
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
|
||||
{wrappedOutput, idx, outRtMemRef});
|
||||
|
||||
auto omTensorPtrAddrTy = opaquePtrTy.getPointerTo();
|
||||
auto omTensorPtrAddr = rewriter
|
||||
.create<LLVM::GEPOp>(loc, omTensorPtrAddrTy,
|
||||
outOmtPtrsArr, ArrayRef<Value>({idxVal}))
|
||||
.getResult();
|
||||
|
||||
rewriter.create<LLVM::StoreOp>(loc, outOMTensor, omTensorPtrAddr);
|
||||
}
|
||||
|
||||
// Create wrapped output.
|
||||
auto wrappedOutput = callApi(rewriter, loc, apiRegistry,
|
||||
API::CREATE_OMTENSOR_LIST, {outOmtPtrsArr, numOutput});
|
||||
|
||||
// Return wrapped output.
|
||||
rewriter.create<LLVM::ReturnOp>(
|
||||
loc, SmallVector<Value, 1>({wrappedOutput}));
|
||||
|
@ -575,6 +635,7 @@ private:
|
|||
using LLVMType = LLVM::LLVMType;
|
||||
auto voidTy = LLVMType::getVoidTy(context);
|
||||
auto opaquePtrTy = LLVMType::getInt8PtrTy(context);
|
||||
auto opaquePtrPtrTy = opaquePtrTy.getPointerTo();
|
||||
auto int32Ty = LLVMType::getInt32Ty(context);
|
||||
auto int64Ty = LLVMType::getInt64Ty(context);
|
||||
auto int64PtrTy = int64Ty.getPointerTo();
|
||||
|
@ -583,16 +644,15 @@ private:
|
|||
// specifying its signature.
|
||||
// clang-format off
|
||||
std::vector<ApiSpec> apiSpecs = {
|
||||
ApiSpec(API::CREATE_ORDERED_DYN_MEM_REF_DICT, "createOrderedRtMemRefDict", opaquePtrTy, {}),
|
||||
ApiSpec(API::CREATE_DYN_MEM_REF, "createRtMemRef", opaquePtrTy, {int32Ty}),
|
||||
ApiSpec(API::GET_DATA, "getData", opaquePtrTy, {opaquePtrTy}),
|
||||
ApiSpec(API::SET_DATA, "setData", voidTy, {opaquePtrTy, opaquePtrTy}),
|
||||
ApiSpec(API::GET_DYN_MEM_REF, "getRtMemRef", opaquePtrTy, {opaquePtrTy, int32Ty}),
|
||||
ApiSpec(API::SET_DYN_MEM_REF, "setRtMemRef", voidTy, {opaquePtrTy, int32Ty, opaquePtrTy}),
|
||||
ApiSpec(API::GET_SIZES, "getSizes", int64PtrTy, {opaquePtrTy}),
|
||||
ApiSpec(API::GET_STRIDES, "getStrides", int64PtrTy, {opaquePtrTy}),
|
||||
ApiSpec(API::GET_DTYPE, "getDType", int32Ty, {opaquePtrTy}),
|
||||
ApiSpec(API::SET_DTYPE, "setDType", voidTy, {opaquePtrTy, int32Ty}),
|
||||
ApiSpec(API::CREATE_OMTENSOR_LIST, "omTensorListCreate", opaquePtrTy, {opaquePtrPtrTy, int32Ty}),
|
||||
ApiSpec(API::CREATE_OMTENSOR, "omTensorCreateEmptyDeprecated", opaquePtrTy, {int32Ty}),
|
||||
ApiSpec(API::GET_DATA, "omTensorGetDataPtr", opaquePtrTy, {opaquePtrTy}),
|
||||
ApiSpec(API::SET_DATA, "omTensorSetPtr", voidTy, {opaquePtrTy, int32Ty, opaquePtrTy, opaquePtrTy}),
|
||||
ApiSpec(API::GET_DATA_SIZES, "omTensorGetDataShape", int64PtrTy, {opaquePtrTy}),
|
||||
ApiSpec(API::GET_DATA_STRIDES, "omTensorGetStrides", int64PtrTy, {opaquePtrTy}),
|
||||
ApiSpec(API::GET_DATA_TYPE, "omTensorGetDataType", int32Ty, {opaquePtrTy}),
|
||||
ApiSpec(API::SET_DATA_TYPE, "omTensorSetDataType", voidTy, {opaquePtrTy, int32Ty}),
|
||||
ApiSpec(API::GET_OMTS, "omTensorListGetPtrToOmts", opaquePtrPtrTy, {opaquePtrTy}),
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
|
@ -645,7 +705,7 @@ private:
|
|||
return *entryPointEntryBlock;
|
||||
}
|
||||
|
||||
void fillPtrToMemRefWithRtMemRef(Value &dynMemRef, Value &ptrToMemRef,
|
||||
void fillPtrToMemRefWithOMTensor(Value &rtMemRef, Value &ptrToMemRef,
|
||||
PatternRewriter &rewriter, const Location &loc,
|
||||
const std::map<API, ApiSpec> &apiRegistry, ModuleOp &module) const {
|
||||
auto *context = module.getContext();
|
||||
|
@ -657,7 +717,7 @@ private:
|
|||
|
||||
// Set dataPtr and alignedDataPtr;
|
||||
auto dataPtr =
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_DATA, {dynMemRef});
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_DATA, {rtMemRef});
|
||||
dataPtr = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, memRefTy.getStructElementType(0), dataPtr);
|
||||
memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef,
|
||||
|
@ -674,9 +734,9 @@ private:
|
|||
// Get rank, sizes array ptr and strides array ptr.
|
||||
auto rank = getRankFromMemRefType(memRefTy);
|
||||
auto sizesArrayPtr =
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {dynMemRef});
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_DATA_SIZES, {rtMemRef});
|
||||
auto stridesArrayPtr =
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {dynMemRef});
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_DATA_STRIDES, {rtMemRef});
|
||||
|
||||
for (decltype(rank) i = 0; i < rank; i++) {
|
||||
auto dimIdx = rewriter.create<LLVM::ConstantOp>(
|
||||
|
@ -706,7 +766,7 @@ private:
|
|||
rewriter.create<LLVM::StoreOp>(loc, memRef, ptrToMemRef);
|
||||
}
|
||||
|
||||
void fillRtMemRefWithMemRef(Value &outMemRef, Value &outRtMemRef,
|
||||
void fillOMTensorWithMemRef(Value &outMemRef, Value &outOMTensor,
|
||||
PatternRewriter &rewriter, const Location &loc,
|
||||
const std::map<API, ApiSpec> &apiRegistry, ModuleOp &module) const {
|
||||
auto *context = module.getContext();
|
||||
|
@ -714,26 +774,40 @@ private:
|
|||
auto int64Ty = LLVM::LLVMType::getInt64Ty(context);
|
||||
auto int32Ty = LLVM::LLVMType::getInt32Ty(context);
|
||||
|
||||
// Extract the data pointer, and record it in dynamic mem ref created.
|
||||
Value outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(loc,
|
||||
// Set ownership to true, i.e., free after OMTensor is destroyed.
|
||||
Value owning = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Ty, rewriter.getI32IntegerAttr(1));
|
||||
|
||||
// Extract the allocated pointer.
|
||||
Value outMemRefAllocatedPtr = rewriter.create<LLVM::ExtractValueOp>(loc,
|
||||
outMemRefTy.getStructElementType(0), outMemRef,
|
||||
rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)}));
|
||||
outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, LLVM::LLVMType::getInt8PtrTy(context), outMemRefDataPtr);
|
||||
outMemRefAllocatedPtr = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, LLVM::LLVMType::getInt8PtrTy(context), outMemRefAllocatedPtr);
|
||||
|
||||
// Extract the aligned pointer.
|
||||
Value outMemRefAlignedPtr = rewriter.create<LLVM::ExtractValueOp>(loc,
|
||||
outMemRefTy.getStructElementType(1), outMemRef,
|
||||
rewriter.getArrayAttr({rewriter.getI64IntegerAttr(1)}));
|
||||
outMemRefAlignedPtr = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, LLVM::LLVMType::getInt8PtrTy(context), outMemRefAlignedPtr);
|
||||
|
||||
// Set ownership, allocated and aligned pointer.
|
||||
callApi(rewriter, loc, apiRegistry, API::SET_DATA,
|
||||
{outRtMemRef, outMemRefDataPtr});
|
||||
{outOMTensor, owning, outMemRefAllocatedPtr, outMemRefAlignedPtr});
|
||||
|
||||
auto elemTy = outMemRefTy.getStructElementType(0).getPointerElementTy();
|
||||
auto onnxTy = llvmTypeToOnnxType(elemTy);
|
||||
auto onnxTyVal = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Ty, rewriter.getI32IntegerAttr(onnxTy));
|
||||
callApi(
|
||||
rewriter, loc, apiRegistry, API::SET_DTYPE, {outRtMemRef, onnxTyVal});
|
||||
callApi(rewriter, loc, apiRegistry, API::SET_DATA_TYPE,
|
||||
{outOMTensor, onnxTyVal});
|
||||
|
||||
auto rank = getRankFromMemRefType(outMemRefTy);
|
||||
auto sizesArrayPtr =
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {outRtMemRef});
|
||||
auto stridesArrayPtr =
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {outRtMemRef});
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_DATA_SIZES, {outOMTensor});
|
||||
auto stridesArrayPtr = callApi(
|
||||
rewriter, loc, apiRegistry, API::GET_DATA_STRIDES, {outOMTensor});
|
||||
|
||||
for (decltype(rank) i = 0; i < rank; i++) {
|
||||
auto dimIdx = rewriter.create<LLVM::ConstantOp>(
|
||||
|
|
|
@ -5,43 +5,55 @@ add_subdirectory(jni)
|
|||
# such static library in a shared library can cause runtime failure on some architectures,
|
||||
# such as z. So we override the default and explicitly compile with -fPIC.
|
||||
add_library(cruntime STATIC
|
||||
RtMemRef.cpp
|
||||
RtMemRef.h
|
||||
DataType.h)
|
||||
OMTensor.c
|
||||
OMTensor.inc
|
||||
OMTensorList.c
|
||||
OMTensorList.inc
|
||||
OnnxDataType.cpp)
|
||||
set_target_properties(cruntime PROPERTIES
|
||||
LANGUAGE C)
|
||||
set_target_properties(cruntime PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE TRUE)
|
||||
target_include_directories(cruntime PRIVATE
|
||||
${ONNX_MLIR_SRC_ROOT}
|
||||
${ONNX_MLIR_SRC_ROOT}/include)
|
||||
|
||||
add_library(RtMemRefUtils
|
||||
RtMemRef.h
|
||||
RtMemRef.cpp
|
||||
DataType.h)
|
||||
set_target_properties(RtMemRefUtils PROPERTIES
|
||||
add_library(OMTensorUtils
|
||||
OMTensor.cpp
|
||||
OMTensor.inc
|
||||
OMTensorList.cpp
|
||||
OMTensorList.inc
|
||||
OnnxDataType.cpp)
|
||||
set_target_properties(OMTensorUtils PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE TRUE)
|
||||
target_compile_definitions(OMTensorUtils PRIVATE RTMEMREF_INTERNAL_API)
|
||||
target_include_directories(OMTensorUtils PRIVATE
|
||||
${ONNX_MLIR_SRC_ROOT}
|
||||
${ONNX_MLIR_SRC_ROOT}/include)
|
||||
|
||||
add_library(ExecutionSession
|
||||
ExecusionSession.hpp
|
||||
ExecusionSession.cpp)
|
||||
ExecutionSession.hpp
|
||||
ExecutionSession.cpp)
|
||||
target_include_directories(ExecutionSession PRIVATE
|
||||
${ONNX_MLIR_SRC_ROOT}/src/Runtime
|
||||
${ONNX_MLIR_SRC_ROOT}/include)
|
||||
target_link_libraries(ExecutionSession
|
||||
${CMAKE_DL_LIBS})
|
||||
target_include_directories(ExecutionSession PRIVATE
|
||||
${ONNX_MLIR_SRC_ROOT}
|
||||
${ONNX_MLIR_BIN_ROOT}
|
||||
${ONNX_MLIR_SRC_ROOT})
|
||||
set_target_properties(ExecutionSession PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE TRUE)
|
||||
|
||||
pybind11_add_module(PyRuntime
|
||||
PyExecutionSession.cpp
|
||||
PyExecutionSession.hpp)
|
||||
target_include_directories(PyRuntime PRIVATE
|
||||
${ONNX_MLIR_SRC_ROOT}
|
||||
${ONNX_MLIR_SRC_ROOT}/src/Runtime
|
||||
${ONNX_MLIR_SRC_ROOT}/include)
|
||||
target_link_libraries(PyRuntime PRIVATE
|
||||
${CMAKE_DL_LIBS}
|
||||
ExecutionSession
|
||||
RtMemRefUtils
|
||||
OMTensorUtils
|
||||
onnx)
|
||||
target_include_directories(PyRuntime PRIVATE
|
||||
${ONNX_MLIR_SRC_ROOT}
|
||||
${ONNX_MLIR_BIN_ROOT}
|
||||
${ONNX_MLIR_SRC_ROOT})
|
||||
|
||||
# See comments above about libcruntime.a
|
||||
add_library(EmbeddedDataLoader STATIC
|
||||
|
@ -51,6 +63,6 @@ set_target_properties(EmbeddedDataLoader PROPERTIES
|
|||
POSITION_INDEPENDENT_CODE TRUE)
|
||||
|
||||
add_dependencies(PyRuntime cruntime)
|
||||
install(FILES RtMemRef.h DESTINATION include)
|
||||
|
||||
install(TARGETS cruntime DESTINATION lib)
|
||||
install(TARGETS EmbeddedDataLoader DESTINATION lib)
|
|
@ -1,40 +0,0 @@
|
|||
//===---------------------- DataType.h - ONNX DataTypes -------------------===//
|
||||
//
|
||||
// Copyright 2019-2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains declaration of ONNX data types.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
enum DYN_MEMREF_DATA_TYPE {
|
||||
UNDEFINED = 0;
|
||||
// Basic types.
|
||||
FLOAT = 1; // float
|
||||
UINT8 = 2; // uint8_t
|
||||
INT8 = 3; // int8_t
|
||||
UINT16 = 4; // uint16_t
|
||||
INT16 = 5; // int16_t
|
||||
INT32 = 6; // int32_t
|
||||
INT64 = 7; // int64_t
|
||||
STRING = 8; // string
|
||||
BOOL = 9; // bool
|
||||
|
||||
// IEEE754 half-precision floating-point format (16 bits wide).
|
||||
// This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
|
||||
FLOAT16 = 10;
|
||||
|
||||
DOUBLE = 11;
|
||||
UINT32 = 12;
|
||||
UINT64 = 13;
|
||||
COMPLEX64 = 14; // complex with float32 real and imaginary components
|
||||
COMPLEX128 = 15; // complex with float64 real and imaginary components
|
||||
|
||||
// Non-IEEE floating-point format based on IEEE754 single-precision
|
||||
// floating-point number truncated to 16 bits.
|
||||
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
|
||||
BFLOAT16 = 16;
|
||||
|
||||
// Future extensions go here.
|
||||
};
|
|
@ -1,10 +1,10 @@
|
|||
//===------- ExecusionSession.cpp - ExecutionSession Implementation -------===//
|
||||
//===------- ExecutionSession.cpp - ExecutionSession Implementation -------===//
|
||||
//
|
||||
// Copyright 2019-2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains implementations of ExecusionSession class, which helps C++
|
||||
// This file contains implementations of ExecutionSession class, which helps C++
|
||||
// programs interact with compiled binary model libraries.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -14,7 +14,7 @@
|
|||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#include "ExecusionSession.hpp"
|
||||
#include "ExecutionSession.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
|
@ -42,19 +42,22 @@ ExecutionSession::ExecutionSession(
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<RtMemRef>> ExecutionSession::run(
|
||||
std::vector<std::unique_ptr<RtMemRef>> ins) {
|
||||
auto *wrappedInput = createOrderedRtMemRefDict();
|
||||
for (size_t i = 0; i < ins.size(); i++)
|
||||
setRtMemRef(wrappedInput, i, ins.at(i).get());
|
||||
std::vector<std::unique_ptr<OMTensor, decltype(&omTensorDestroy)>>
|
||||
ExecutionSession::run(
|
||||
std::vector<std::unique_ptr<OMTensor, decltype(&omTensorDestroy)>> ins) {
|
||||
|
||||
std::vector<OMTensor *> omts;
|
||||
for (const auto &inOmt : ins)
|
||||
omts.emplace_back(inOmt.get());
|
||||
auto *wrappedInput = omTensorListCreate(&omts[0], omts.size());
|
||||
|
||||
auto *wrappedOutput = _entryPointFunc(wrappedInput);
|
||||
|
||||
std::vector<std::unique_ptr<RtMemRef>> outs;
|
||||
auto outputSize = getSize(wrappedOutput);
|
||||
std::vector<std::unique_ptr<OMTensor, decltype(&omTensorDestroy)>> outs;
|
||||
|
||||
for (size_t i = 0; i < getSize(wrappedOutput); i++) {
|
||||
outs.emplace_back(std::unique_ptr<RtMemRef>(getRtMemRef(wrappedOutput, i)));
|
||||
for (size_t i = 0; i < omTensorListGetSize(wrappedOutput); i++) {
|
||||
outs.emplace_back(std::unique_ptr<OMTensor, decltype(&omTensorDestroy)>(
|
||||
omTensorListGetOmtByIndex(wrappedOutput, i), omTensorDestroy));
|
||||
}
|
||||
return std::move(outs);
|
||||
}
|
|
@ -1,10 +1,10 @@
|
|||
//===--------- ExecusionSession.hpp - ExecutionSession Declaration --------===//
|
||||
//===--------- ExecutionSession.hpp - ExecutionSession Declaration --------===//
|
||||
//
|
||||
// Copyright 2019-2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains declarations of ExecusionSession class, which helps C++
|
||||
// This file contains declarations of ExecutionSession class, which helps C++
|
||||
// programs interact with compiled binary model libraries.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -15,18 +15,19 @@
|
|||
#include <dlfcn.h>
|
||||
#include <string>
|
||||
|
||||
#include "src/Runtime/RtMemRef.h"
|
||||
#include "OnnxMlirRuntime.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
typedef OrderedRtMemRefDict *(*entryPointFuncType)(OrderedRtMemRefDict *);
|
||||
typedef OMTensorList *(*entryPointFuncType)(OMTensorList *);
|
||||
|
||||
class ExecutionSession {
|
||||
public:
|
||||
ExecutionSession(std::string sharedLibPath, std::string entryPointName);
|
||||
|
||||
std::vector<std::unique_ptr<RtMemRef>> run(
|
||||
std::vector<std::unique_ptr<RtMemRef>>);
|
||||
// Use custom deleter since forward declared OMTensor hides destructor
|
||||
std::vector<std::unique_ptr<OMTensor, decltype(&omTensorDestroy)>> run(
|
||||
std::vector<std::unique_ptr<OMTensor, decltype(&omTensorDestroy)>>);
|
||||
|
||||
~ExecutionSession();
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
//===-------------- OMTensor.cpp - OMTensor C Implementation --------------===//
|
||||
//
|
||||
// Copyright 2019-2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains implementation of OMTensor and data structures and
|
||||
// helper functions.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "OMTensor.inc"
|
|
@ -0,0 +1,12 @@
|
|||
//===------------- OMTensor.cpp - OMTensor C++ Implementation -------------===//
|
||||
//
|
||||
// Copyright 2019-2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains implementation of OMTensor and data structures and
|
||||
// helper functions.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "OMTensor.inc"
|
|
@ -0,0 +1,499 @@
|
|||
//===--------- OMTensor.inc - C/C++ Neutral OMTensor Implementation--------===//
|
||||
//
|
||||
// Copyright 2019-2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains implementations of OMTensor data structures
|
||||
// and helper functions.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifdef __cplusplus
|
||||
#include <cassert>
|
||||
#include <map>
|
||||
#include <numeric>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <typeinfo>
|
||||
#include <vector>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#ifdef __APPLE__
|
||||
#include <stdlib.h>
|
||||
#else
|
||||
#include <malloc.h>
|
||||
#endif
|
||||
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "onnx-mlir/Runtime/OMTensor.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
#include "src/Runtime/OMTensorHelper.h"
|
||||
#endif
|
||||
|
||||
struct OMTensor {
|
||||
#ifdef __cplusplus
|
||||
/**
|
||||
* Constructor
|
||||
*
|
||||
* @param rank, rank of data sizes and strides
|
||||
*
|
||||
* Create a OMTensor with specified rank. Memory for data sizes and strides
|
||||
* are allocated.
|
||||
*/
|
||||
OMTensor(int rank) {
|
||||
if ((_shape = (int64_t *)malloc(rank * sizeof(int64_t))) &&
|
||||
(_stride = (int64_t *)malloc(rank * sizeof(int64_t)))) {
|
||||
assert(_shape);
|
||||
assert(_stride);
|
||||
_allocatedPtr = NULL;
|
||||
_alignedPtr = NULL;
|
||||
_offset = 0;
|
||||
_dataType = ONNX_TYPE_UNDEFINED;
|
||||
_rank = rank;
|
||||
_owning = false;
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"OMTensor(" + std::to_string(rank) + ") malloc error");
|
||||
}
|
||||
};
|
||||
|
||||
OMTensor() = default;
|
||||
|
||||
/**
|
||||
* Destructor
|
||||
*
|
||||
* Destroy the OMTensor struct.
|
||||
*/
|
||||
~OMTensor() {
|
||||
if (_owning)
|
||||
free(_allocatedPtr);
|
||||
free(_shape);
|
||||
free(_stride);
|
||||
};
|
||||
#endif
|
||||
// Fields are named according to:
|
||||
// https://mlir.llvm.org/docs/Dialects/SPIR-V/#lowering-memrefs-to-spvarray-and-spvrtarray
|
||||
|
||||
// On machine without alignment constraints the allocated and aligned pointers
|
||||
// are the same. However, on machines with alignment constraints not supported
|
||||
// by the memory allocation system, the allocated ptr points to the chunk of
|
||||
// memory that is allocated, and the aligned pointer points to a chunk of
|
||||
// memory that is allocated and also satisfy the alignment constraints on the
|
||||
// machine. For example, on a machine for which malloc returns chunks aligned
|
||||
// at 16 byte boundaries, but where tensor must be allocated at 1K boundaries
|
||||
// for performance reason, allocated pointer may return 0x1000f and aligned
|
||||
// pointer may return 0x10400.
|
||||
void *_allocatedPtr; // data buffer
|
||||
void *_alignedPtr; // aligned data buffer that the omt indexes.
|
||||
|
||||
int64_t _offset; // offset of 1st element
|
||||
int64_t *_shape; // sizes array
|
||||
int64_t *_stride; // strides array
|
||||
int64_t _rank; // rank
|
||||
|
||||
OM_DATA_TYPE _dataType; // ONNX data type
|
||||
|
||||
int _owning; // indicates whether the Omt owns the memory space
|
||||
// referenced by _allocatedPtr. Omt struct will release the
|
||||
// memory space referred to by _allocatedPtr upon destruction if
|
||||
// and only if it owns it.
|
||||
};
|
||||
|
||||
// Create a OMTensor.
|
||||
OMTensor *omTensorCreate(
|
||||
void *data_ptr, int64_t *shape, int64_t rank, OM_DATA_TYPE dtype) {
|
||||
OMTensor *tensor = (OMTensor *)malloc(sizeof(OMTensor));
|
||||
if (!tensor)
|
||||
return NULL;
|
||||
if ((tensor->_shape = (int64_t *)malloc(rank * sizeof(int64_t))) &&
|
||||
(tensor->_stride = (int64_t *)malloc(rank * sizeof(int64_t)))) {
|
||||
// If malloc for _shape or _stride fails, free them and return NULL.
|
||||
if (!tensor->_shape || !tensor->_stride) {
|
||||
if (tensor->_shape)
|
||||
free(tensor->_shape);
|
||||
if (tensor->_stride)
|
||||
free(tensor->_stride);
|
||||
return NULL;
|
||||
}
|
||||
tensor->_allocatedPtr = data_ptr;
|
||||
tensor->_alignedPtr = data_ptr;
|
||||
tensor->_rank = rank;
|
||||
tensor->_dataType = dtype;
|
||||
tensor->_owning = false;
|
||||
}
|
||||
// Using signed indices helps detect when index falls below 0.
|
||||
for (int64_t i = rank - 1; i >= 0; i--) {
|
||||
tensor->_shape[i] = shape[i];
|
||||
if (i == rank - 1)
|
||||
tensor->_stride[i] = 1;
|
||||
else
|
||||
tensor->_stride[i] = tensor->_stride[i + 1] * tensor->_shape[i + 1];
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// Create a OMTensor.
|
||||
OMTensor *omTensorCreateWithOwnership(void *data_ptr, int64_t *shape,
|
||||
int64_t rank, OM_DATA_TYPE dtype, int owning) {
|
||||
OMTensor *tensor = omTensorCreate(data_ptr, shape, rank, dtype);
|
||||
// If ctor fails, return NULL.
|
||||
if (!tensor)
|
||||
return NULL;
|
||||
tensor->_owning = owning;
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// Create a OMTensor.
|
||||
OMTensor *omTensorCreateEmptyDeprecated(int rank) {
|
||||
OMTensor *omt = (OMTensor *)malloc(sizeof(struct OMTensor));
|
||||
if (!omt)
|
||||
return NULL;
|
||||
if ((omt->_shape = (int64_t *)malloc(rank * sizeof(int64_t))) &&
|
||||
(omt->_stride = (int64_t *)malloc(rank * sizeof(int64_t)))) {
|
||||
// If malloc for _shape or _stride fails, free them and return NULL.
|
||||
if (!omt->_shape || !omt->_stride) {
|
||||
if (omt->_shape)
|
||||
free(omt->_shape);
|
||||
if (omt->_stride)
|
||||
free(omt->_stride);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
omt->_allocatedPtr = NULL;
|
||||
omt->_alignedPtr = NULL;
|
||||
omt->_offset = 0;
|
||||
omt->_dataType = ONNX_TYPE_UNDEFINED;
|
||||
omt->_rank = rank;
|
||||
omt->_owning = false;
|
||||
}
|
||||
return omt;
|
||||
}
|
||||
|
||||
OMTensor *omTensorCreateEmpty(
|
||||
int64_t *shape, int64_t rank, OM_DATA_TYPE dtype) {
|
||||
OMTensor *tensor =
|
||||
omTensorCreateWithOwnership(NULL, shape, rank, dtype, /*owning=*/true);
|
||||
// If ctor fails, return null.
|
||||
if (!tensor)
|
||||
return NULL;
|
||||
void *dataPtr = malloc(omTensorGetNumElems(tensor) * getDataTypeSize(dtype));
|
||||
if (!dataPtr)
|
||||
return NULL;
|
||||
tensor->_alignedPtr = dataPtr;
|
||||
tensor->_allocatedPtr = dataPtr;
|
||||
return tensor;
|
||||
}
|
||||
|
||||
/* OMTensor destroyer */
|
||||
void omTensorDestroy(OMTensor *tensor) {
|
||||
if (tensor->_owning) {
|
||||
free(tensor->_allocatedPtr);
|
||||
tensor->_allocatedPtr = NULL;
|
||||
tensor->_alignedPtr = NULL;
|
||||
}
|
||||
free(tensor);
|
||||
}
|
||||
|
||||
/* OMTensor data getter */
|
||||
void *omTensorGetDataPtr(OMTensor *tensor) { return tensor->_alignedPtr; }
|
||||
|
||||
/**
|
||||
* OMTensor allocated and aligned pointer setter.
|
||||
* This function is intentionally left out from the header because it is only
|
||||
* used by the wrapper code we emit around inference function that converts
|
||||
* MemRefs to OMTensors for user convenience.
|
||||
*
|
||||
* @param tensor pointer to the OMTensor
|
||||
* @param owning whether allocatedPtr should be freed after tensor is destroyed.
|
||||
* @param allocatedPtr allocated pointer to tensor content.
|
||||
* @param alignedPtr aligned pointer to tensor content. If NULL will be set to
|
||||
* allocatedPtr.
|
||||
*
|
||||
*/
|
||||
void omTensorSetPtr(
|
||||
OMTensor *tensor, int owning, void *allocatedPtr, void *alignedPtr) {
|
||||
if (tensor->_owning) {
|
||||
/* If we own the allocated buffer, free it first. */
|
||||
free(tensor->_allocatedPtr);
|
||||
}
|
||||
tensor->_owning = owning;
|
||||
tensor->_allocatedPtr = allocatedPtr;
|
||||
if (alignedPtr)
|
||||
tensor->_alignedPtr = alignedPtr;
|
||||
else
|
||||
tensor->_alignedPtr = allocatedPtr;
|
||||
}
|
||||
|
||||
/* OMTensor data sizes getter */
|
||||
int64_t *omTensorGetDataShape(OMTensor *tensor) { return tensor->_shape; }
|
||||
|
||||
/* OMTensor data sizes setter */
|
||||
void omTensorSetShape(OMTensor *tensor, int64_t *shape) {
|
||||
for (int i = 0; i < tensor->_rank; i++)
|
||||
tensor->_shape[i] = shape[i];
|
||||
}
|
||||
|
||||
/* OMTensor data strides getter */
|
||||
int64_t *omTensorGetStrides(OMTensor *tensor) { return tensor->_stride; }
|
||||
|
||||
/* OMTensor data strides setter */
|
||||
void omTensorSetStrides(OMTensor *tensor, int64_t *strides) {
|
||||
for (int i = 0; i < tensor->_rank; i++)
|
||||
tensor->_stride[i] = strides[i];
|
||||
}
|
||||
|
||||
/* OMTensor data type getter */
|
||||
OM_DATA_TYPE omTensorGetDataType(OMTensor *tensor) { return tensor->_dataType; }
|
||||
|
||||
/* OMTensor data type setter */
|
||||
void omTensorSetDataType(OMTensor *tensor, OM_DATA_TYPE dataType) {
|
||||
tensor->_dataType = dataType;
|
||||
}
|
||||
|
||||
/* OMTensor data buffer size getter */
|
||||
int64_t omTensorGetDataBufferSize(OMTensor *tensor) {
|
||||
return getNumOfElems(tensor->_shape, tensor->_rank) *
|
||||
getDataTypeSize(tensor->_dataType);
|
||||
}
|
||||
|
||||
/* OMTensor rank getter */
|
||||
int omTensorGetRank(OMTensor *tensor) { return tensor->_rank; }
|
||||
|
||||
/* OMTensor number of elements getter */
|
||||
int64_t omTensorGetNumElems(OMTensor *tensor) {
|
||||
// Using signed indices helps detect when index falls below 0.
|
||||
// Verify that strides are dense, meaning that there're
|
||||
// no skipping elements.
|
||||
for (int64_t i = tensor->_rank - 1; i >= 0; i--) {
|
||||
int64_t strideIfNotSkipping = 1;
|
||||
for (int64_t j = i + 1; j < tensor->_rank; j++) {
|
||||
strideIfNotSkipping *= tensor->_shape[j];
|
||||
}
|
||||
assert(tensor->_stride[i] == strideIfNotSkipping);
|
||||
}
|
||||
return getNumOfElems(tensor->_shape, tensor->_rank);
|
||||
}
|
||||
|
||||
/**
|
||||
* OMTensor allocated ptr getter.
|
||||
* Note that this function is intentionally left out from the header
|
||||
* because it is only used by the wrapper code we emit around inference
|
||||
* function that converts OMTensor into MemRefs for user convenience.
|
||||
*
|
||||
* @param tensor pointer to the OMTensor
|
||||
* @return pointer to the allocated data buffer of the OMTensor,
|
||||
* NULL if the allocated data buffer is not set.
|
||||
*/
|
||||
void *omTensorGetAllocatedPtr(OMTensor *tensor) {
|
||||
return tensor->_allocatedPtr;
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
/* OMTensor creator with data sizes and element type */
|
||||
template <typename T>
|
||||
OMTensor *omTensorCreateWithShape(std::vector<int64_t> dataSizes) {
|
||||
/* Create a OMTensor with data sizes and strides allocated */
|
||||
auto omt = omTensorCreateEmptyDeprecated(dataSizes.size());
|
||||
if (omt == NULL)
|
||||
return NULL;
|
||||
|
||||
/* Allocate data buffer */
|
||||
omt->_rank = dataSizes.size();
|
||||
if ((omt->_allocatedPtr = malloc(
|
||||
getNumOfElems(dataSizes.data(), omt->_rank) * sizeof(T))) == NULL) {
|
||||
omTensorDestroy(omt);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
omt->_alignedPtr = omt->_allocatedPtr;
|
||||
omt->_offset = 0;
|
||||
|
||||
/* Copy dataSizes, _shape already allocated by omTensorCreate */
|
||||
copy(dataSizes.begin(), dataSizes.end(), omt->_shape);
|
||||
|
||||
/* Compute and copy dataStrides, _stride already allocated by
|
||||
* omTensorCreateEmptyDeprecated
|
||||
*/
|
||||
auto computedStrides = computeStridesFromSizes(omt->_shape, omt->_rank);
|
||||
copy(computedStrides.begin(), computedStrides.end(), omt->_stride);
|
||||
|
||||
/* Convert CPP type to ONNX type */
|
||||
try {
|
||||
omt->_dataType = OM_DATA_TYPE_CPP_TO_ONNX.at(std::string(typeid(T).name()));
|
||||
} catch (const std::out_of_range &e) {
|
||||
omt->_dataType = ONNX_TYPE_UNDEFINED;
|
||||
}
|
||||
|
||||
/* Set flag for destructor */
|
||||
omt->_owning = true;
|
||||
|
||||
return omt;
|
||||
}
|
||||
|
||||
/* OMTensor creator with data sizes, element type and random data */
|
||||
template <typename T>
|
||||
OMTensor *omTensorCreateWithRandomData(
|
||||
std::vector<int64_t> dataSizes, T lbound, T ubound) {
|
||||
// Will be used to obtain a seed for the random number engine
|
||||
std::random_device rd;
|
||||
// Standard mersenne_twister_engine seeded with rd()
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_real_distribution<> dis(lbound, ubound);
|
||||
|
||||
auto omt = omTensorCreateWithShape<T>(dataSizes);
|
||||
if (omt == NULL)
|
||||
return NULL;
|
||||
|
||||
std::generate((T *)omt->_allocatedPtr,
|
||||
(T *)omt->_allocatedPtr + getNumOfElems(omt->_shape, omt->_rank),
|
||||
[&]() { return dis(gen); });
|
||||
return omt;
|
||||
}
|
||||
|
||||
/* Access an element (by reference) at offset computed by index array */
|
||||
template <typename T>
|
||||
T &omTensorGetElem(OMTensor *omt, std::vector<int64_t> indexes) {
|
||||
int64_t elemOffset = omTensorComputeElemOffset(omt, indexes);
|
||||
return ((T *)omt->_allocatedPtr)[elemOffset];
|
||||
}
|
||||
|
||||
/* Access an element (by reference) at linear offset */
|
||||
template <typename T>
|
||||
T &omTensorGetElemByOffset(OMTensor *omt, int64_t index) {
|
||||
return ((T *)omt->_allocatedPtr)[index];
|
||||
}
|
||||
|
||||
/* Compute strides vector from sizes vector */
|
||||
std::vector<int64_t> omTensorComputeStridesFromShape(OMTensor *omt) {
|
||||
return computeStridesFromSizes(omt->_shape, omt->_rank);
|
||||
}
|
||||
|
||||
/* Compute linear element offset from multi-dimensional index array */
|
||||
int64_t omTensorComputeElemOffset(
|
||||
OMTensor *omt, std::vector<int64_t> &indexes) {
|
||||
return computeElemOffset(omt->_stride, omt->_rank, indexes);
|
||||
}
|
||||
|
||||
/* Compute index set for the whole OMTensor */
|
||||
std::vector<std::vector<int64_t>> omTensorComputeIndexSet(OMTensor *omt) {
|
||||
// First, we create index set of each dimension separately.
|
||||
// i.e., for a tensor/OMT of shape (2, 3), its dimWiseIdxSet will be:
|
||||
// {{0,1}, {0,1,2}};
|
||||
std::vector<std::vector<int64_t>> dimWiseIdxSet;
|
||||
for (auto dimSize :
|
||||
std::vector<int64_t>(omt->_shape, omt->_shape + omt->_rank)) {
|
||||
std::vector<int64_t> dimIdxSet(dimSize);
|
||||
iota(begin(dimIdxSet), end(dimIdxSet), 0);
|
||||
dimWiseIdxSet.emplace_back(dimIdxSet);
|
||||
}
|
||||
// Then, the cartesian product of vectors within dimWiseIdxSet will be the
|
||||
// index set for the whole OMT.
|
||||
return CartProduct(dimWiseIdxSet);
|
||||
}
|
||||
|
||||
/* Check whether two OMTensor data are "close" to each other */
|
||||
template <typename T>
|
||||
inline bool omTensorAreTwoOmtsClose(
|
||||
OMTensor *a, OMTensor *b, float rtol, float atol) {
|
||||
|
||||
// Compare shape.
|
||||
auto aShape = std::vector<int64_t>(a->_shape, a->_shape + a->_rank);
|
||||
auto bShape = std::vector<int64_t>(b->_shape, b->_shape + b->_rank);
|
||||
if (aShape != bShape) {
|
||||
std::cerr << "Shape mismatch ";
|
||||
printVector(aShape, ",", std::cerr);
|
||||
std::cerr << " != ";
|
||||
printVector(bShape, ",", std::cerr);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Compute absolute difference, verify it's within tolerable range.
|
||||
auto anum = omTensorGetNumElems(a);
|
||||
std::vector<T> absoluteDiff(anum);
|
||||
std::transform((T *)a->_allocatedPtr, (T *)a->_allocatedPtr + anum,
|
||||
(T *)b->_allocatedPtr, absoluteDiff.begin(), std::minus<>());
|
||||
std::transform(absoluteDiff.begin(), absoluteDiff.end(), absoluteDiff.begin(),
|
||||
static_cast<T (*)(T)>(&std::abs));
|
||||
bool atolSatisfied = std::all_of(
|
||||
absoluteDiff.begin(), absoluteDiff.end(), [&](T a) { return a < atol; });
|
||||
|
||||
// Compute relative difference, verify it's within tolerable range.
|
||||
std::vector<T> relativeDiff(anum);
|
||||
std::transform(absoluteDiff.begin(), absoluteDiff.end(),
|
||||
(T *)a->_allocatedPtr, relativeDiff.begin(), std::divides<>());
|
||||
bool rtolSatisfied = all_of(
|
||||
relativeDiff.begin(), relativeDiff.end(), [&](T a) { return a < rtol; });
|
||||
|
||||
if (atolSatisfied && rtolSatisfied) {
|
||||
return true;
|
||||
} else {
|
||||
// Figure out where and what went wrong, this can be slow; but hopefully we
|
||||
// don't need this often.
|
||||
for (const auto &idx : omTensorComputeIndexSet(a)) {
|
||||
T aElem = omTensorGetElem<T>(a, idx);
|
||||
T bElem = omTensorGetElem<T>(b, idx);
|
||||
auto elmAbsDiff = abs(aElem - bElem);
|
||||
auto withinRtol = (elmAbsDiff / aElem < rtol);
|
||||
auto withinAtol = (elmAbsDiff < atol);
|
||||
if (!withinRtol || !withinAtol) {
|
||||
std::cerr << "a[";
|
||||
printVector(idx, ",", std::cerr);
|
||||
std::cerr << "] = " << aElem << " != ";
|
||||
std::cerr << "b[";
|
||||
printVector(idx, ",", std::cerr);
|
||||
std::cerr << "] = " << bElem << std::endl;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Explicit instantiation of all templated API functions.
|
||||
|
||||
template OMTensor *omTensorCreateWithShape<int32_t>(
|
||||
std::vector<int64_t> dataSizes);
|
||||
template OMTensor *omTensorCreateWithShape<int64_t>(
|
||||
std::vector<int64_t> dataSizes);
|
||||
template OMTensor *omTensorCreateWithShape<float>(
|
||||
std::vector<int64_t> dataSizes);
|
||||
template OMTensor *omTensorCreateWithShape<double>(
|
||||
std::vector<int64_t> dataSizes);
|
||||
|
||||
template OMTensor *omTensorCreateWithRandomData<int32_t>(
|
||||
std::vector<int64_t> dataSizes, int32_t lbound, int32_t ubound);
|
||||
template OMTensor *omTensorCreateWithRandomData<int64_t>(
|
||||
std::vector<int64_t> dataSizes, int64_t lbound, int64_t ubound);
|
||||
template OMTensor *omTensorCreateWithRandomData<float>(
|
||||
std::vector<int64_t> dataSizes, float lbound, float ubound);
|
||||
template OMTensor *omTensorCreateWithRandomData<double>(
|
||||
std::vector<int64_t> dataSizes, double lbound, double ubound);
|
||||
|
||||
template int32_t &omTensorGetElem<int32_t>(
|
||||
OMTensor *, std::vector<int64_t> indexes);
|
||||
template int64_t &omTensorGetElem<int64_t>(
|
||||
OMTensor *, std::vector<int64_t> indexes);
|
||||
template float &omTensorGetElem<float>(
|
||||
OMTensor *, std::vector<int64_t> indexes);
|
||||
template double &omTensorGetElem<double>(
|
||||
OMTensor *, std::vector<int64_t> indexes);
|
||||
|
||||
template int32_t &omTensorGetElemByOffset<int32_t>(OMTensor *, int64_t index);
|
||||
template int64_t &omTensorGetElemByOffset<int64_t>(OMTensor *, int64_t index);
|
||||
template float &omTensorGetElemByOffset<float>(OMTensor *, int64_t indexs);
|
||||
template double &omTensorGetElemByOffset<double>(OMTensor *, int64_t index);
|
||||
|
||||
template bool omTensorAreTwoOmtsClose<int32_t>(
|
||||
OMTensor *a, OMTensor *b, float rtol, float atol);
|
||||
template bool omTensorAreTwoOmtsClose<int64_t>(
|
||||
OMTensor *a, OMTensor *b, float rtol, float atol);
|
||||
template bool omTensorAreTwoOmtsClose<float>(
|
||||
OMTensor *a, OMTensor *b, float rtol, float atol);
|
||||
template bool omTensorAreTwoOmtsClose<double>(
|
||||
OMTensor *a, OMTensor *b, float rtol, float atol);
|
||||
#endif
|
|
@ -0,0 +1,160 @@
|
|||
//===----------- OMTensorHelper.h - OMTensor Helper Func header -----------===//
|
||||
//
|
||||
// Copyright 2019-2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains declaration of OMTensor C++ helper functions. At some
|
||||
// point, this file needs to be merged into the OMTensor.h along with other C++
|
||||
// APIs operating on OMTensor.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef ONNX_MLIR_OMTENSORHELPER_H
|
||||
#define ONNX_MLIR_OMTENSORHELPER_H
|
||||
|
||||
/* Helper function to compute cartisian product */
|
||||
static inline std::vector<std::vector<int64_t>> CartProduct(
|
||||
const std::vector<std::vector<int64_t>> &v) {
|
||||
std::vector<std::vector<int64_t>> s = {{}};
|
||||
for (const auto &u : v) {
|
||||
std::vector<std::vector<int64_t>> r;
|
||||
for (const auto &x : s) {
|
||||
for (const auto y : u) {
|
||||
r.push_back(x);
|
||||
r.back().push_back(y);
|
||||
}
|
||||
}
|
||||
s = move(r);
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
/* Helper function to compute data strides from sizes */
|
||||
static inline std::vector<int64_t> computeStridesFromSizes(
|
||||
int64_t *dataSizes, int rank) {
|
||||
// Shift dimension sizes one to the left, fill in the vacated rightmost
|
||||
// element with 1; this gets us a vector that'll be more useful for computing
|
||||
// strides of memory access along each dimension using prefix product (aka
|
||||
// partial_sum with a multiply operator below). The intuition is that the size
|
||||
// of the leading dimension does not matter when computing strides.
|
||||
std::vector<int64_t> sizesVec(dataSizes + 1, dataSizes + rank);
|
||||
sizesVec.push_back(1);
|
||||
|
||||
std::vector<int64_t> dimStrides(rank);
|
||||
partial_sum(sizesVec.rbegin(), sizesVec.rend(), dimStrides.rbegin(),
|
||||
std::multiplies<>());
|
||||
return dimStrides;
|
||||
}
|
||||
|
||||
/* Helper function to compute linear offset from a multi-dimensional index array
|
||||
*/
|
||||
static inline int64_t computeElemOffset(
|
||||
int64_t *dataStrides, int rank, std::vector<int64_t> &indexes) {
|
||||
auto dimStrides = std::vector<int64_t>(dataStrides, dataStrides + rank);
|
||||
int64_t elemOffset = inner_product(
|
||||
indexes.begin(), indexes.end(), dimStrides.begin(), (int64_t)0);
|
||||
return elemOffset;
|
||||
}
|
||||
|
||||
/* Helper function to print a vector with delimiter */
|
||||
template <typename T>
|
||||
static inline void printVector(std::vector<T> vec, std::string _delimiter = ",",
|
||||
std::ostream &stream = std::cout) {
|
||||
std::string delimiter;
|
||||
for (const auto &elem : vec) {
|
||||
stream << delimiter << elem;
|
||||
delimiter = _delimiter;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* OMTensor creator with data sizes and element type
|
||||
*
|
||||
* @param dataSizes, data sizes array
|
||||
* @return pointer to OMTensor created, NULL if creation failed.
|
||||
*
|
||||
* Create a full OMTensor of data type T and shape dataSizes, with all
|
||||
* data fields initialized to proper values and data pointers malloc'ed.
|
||||
*/
|
||||
template <typename T>
|
||||
OMTensor *omTensorCreateWithShape(std::vector<int64_t> dataSizes);
|
||||
|
||||
/**
|
||||
* OMTensor creator with data sizes, element type and random data
|
||||
*
|
||||
* @param dataSizes, data sizes array
|
||||
* @param lbound (optional), lower bound of the random distribution
|
||||
* @param ubound (optional), upper bound of the random distribution
|
||||
* @return pointer to OMTensor created, NULL if creation failed.
|
||||
*
|
||||
* Create a full OMTensor like what omTensorCreateWithShape does
|
||||
* and also fill the OMTensor data buffer with randomly generated
|
||||
* real numbers from a uniform distribution between lbound and ubound.
|
||||
*/
|
||||
template <typename T>
|
||||
OMTensor *omTensorCreateWithRandomData(
|
||||
std::vector<int64_t> dataSizes, T lbound = -1.0, T ubound = 1.0);
|
||||
|
||||
/**
|
||||
* OMTensor data element getter by offset
|
||||
*
|
||||
* @param omt, pointer to the OMTensor
|
||||
* @param indexes, multi-dimensional index array of the element
|
||||
* @return typed element by reference at the offset computed by the index array.
|
||||
*/
|
||||
template <typename T>
|
||||
T &omTensorGetElem(OMTensor *omt, std::vector<int64_t> indexes);
|
||||
|
||||
/**
|
||||
* OMTensor data element getter by index
|
||||
*
|
||||
* @param omt, pointer to the OMTensor
|
||||
* @param index, index of the element
|
||||
* @return typed element by reference at the linear offset.
|
||||
*/
|
||||
template <typename T>
|
||||
T &omTensorGetElemByOffset(OMTensor *omt, int64_t index);
|
||||
|
||||
/**
|
||||
* OMTensor strides computation
|
||||
*
|
||||
* @param omt, pointer to the OMTensor
|
||||
* @return data strides of the OMTensor computed from the data sizes.
|
||||
*/
|
||||
std::vector<int64_t> omTensorComputeStridesFromShape(OMTensor *omt);
|
||||
|
||||
/**
|
||||
* OMTensor linear offset computation
|
||||
*
|
||||
* @param omt, pointer to the OMTensor
|
||||
* @param indexes, multi-dimensional index array
|
||||
* @return linear offset.
|
||||
*/
|
||||
int64_t omTensorComputeElemOffset(OMTensor *omt, std::vector<int64_t> &indexes);
|
||||
|
||||
/**
|
||||
* OMTensor index set computation
|
||||
*
|
||||
* @param omt, pointer to the OMTensor
|
||||
* @return index set (i.e., all valid multi-dimensional array indexes
|
||||
* that can be used to access this OMTensor's constituent elements)
|
||||
* for the whole OMTensor.
|
||||
*/
|
||||
std::vector<std::vector<int64_t>> omTensorComputeIndexSet(OMTensor *omt);
|
||||
|
||||
/**
|
||||
* OMTensor "distance" computation
|
||||
*
|
||||
* @param a, 1st OMTensor
|
||||
* @param b, 2nd OMTensor
|
||||
* @param rtol (optional), relative difference tolerance
|
||||
* @param atol (optional), absolute difference tolerance
|
||||
* @return true if both relative and absolute difference are within the
|
||||
* specified tolerance, respectively, false otherwise.
|
||||
*/
|
||||
template <typename T>
|
||||
bool omTensorAreTwoOmtsClose(
|
||||
OMTensor *a, OMTensor *b, float rtol = 1e-5, float atol = 1e-5);
|
||||
|
||||
#endif // ONNX_MLIR_OMTENSORHELPER_H
|
|
@ -0,0 +1,12 @@
|
|||
//===------------- OMTensorList.c - OMTensor C Implementation -------------===//
|
||||
//
|
||||
// Copyright 2019-2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains implementation of OMTensorList data structures and
|
||||
// helper functions.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "OMTensorList.inc"
|
|
@ -0,0 +1,12 @@
|
|||
//===------------ OMTensorList.cpp - OMTensor C++ Implementation ----------===//
|
||||
//
|
||||
// Copyright 2019-2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains implementation of OMTensorList data structures and
|
||||
// helper functions.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "OMTensorList.inc"
|
|
@ -0,0 +1,93 @@
|
|||
//===---------- OMTensorList.cpp - OMTensor C/C++ Implementation ----------===//
|
||||
//
|
||||
// Copyright 2019-2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains C/C++ neutral implementation of OMTensorList data
|
||||
// structures and helper functions.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifdef __APPLE__
|
||||
#include <stdlib.h>
|
||||
#else
|
||||
#include <malloc.h>
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
#include <cassert>
|
||||
#else
|
||||
#include <assert.h>
|
||||
#endif
|
||||
|
||||
#include "onnx-mlir/Runtime/OMTensorList.h"
|
||||
|
||||
struct OMTensorList {
|
||||
#ifdef __cplusplus
|
||||
/**
|
||||
* Constructor
|
||||
*
|
||||
* Create an OMTensorList with specified OMTensor pointer array
|
||||
* and the size of the array
|
||||
*/
|
||||
OMTensorList(OMTensor *omts[], int n) : _omts(omts), _size(n){};
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*
|
||||
* Create an empty OMTensorList for internal API calls.
|
||||
*/
|
||||
OMTensorList() = default;
|
||||
|
||||
/**
|
||||
* Destructor
|
||||
*
|
||||
* Destroy the OMTensorList struct.
|
||||
*/
|
||||
~OMTensorList() {
|
||||
/* Destroy all the OMTensors */
|
||||
for (int i = 0; i < _size; i++)
|
||||
if (_omts[i])
|
||||
omTensorDestroy(_omts[i]);
|
||||
};
|
||||
#endif
|
||||
|
||||
/* To facilitate user facing API getOmts, OMTensors are kept in a vector
|
||||
* that can be quickly returned as an array. A name to index map is used
|
||||
* to address ReMemRefs by name.
|
||||
*/
|
||||
OMTensor **_omts; // OMTensor array
|
||||
|
||||
size_t _size; // Number of elements in _omts.
|
||||
};
|
||||
|
||||
/* OMTensorList creator */
|
||||
OMTensorList *omTensorListCreate(OMTensor **tensors, int n) {
|
||||
OMTensorList *list = (OMTensorList *)malloc(sizeof(struct OMTensorList));
|
||||
if (!list)
|
||||
return NULL;
|
||||
list->_omts = tensors;
|
||||
list->_size = n;
|
||||
return list;
|
||||
}
|
||||
|
||||
/* OMTensorList destroyer */
|
||||
void omTensorListDestroy(OMTensorList *list) {
|
||||
for (int i = 0; i < list->_size; i++)
|
||||
free(list->_omts[i]);
|
||||
free(list);
|
||||
}
|
||||
|
||||
/* OMTensorList OMTensor array getter */
|
||||
OMTensor **omTensorListGetPtrToOmts(OMTensorList *list) { return list->_omts; }
|
||||
|
||||
/* OMTensorList number of OMTensor getter */
|
||||
int omTensorListGetSize(OMTensorList *list) { return list->_size; }
|
||||
|
||||
/* Return OMTensor at specified index in the OMTensorList */
|
||||
OMTensor *omTensorListGetOmtByIndex(OMTensorList *rlist, size_t index) {
|
||||
assert(index >= 0);
|
||||
assert(index < rlist->_size);
|
||||
return rlist->_omts[index];
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
#include "onnx-mlir/Runtime/OnnxDataType.h"
|
||||
|
||||
const int OM_DATA_TYPE_SIZE[] = {
|
||||
#define OM_TYPE_METADATA_DEF(ENUM_NAME, ENUM_VAL, DTYPE_SIZE) DTYPE_SIZE,
|
||||
#include "onnx-mlir/Runtime/OnnxDataTypeMetaData.inc"
|
||||
|
||||
#undef OM_TYPE_METADATA_DEF
|
||||
};
|
|
@ -17,74 +17,109 @@
|
|||
namespace onnx_mlir {
|
||||
|
||||
std::vector<py::array> PyExecutionSession::pyRun(
|
||||
std::vector<py::array> inputsPyArray) {
|
||||
const std::vector<py::array> &inputsPyArray) {
|
||||
assert(_entryPointFunc && "Entry point not loaded.");
|
||||
auto *wrappedInput = createOrderedRtMemRefDict();
|
||||
int inputIdx = 0;
|
||||
|
||||
std::vector<OMTensor *> omts;
|
||||
for (auto inputPyArray : inputsPyArray) {
|
||||
auto *inputRtMemRef = createRtMemRef(inputPyArray.ndim());
|
||||
assert(inputPyArray.flags() && py::array::c_style &&
|
||||
"Expect contiguous python array.");
|
||||
|
||||
void *dataPtr;
|
||||
int ownData = 0;
|
||||
if (inputPyArray.writeable()) {
|
||||
inputRtMemRef->data = inputPyArray.mutable_data();
|
||||
inputRtMemRef->alignedData = inputPyArray.mutable_data();
|
||||
dataPtr = inputPyArray.mutable_data();
|
||||
} else {
|
||||
// If data is not writable, copy them to a writable buffer.
|
||||
auto *copiedData = (float *)malloc(inputPyArray.nbytes());
|
||||
memcpy(copiedData, inputPyArray.data(), inputPyArray.nbytes());
|
||||
inputRtMemRef->data = copiedData;
|
||||
inputRtMemRef->alignedData = copiedData;
|
||||
dataPtr = copiedData;
|
||||
// We want OMTensor to free up the memory space upon destruction.
|
||||
ownData = 1;
|
||||
}
|
||||
|
||||
for (int i = 0; i < inputPyArray.ndim(); i++) {
|
||||
inputRtMemRef->sizes[i] = inputPyArray.shape(i);
|
||||
inputRtMemRef->strides[i] = inputPyArray.strides(i);
|
||||
}
|
||||
|
||||
setRtMemRef(wrappedInput, inputIdx++, inputRtMemRef);
|
||||
}
|
||||
|
||||
std::vector<py::array> outputPyArrays;
|
||||
auto *wrappedOutput = _entryPointFunc(wrappedInput);
|
||||
for (int i = 0; i < numRtMemRefs(wrappedOutput); i++) {
|
||||
auto *dynMemRef = getRtMemRef(wrappedOutput, i);
|
||||
auto shape = std::vector<int64_t>(
|
||||
dynMemRef->sizes, dynMemRef->sizes + dynMemRef->rank);
|
||||
|
||||
// https://numpy.org/devdocs/user/basics.types.html
|
||||
py::dtype dtype;
|
||||
if (dynMemRef->onnx_dtype == onnx::TensorProto::FLOAT)
|
||||
dtype = py::dtype("float32");
|
||||
else if (dynMemRef->onnx_dtype = onnx::TensorProto::UINT8)
|
||||
dtype = py::dtype("uint8");
|
||||
else if (dynMemRef->onnx_dtype = onnx::TensorProto::INT8)
|
||||
dtype = py::dtype("int8");
|
||||
else if (dynMemRef->onnx_dtype = onnx::TensorProto::UINT16)
|
||||
dtype = py::dtype("uint16");
|
||||
else if (dynMemRef->onnx_dtype = onnx::TensorProto::INT16)
|
||||
dtype = py::dtype("int16");
|
||||
else if (dynMemRef->onnx_dtype == onnx::TensorProto::INT32)
|
||||
dtype = py::dtype("int32");
|
||||
else if (dynMemRef->onnx_dtype == onnx::TensorProto::INT64)
|
||||
dtype = py::dtype("int64");
|
||||
// TODO(tjingrant) wait for Tong's input for how to represent string.
|
||||
else if (dynMemRef->onnx_dtype = onnx::TensorProto::BOOL)
|
||||
dtype = py::dtype("bool_");
|
||||
else if (dynMemRef->onnx_dtype = onnx::TensorProto::FLOAT16)
|
||||
dtype = py::dtype("float32");
|
||||
else if (dynMemRef->onnx_dtype = onnx::TensorProto::DOUBLE)
|
||||
dtype = py::dtype("float64");
|
||||
else if (dynMemRef->onnx_dtype == onnx::TensorProto::UINT32)
|
||||
dtype = py::dtype("uint32");
|
||||
else if (dynMemRef->onnx_dtype == onnx::TensorProto::UINT64)
|
||||
dtype = py::dtype("uint64");
|
||||
// Borrowed from:
|
||||
// https://github.com/pybind/pybind11/issues/563#issuecomment-267835542
|
||||
OM_DATA_TYPE dtype;
|
||||
if (py::isinstance<py::array_t<float>>(inputPyArray))
|
||||
dtype = ONNX_TYPE_FLOAT;
|
||||
else if (py::isinstance<py::array_t<std::uint8_t>>(inputPyArray))
|
||||
dtype = ONNX_TYPE_UINT8;
|
||||
else if (py::isinstance<py::array_t<std::int8_t>>(inputPyArray))
|
||||
dtype = ONNX_TYPE_INT8;
|
||||
else if (py::isinstance<py::array_t<std::uint16_t>>(inputPyArray))
|
||||
dtype = ONNX_TYPE_UINT16;
|
||||
else if (py::isinstance<py::array_t<std::int16_t>>(inputPyArray))
|
||||
dtype = ONNX_TYPE_INT16;
|
||||
else if (py::isinstance<py::array_t<std::int32_t>>(inputPyArray))
|
||||
dtype = ONNX_TYPE_INT32;
|
||||
else if (py::isinstance<py::array_t<std::int64_t>>(inputPyArray))
|
||||
dtype = ONNX_TYPE_INT64;
|
||||
else if (py::isinstance<py::array_t<bool>>(inputPyArray))
|
||||
dtype = ONNX_TYPE_BOOL;
|
||||
// Missing fp16 support.
|
||||
else if (py::isinstance<py::array_t<double>>(inputPyArray))
|
||||
dtype = ONNX_TYPE_DOUBLE;
|
||||
else if (py::isinstance<py::array_t<std::uint32_t>>(inputPyArray))
|
||||
dtype = ONNX_TYPE_UINT32;
|
||||
else if (py::isinstance<py::array_t<std::uint64_t>>(inputPyArray))
|
||||
dtype = ONNX_TYPE_UINT64;
|
||||
else {
|
||||
fprintf(stderr, "Unsupported ONNX type in RtMemRef.onnx_dtype.");
|
||||
std::cerr << "Numpy type not supported: " << inputPyArray.dtype()
|
||||
<< ".\n";
|
||||
exit(1);
|
||||
}
|
||||
|
||||
outputPyArrays.emplace_back(py::array(dtype, shape, dynMemRef->data));
|
||||
auto *inputOMTensor = omTensorCreateWithOwnership(dataPtr,
|
||||
(int64_t *)inputPyArray.shape(), inputPyArray.ndim(), dtype, ownData);
|
||||
omTensorSetStrides(inputOMTensor, (int64_t *)inputPyArray.strides());
|
||||
|
||||
omts.emplace_back(inputOMTensor);
|
||||
}
|
||||
|
||||
auto *wrappedInput = omTensorListCreate(&omts[0], omts.size());
|
||||
auto *wrappedOutput = _entryPointFunc(wrappedInput);
|
||||
|
||||
std::vector<py::array> outputPyArrays;
|
||||
for (int i = 0; i < omTensorListGetSize(wrappedOutput); i++) {
|
||||
auto *omt = omTensorListGetOmtByIndex(wrappedOutput, i);
|
||||
auto shape = std::vector<int64_t>(omTensorGetDataShape(omt),
|
||||
omTensorGetDataShape(omt) + omTensorGetRank(omt));
|
||||
|
||||
// https://numpy.org/devdocs/user/basics.types.html
|
||||
py::dtype dtype;
|
||||
if (omTensorGetDataType(omt) == onnx::TensorProto::FLOAT)
|
||||
dtype = py::dtype("float32");
|
||||
else if (omTensorGetDataType(omt) == onnx::TensorProto::UINT8)
|
||||
dtype = py::dtype("uint8");
|
||||
else if (omTensorGetDataType(omt) == onnx::TensorProto::INT8)
|
||||
dtype = py::dtype("int8");
|
||||
else if (omTensorGetDataType(omt) == onnx::TensorProto::UINT16)
|
||||
dtype = py::dtype("uint16");
|
||||
else if (omTensorGetDataType(omt) == onnx::TensorProto::INT16)
|
||||
dtype = py::dtype("int16");
|
||||
else if (omTensorGetDataType(omt) == onnx::TensorProto::INT32)
|
||||
dtype = py::dtype("int32");
|
||||
else if (omTensorGetDataType(omt) == onnx::TensorProto::INT64)
|
||||
dtype = py::dtype("int64");
|
||||
// TODO(tjingrant) wait for Tong's input for how to represent string.
|
||||
else if (omTensorGetDataType(omt) == onnx::TensorProto::BOOL)
|
||||
dtype = py::dtype("bool_");
|
||||
else if (omTensorGetDataType(omt) == onnx::TensorProto::FLOAT16)
|
||||
dtype = py::dtype("float32");
|
||||
else if (omTensorGetDataType(omt) == onnx::TensorProto::DOUBLE)
|
||||
dtype = py::dtype("float64");
|
||||
else if (omTensorGetDataType(omt) == onnx::TensorProto::UINT32)
|
||||
dtype = py::dtype("uint32");
|
||||
else if (omTensorGetDataType(omt) == onnx::TensorProto::UINT64)
|
||||
dtype = py::dtype("uint64");
|
||||
else {
|
||||
fprintf(stderr, "Unsupported ONNX type in OMTensor.");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
outputPyArrays.emplace_back(
|
||||
py::array(dtype, shape, omTensorGetDataPtr(omt)));
|
||||
}
|
||||
|
||||
return outputPyArrays;
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
namespace py = pybind11;
|
||||
|
||||
#include "src/Runtime/ExecusionSession.hpp"
|
||||
#include "ExecutionSession.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
|
@ -24,7 +24,7 @@ public:
|
|||
PyExecutionSession(std::string sharedLibPath, std::string entryPointName)
|
||||
: onnx_mlir::ExecutionSession(sharedLibPath, entryPointName){};
|
||||
|
||||
std::vector<py::array> pyRun(std::vector<py::array> inputsPyArray);
|
||||
std::vector<py::array> pyRun(const std::vector<py::array> &inputsPyArray);
|
||||
};
|
||||
} // namespace onnx_mlir
|
||||
|
||||
|
|
|
@ -1,171 +0,0 @@
|
|||
//===----------- RtMemRef.cpp - Dynamic MemRef Implementation ------------===//
|
||||
//
|
||||
// Copyright 2019-2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains implementations of Dynamic MemRef data structures and
|
||||
// helper functions.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <cassert>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "RtMemRef.h"
|
||||
|
||||
namespace {
|
||||
// Helper function to compute cartisian product.
|
||||
inline std::vector<std::vector<INDEX_TYPE>> CartProduct(
|
||||
const std::vector<std::vector<INDEX_TYPE>> &v) {
|
||||
std::vector<std::vector<INDEX_TYPE>> s = {{}};
|
||||
for (const auto &u : v) {
|
||||
std::vector<std::vector<INDEX_TYPE>> r;
|
||||
for (const auto &x : s) {
|
||||
for (const auto y : u) {
|
||||
r.push_back(x);
|
||||
r.back().push_back(y);
|
||||
}
|
||||
}
|
||||
s = move(r);
|
||||
}
|
||||
return s;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
RtMemRef::RtMemRef(int _rank) {
|
||||
rank = _rank;
|
||||
sizes = (INDEX_TYPE *)malloc(rank * sizeof(INDEX_TYPE));
|
||||
strides = (int64_t *)malloc(rank * sizeof(int64_t));
|
||||
}
|
||||
|
||||
INDEX_TYPE RtMemRef::size() const {
|
||||
return std::accumulate(sizes, sizes + rank, 1, std::multiplies<>());
|
||||
}
|
||||
|
||||
std::vector<std::vector<INDEX_TYPE>> RtMemRef::indexSet() const {
|
||||
// First, we create index set of each dimension separately.
|
||||
// i.e., for a tensor/RMR of shape (2, 3), its dimWiseIdxSet will be:
|
||||
// {{0,1}, {0,1,2}};
|
||||
std::vector<std::vector<INDEX_TYPE>> dimWiseIdxSet;
|
||||
for (auto dimSize : std::vector<INDEX_TYPE>(sizes, sizes + rank)) {
|
||||
std::vector<INDEX_TYPE> dimIdxSet(dimSize);
|
||||
std::iota(std::begin(dimIdxSet), std::end(dimIdxSet), 0);
|
||||
dimWiseIdxSet.emplace_back(dimIdxSet);
|
||||
}
|
||||
// Then, the cartesian product of vectors within dimWiseIdxSet will be the
|
||||
// index set for the whole RMR.
|
||||
return CartProduct(dimWiseIdxSet);
|
||||
}
|
||||
|
||||
INDEX_TYPE RtMemRef::computeOffset(std::vector<INDEX_TYPE> &idxs) const {
|
||||
auto dimStrides = std::vector<INDEX_TYPE>(strides, strides + rank);
|
||||
INDEX_TYPE elemOffset = std::inner_product(
|
||||
idxs.begin(), idxs.end(), dimStrides.begin(), (INDEX_TYPE)0);
|
||||
return elemOffset;
|
||||
}
|
||||
|
||||
std::vector<int64_t> RtMemRef::computeStridesFromSizes() const {
|
||||
// Shift dimension sizes one to the left, fill in the vacated rightmost
|
||||
// element with 1; this gets us a vector that'll be more useful for computing
|
||||
// strides of memory access along each dimension using prefix product (aka
|
||||
// partial_sum with a multiply operator below). The intuition is that the size
|
||||
// of the leading dimension does not matter when computing strides.
|
||||
std::vector<int64_t> sizesVec(sizes + 1, sizes + rank);
|
||||
sizesVec.push_back(1);
|
||||
|
||||
std::vector<int64_t> dimStrides(rank);
|
||||
std::partial_sum(sizesVec.rbegin(), sizesVec.rend(), dimStrides.rbegin(),
|
||||
std::multiplies<>());
|
||||
return dimStrides;
|
||||
}
|
||||
|
||||
RtMemRef::~RtMemRef() {
|
||||
free(data);
|
||||
free(sizes);
|
||||
free(strides);
|
||||
}
|
||||
|
||||
// An ordered dynamic MemRef dictionary.
|
||||
// The goal is to support accessing dynamic memory ref by name and by index.
|
||||
// Currently, only accessing by index is supported.
|
||||
struct OrderedRtMemRefDict {
|
||||
std::map<std::string, RtMemRef *> tensorDict;
|
||||
std::vector<std::string> orderedNames;
|
||||
};
|
||||
|
||||
int numRtMemRefs(OrderedRtMemRefDict *dict) {
|
||||
return dict->orderedNames.size();
|
||||
}
|
||||
|
||||
OrderedRtMemRefDict *createOrderedRtMemRefDict() {
|
||||
return new OrderedRtMemRefDict();
|
||||
}
|
||||
|
||||
RtMemRef *createRtMemRef(int rank) { return new RtMemRef(rank); }
|
||||
|
||||
RtMemRef *getRtMemRef(OrderedRtMemRefDict *tensorDict, int idx) {
|
||||
return tensorDict->tensorDict[tensorDict->orderedNames[idx]];
|
||||
}
|
||||
|
||||
void setRtMemRef(OrderedRtMemRefDict *tensorDict, int idx, RtMemRef *tensor) {
|
||||
if (tensorDict->orderedNames.size() <= idx)
|
||||
tensorDict->orderedNames.resize(idx + 1);
|
||||
|
||||
// The dynamic memref is essentially anonymous, since we are storing it by
|
||||
// indexed position.
|
||||
// TODO: can use random string as names to reduce chance of collision.
|
||||
auto unique_name = std::to_string(idx);
|
||||
assert(tensorDict->tensorDict.count(unique_name) == 0 &&
|
||||
"duplicate dynamic mem ref name");
|
||||
|
||||
tensorDict->orderedNames[idx] = unique_name;
|
||||
tensorDict->tensorDict[tensorDict->orderedNames[idx]] = tensor;
|
||||
}
|
||||
|
||||
void *getData(RtMemRef *dynMemRef) { return dynMemRef->data; }
|
||||
|
||||
void setData(RtMemRef *dynMemRef, void *dataPtr) { dynMemRef->data = dataPtr; }
|
||||
|
||||
void *getAlignedData(RtMemRef *dynMemRef) { return dynMemRef->alignedData; }
|
||||
|
||||
void setAlignedData(RtMemRef *dynMemRef, void *dataPtr) {
|
||||
dynMemRef->alignedData = dataPtr;
|
||||
}
|
||||
|
||||
INDEX_TYPE *getSizes(RtMemRef *dynMemRef) { return dynMemRef->sizes; }
|
||||
|
||||
void setSizes(RtMemRef *dynMemRef, INDEX_TYPE *sizes) {
|
||||
for (int i = 0; i < dynMemRef->rank; i++)
|
||||
dynMemRef->sizes[i] = sizes[i];
|
||||
}
|
||||
|
||||
int64_t *getStrides(RtMemRef *dynMemRef) { return dynMemRef->strides; }
|
||||
|
||||
int64_t getSize(OrderedRtMemRefDict *dict) { return dict->orderedNames.size(); }
|
||||
|
||||
INDEX_TYPE getDataSize(RtMemRef *rtMemRef) {
|
||||
INDEX_TYPE n = rtMemRef->sizes[0];
|
||||
for (int i = 1; i < rtMemRef->rank; i++)
|
||||
n *= rtMemRef->sizes[i];
|
||||
return n;
|
||||
}
|
||||
|
||||
void setDType(RtMemRef *dynMemRef, int onnxType) {
|
||||
dynMemRef->onnx_dtype = onnxType;
|
||||
}
|
||||
|
||||
int getDType(RtMemRef *dynMemRef) { return dynMemRef->onnx_dtype; }
|
||||
|
||||
unsigned int getRank(RtMemRef *dynMemRef) { return dynMemRef->rank; }
|
||||
|
||||
void setStrides(RtMemRef *dynMemRef, int64_t *strides) {
|
||||
for (int i = 0; i < dynMemRef->rank; i++)
|
||||
dynMemRef->strides[i] = strides[i];
|
||||
}
|
|
@ -1,283 +0,0 @@
|
|||
//===------------ RtMemRef.h - Dynamic MemRef Implementation -------------===//
|
||||
//
|
||||
// Copyright 2019-2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains declaration of Dynamic MemRef data structures and helper
|
||||
// functions.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifdef __cplusplus
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
#else
|
||||
#include <stdint.h>
|
||||
#endif
|
||||
|
||||
typedef int64_t INDEX_TYPE;
|
||||
|
||||
// Typically, MemRefs in MLIR context are used as a compile-time constructs.
|
||||
// Information such as element type and rank of the data payload is statically
|
||||
// encoded, meaning that they are determined and fixed at compile-time. This
|
||||
// presents significant burden for any runtime components trying to interact
|
||||
// with the compiled executable.
|
||||
//
|
||||
// Thus a version of MemRef struct that is amenable to runtime manipulation is
|
||||
// provided as a basis for building any runtime-related components providing
|
||||
// user-facing programming interfaces. All information are dynamically encoded
|
||||
// as members of this struct so that they can be accessed and modified easily
|
||||
// during runtime.
|
||||
//
|
||||
// We will refer to it as a RMF (Runtime MemRef).
|
||||
struct RtMemRef {
|
||||
|
||||
// Pointer to the raw memory space allocated to host the RMR content. This
|
||||
// pointer should only be acessed for memory management purposes, not for
|
||||
// reading RMR content.
|
||||
void *data;
|
||||
|
||||
// Pointer to the properly aligned array of elements stored in this Rmr.
|
||||
void *alignedData;
|
||||
|
||||
// Distance between the start of the raw memory space and the first element of
|
||||
// the RMR content.
|
||||
INDEX_TYPE offset;
|
||||
|
||||
// Number of dimensions of the array represented by the RMR.
|
||||
unsigned int rank;
|
||||
|
||||
// An array recording the per-dimension sizes of the array represented by the
|
||||
// RMR.
|
||||
INDEX_TYPE *sizes;
|
||||
|
||||
// An array recording the per-dimension strides of the array represented by
|
||||
// the RMR.
|
||||
int64_t *strides;
|
||||
|
||||
// Refer to TensorProto_DataType at
|
||||
// https://github.com/onnx/onnx/blob/cc2230603422bae893d5bc900d2d773ab34400a4/onnx/onnx-ml.proto#L451
|
||||
// for enum value interpretation.
|
||||
unsigned int onnx_dtype;
|
||||
#ifdef __cplusplus
|
||||
explicit RtMemRef(int _rank);
|
||||
|
||||
// Create a full RMR of type T and shape _sizes, with all data fields
|
||||
// initialized to proper values and data pointers malloc'ed.
|
||||
template <typename T>
|
||||
static RtMemRef *create(std::vector<INDEX_TYPE> _sizes) {
|
||||
auto rmr = new RtMemRef(_sizes.size());
|
||||
rmr->offset = 0;
|
||||
rmr->rank = _sizes.size();
|
||||
rmr->sizes = (INDEX_TYPE *)malloc(rmr->rank * sizeof(INDEX_TYPE));
|
||||
std::copy(_sizes.begin(), _sizes.end(), rmr->sizes);
|
||||
|
||||
rmr->strides = (int64_t *)malloc(rmr->rank * sizeof(int64_t));
|
||||
auto computedStrides = rmr->computeStridesFromSizes();
|
||||
std::copy(computedStrides.begin(), computedStrides.end(), rmr->strides);
|
||||
|
||||
rmr->data = malloc(rmr->size() * sizeof(T));
|
||||
rmr->alignedData = rmr->data;
|
||||
|
||||
return rmr;
|
||||
}
|
||||
|
||||
// Access an element (by reference) at index position idxs.
|
||||
template <typename T>
|
||||
T &elem(std::vector<INDEX_TYPE> idxs) {
|
||||
INDEX_TYPE elemOffset = computeOffset(idxs);
|
||||
T *typedPtr = (T *)data;
|
||||
return typedPtr[elemOffset];
|
||||
}
|
||||
|
||||
// Access an element (by reference) at *flattened* index position idx.
|
||||
template <typename T>
|
||||
T &elem(INDEX_TYPE idx) {
|
||||
T *typedPtr = (T *)data;
|
||||
return typedPtr[idx];
|
||||
}
|
||||
|
||||
// Get a typed ptr to the data content of the RMR.
|
||||
template <typename T>
|
||||
T *typedPtr() {
|
||||
return (T *)data;
|
||||
}
|
||||
|
||||
// Get how many elements are stored in RMR, as implied by its shape.
|
||||
INDEX_TYPE size() const;
|
||||
|
||||
// Helper function to compute strides of access along each dimensions from its
|
||||
// shape.
|
||||
std::vector<int64_t> computeStridesFromSizes() const;
|
||||
|
||||
// Compute flattened array idx from a multi-dimensional array idx.
|
||||
INDEX_TYPE computeOffset(std::vector<INDEX_TYPE> &idxs) const;
|
||||
|
||||
// Get the index set (i.e., all valid multi-dimensional array indexes that can
|
||||
// be used to access this RMR's constituent elements).
|
||||
std::vector<std::vector<INDEX_TYPE>> indexSet() const;
|
||||
|
||||
~RtMemRef();
|
||||
#endif
|
||||
};
|
||||
|
||||
#ifdef __cplusplus
|
||||
// Ordered RtMemRef Dictionary is a data structure for wrapping the input
|
||||
// dynmemrefs so that they can be addressed both by index and by name.
|
||||
struct OrderedRtMemRefDict;
|
||||
|
||||
#else
|
||||
typedef struct RtMemRef RtMemRef;
|
||||
typedef struct _OrderedRtMemRefDict OrderedRtMemRefDict;
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Get number of dynamic memrefs in OrderedRtMemRefDict dict.
|
||||
int numRtMemRefs(OrderedRtMemRefDict *dict);
|
||||
|
||||
// Create an ordered dynamic memref dictionary.
|
||||
OrderedRtMemRefDict *createOrderedRtMemRefDict();
|
||||
|
||||
// Get how many dynamic memrefs are in dict.
|
||||
int64_t getSize(OrderedRtMemRefDict *dict);
|
||||
|
||||
// Get how many data elements are in RtMemRef.
|
||||
INDEX_TYPE getDataSize(RtMemRef *rtMemRef);
|
||||
|
||||
// Create a dynmemref with a certain rank.
|
||||
RtMemRef *createRtMemRef(int rank);
|
||||
|
||||
// Get the i-th dynmemref from orderedDict.
|
||||
RtMemRef *getRtMemRef(OrderedRtMemRefDict *orderedDict, int i);
|
||||
|
||||
// Set the i-th dynmemref in orderedDict to be dynMemRef.
|
||||
void setRtMemRef(OrderedRtMemRefDict *tensorDict, int idx, RtMemRef *dynMemRef);
|
||||
|
||||
// Get data pointer from dynMemRef.
|
||||
void *getData(RtMemRef *dynMemRef);
|
||||
|
||||
// Set data pointer for dynMemRef.
|
||||
void setData(RtMemRef *dynMemRef, void *data);
|
||||
|
||||
// Get algined data pointer from dynMemRef.
|
||||
void *getAlignedData(RtMemRef *);
|
||||
|
||||
// Set aligned data pointer for dynMemRef.
|
||||
void setAlignedData(RtMemRef *, void *);
|
||||
|
||||
// Get the data type enum value of the dynMemRef.
|
||||
int getDType(RtMemRef *dynMemRef);
|
||||
|
||||
// Set the data type enum value of the dynMemRef.
|
||||
void setDType(RtMemRef *dynMemRef, int onnxType);
|
||||
|
||||
// Get the rank of the dynMemRef.
|
||||
unsigned int getRank(RtMemRef *dynMemRef);
|
||||
|
||||
// Get ptr to sizes array.
|
||||
INDEX_TYPE *getSizes(RtMemRef *);
|
||||
|
||||
// Set the sizes array (by copying size values from array `sizes`).
|
||||
void setSizes(RtMemRef *, INDEX_TYPE *sizes);
|
||||
|
||||
// Get ptr to strides array.
|
||||
int64_t *getStrides(RtMemRef *);
|
||||
|
||||
// Set the strides array (by copying stride values from array `strides`).
|
||||
void setStrides(RtMemRef *, int64_t *strides);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void printVector(std::vector<T> vec, std::string _delimiter = ",",
|
||||
std::ostream &stream = std::cout) {
|
||||
std::string delimiter;
|
||||
for (const auto &elem : vec) {
|
||||
stream << delimiter << elem;
|
||||
delimiter = _delimiter;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
RtMemRef *getRndRealRmr(
|
||||
std::vector<INDEX_TYPE> sizes, T lb = -1.0, T ub = 1.0) {
|
||||
// Will be used to obtain a seed for the random number engine
|
||||
std::random_device rd;
|
||||
// Standard mersenne_twister_engine seeded with rd()
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_real_distribution<> dis(lb, ub);
|
||||
auto rmr = RtMemRef::create<T>(sizes);
|
||||
auto ptr = (T *)rmr->data;
|
||||
std::generate(ptr, ptr + rmr->size(), [&]() { return dis(gen); });
|
||||
return rmr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool isRmrClose(
|
||||
RtMemRef *a, RtMemRef *b, float rtol = 1e-5, float atol = 1e-5) {
|
||||
|
||||
// Compare shape.
|
||||
auto aShape = std::vector<INDEX_TYPE>(a->sizes, a->sizes + a->rank);
|
||||
auto bShape = std::vector<INDEX_TYPE>(b->sizes, b->sizes + b->rank);
|
||||
if (aShape != bShape) {
|
||||
std::cerr << "Shape mismatch ";
|
||||
printVector(aShape, ",", std::cerr);
|
||||
std::cerr << " != ";
|
||||
printVector(bShape, ",", std::cerr);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Compute absolute difference, verify it's within tolerable range.
|
||||
std::vector<T> absoluteDiff(a->size());
|
||||
std::transform(a->typedPtr<T>(), a->typedPtr<T>() + a->size(),
|
||||
b->typedPtr<T>(), absoluteDiff.begin(), std::minus<>());
|
||||
std::transform(absoluteDiff.begin(), absoluteDiff.end(), absoluteDiff.begin(),
|
||||
static_cast<T (*)(T)>(&std::abs));
|
||||
bool atolSatisfied = std::all_of(
|
||||
absoluteDiff.begin(), absoluteDiff.end(), [&](T a) { return a < atol; });
|
||||
|
||||
// Compute relative difference, verify it's within tolerable range.
|
||||
std::vector<T> relativeDiff(a->size());
|
||||
std::transform(absoluteDiff.begin(), absoluteDiff.end(), a->typedPtr<T>(),
|
||||
relativeDiff.begin(), std::divides<>());
|
||||
bool rtolSatisfied = std::all_of(
|
||||
relativeDiff.begin(), relativeDiff.end(), [&](T a) { return a < rtol; });
|
||||
|
||||
if (atolSatisfied && rtolSatisfied) {
|
||||
return true;
|
||||
} else {
|
||||
// Figure out where and what went wrong, this can be slow; but hopefully we
|
||||
// don't need this often.
|
||||
for (const auto &idx : a->indexSet()) {
|
||||
T aElem = a->elem<T>(idx);
|
||||
T bElem = b->elem<T>(idx);
|
||||
auto elmAbsDiff = std::abs(aElem - bElem);
|
||||
auto withinRtol = (elmAbsDiff / aElem < rtol);
|
||||
auto withinAtol = (elmAbsDiff < atol);
|
||||
if (!withinRtol || !withinAtol) {
|
||||
std::cerr << "a[";
|
||||
printVector(idx, ",", std::cerr);
|
||||
std::cerr << "] = " << aElem << " != ";
|
||||
std::cerr << "b[";
|
||||
printVector(idx, ",", std::cerr);
|
||||
std::cerr << "] = " << bElem << std::endl;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// Will transition from RtMemRef to RtMemRef soon.
|
||||
typedef RtMemRef RtMemRef;
|
|
@ -7,18 +7,18 @@ if(Java_Development_FOUND AND JNI_FOUND)
|
|||
# Target for Java runtime jar
|
||||
add_jar(javaruntime
|
||||
src/com/ibm/onnxmlir/DynEntryPoint.java
|
||||
src/com/ibm/onnxmlir/OrderedRtMemRefDict.java
|
||||
src/com/ibm/onnxmlir/RtMemRef.java
|
||||
src/com/ibm/onnxmlir/OMTensorList.java
|
||||
src/com/ibm/onnxmlir/OMTensor.java
|
||||
OUTPUT_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
|
||||
|
||||
# Target for JNI runtime lib
|
||||
add_library(jniruntime STATIC
|
||||
jniwrapper.c jnilog.c jnidummy.c
|
||||
com_ibm_onnxmlir_DynEntryPoint.h jnilog.h ../RtMemRef.h)
|
||||
com_ibm_onnxmlir_DynEntryPoint.h jnilog.h)
|
||||
set_target_properties(jniruntime PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE TRUE)
|
||||
target_include_directories(jniruntime PRIVATE
|
||||
${ONNX_MLIR_SRC_ROOT}/src/Runtime
|
||||
${ONNX_MLIR_SRC_ROOT}/include
|
||||
${JAVA_INCLUDE_PATH}
|
||||
${JAVA_INCLUDE_PATH2})
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ extern "C" {
|
|||
* Class: com_ibm_onnxmlir_DynEntryPoint
|
||||
* Method: main_graph_jni
|
||||
* Signature:
|
||||
* (Lcom/ibm/onnxmlir/OrderedRtMemRefDict;)Lcom/ibm/onnxmlir/OrderedRtMemRefDict;
|
||||
* (Lcom/ibm/onnxmlir/OMTensorList;)Lcom/ibm/onnxmlir/OMTensorList;
|
||||
*/
|
||||
JNIEXPORT jobject JNICALL Java_com_ibm_onnxmlir_DynEntryPoint_main_1graph_1jni(
|
||||
JNIEnv *, jclass, jobject);
|
||||
|
|
|
@ -63,26 +63,6 @@ enum { LOG_TRACE, LOG_DEBUG, LOG_INFO, LOG_WARNING, LOG_ERROR, LOG_FATAL };
|
|||
sprintf(buf + strlen(buf), n > LOG_MAX_NUM ? " ... " : " "); \
|
||||
} while (0)
|
||||
|
||||
enum {
|
||||
ONNX_TYPE_UNDEFINED, /* 0 */
|
||||
ONNX_TYPE_FLOAT, /* 1 */
|
||||
ONNX_TYPE_UINT8, /* 2 */
|
||||
ONNX_TYPE_INT8, /* 3 */
|
||||
ONNX_TYPE_UINT16, /* 4 */
|
||||
ONNX_TYPE_INT16, /* 5 */
|
||||
ONNX_TYPE_INT32, /* 6 */
|
||||
ONNX_TYPE_INT64, /* 7 */
|
||||
ONNX_TYPE_STRING, /* 8 */
|
||||
ONNX_TYPE_BOOL, /* 9 */
|
||||
ONNX_TYPE_FLOAT16, /* 10 */
|
||||
ONNX_TYPE_DOUBLE, /* 11 */
|
||||
ONNX_TYPE_UINT32, /* 12 */
|
||||
ONNX_TYPE_UINT64, /* 13 */
|
||||
ONNX_TYPE_COMPLEX64, /* 14 */
|
||||
ONNX_TYPE_COMPLEX128, /* 15 */
|
||||
ONNX_TYPE_BFLOAT16, /* 16 */
|
||||
};
|
||||
|
||||
/* Construct string of up to LOG_MAX_NUM elements of a "type" array */
|
||||
#define LOG_TYPE_BUF(type, buf, data, n) \
|
||||
do { \
|
||||
|
|
|
@ -6,17 +6,20 @@
|
|||
#endif
|
||||
#include <string.h>
|
||||
|
||||
#include "RtMemRef.h"
|
||||
#include "OnnxMlirRuntime.h"
|
||||
#include "com_ibm_onnxmlir_DynEntryPoint.h"
|
||||
#include "jnilog.h"
|
||||
|
||||
/* Declare type var, make call and assign to var, check against val */
|
||||
/* Declare type var, make call and assign to var, check against val.
|
||||
* It's assumed that a Java exception has already been thrown so
|
||||
* this call simply returns NULL.
|
||||
*/
|
||||
#define CHECK_CALL(type, var, call, val) \
|
||||
type var = call; \
|
||||
if (var == val) \
|
||||
return NULL
|
||||
|
||||
/* Make a JNI call and throw Java exception if the call failed */
|
||||
/* Make a JNI call, log error and throw Java exception if the call failed */
|
||||
#define JNI_CALL(env, stmt) \
|
||||
stmt; \
|
||||
do { \
|
||||
|
@ -29,332 +32,354 @@
|
|||
} while (0)
|
||||
|
||||
/* Make a JNI call and assign return value to var,
|
||||
* throw Java exception if the call failed
|
||||
* log error and throw Java exception if the call failed
|
||||
*/
|
||||
#define JNI_VAR_CALL(env, var, call) JNI_CALL(env, var = call)
|
||||
|
||||
/* Declare type var, make a JNI call and assign return value to var,
|
||||
* throw Java exception if the call failed
|
||||
* log error and throw Java exception if the call failed
|
||||
*/
|
||||
#define JNI_TYPE_VAR_CALL(env, type, var, call) JNI_CALL(env, type var = call);
|
||||
|
||||
/* If cond is true (native code failed), log error and throw Java exception */
|
||||
#define JNI_COND(type, var, call, val, env, cls, ...) \
|
||||
type var = call; \
|
||||
/* Make a native library call, if cond is true (native code failed),
|
||||
* log error and throw Java exception
|
||||
*/
|
||||
#define LIB_CALL(stmt, check, env, cls, ...) \
|
||||
stmt; \
|
||||
do { \
|
||||
if (var == val) { \
|
||||
if (check) { \
|
||||
LOG_PRINTF(LOG_ERROR, __VA_ARGS__); \
|
||||
(*env)->ThrowNew(env, cls, "native code error"); \
|
||||
return NULL; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
/* Debug output of RtMemRef fields */
|
||||
#define RMR_DEBUG(i, type, rank, sizes, strides, data, datasize) \
|
||||
/* Make a native library call and assign return value to var,
|
||||
* log error and throw Java exception if the call failed
|
||||
*/
|
||||
#define LIB_VAR_CALL(var, call, val, env, cls, ...) \
|
||||
LIB_CALL(var = call, var == val, env, cls, __VA_ARGS__);
|
||||
|
||||
/* Declare type var, make a native library call and assign return value to var,
|
||||
* log error and throw Java exception if the call failed
|
||||
*/
|
||||
#define LIB_TYPE_VAR_CALL(type, var, call, val, env, cls, ...) \
|
||||
LIB_CALL(type var = call, var == val, env, cls, __VA_ARGS__);
|
||||
|
||||
/* Debug output of OMTensor fields */
|
||||
#define OMT_DEBUG( \
|
||||
i, n, data, dataSizes, dataStrides, dataType, dataBufferSize, rank, name) \
|
||||
do { \
|
||||
char tmp[1024]; \
|
||||
LOG_PRINTF(LOG_DEBUG, "rmr[%d]:type=%d", i, type); \
|
||||
LOG_PRINTF(LOG_DEBUG, "rmr[%d]:rank=%d", i, rank); \
|
||||
LOG_LONG_BUF(tmp, sizes, rank); \
|
||||
LOG_PRINTF(LOG_DEBUG, "rmr[%d]:sizes=[%s]", i, tmp); \
|
||||
LOG_LONG_BUF(tmp, strides, rank); \
|
||||
LOG_PRINTF(LOG_DEBUG, "rmr[%d]:strides=[%s]", i, tmp); \
|
||||
LOG_TYPE_BUF(type, tmp, data, datasize); \
|
||||
LOG_PRINTF(LOG_DEBUG, "rmr[%d]:data=[%s]", i, tmp); \
|
||||
LOG_TYPE_BUF(dataType, tmp, data, n); \
|
||||
LOG_PRINTF(LOG_DEBUG, "omt[%d]:data=[%s]", i, tmp); \
|
||||
LOG_LONG_BUF(tmp, dataSizes, rank); \
|
||||
LOG_PRINTF(LOG_DEBUG, "omt[%d]:dataSizes=[%s]", i, tmp); \
|
||||
LOG_LONG_BUF(tmp, dataStrides, rank); \
|
||||
LOG_PRINTF(LOG_DEBUG, "omt[%d]:dataStrides=[%s]", i, tmp); \
|
||||
LOG_PRINTF(LOG_DEBUG, "omt[%d]:dataType=%d", i, dataType); \
|
||||
LOG_PRINTF(LOG_DEBUG, "omt[%d]:dataBufferSize=%ld", i, dataBufferSize); \
|
||||
LOG_PRINTF(LOG_DEBUG, "omt[%d]:rank=%d", i, rank); \
|
||||
LOG_PRINTF(LOG_DEBUG, "omt[%d]:name=%s", i, name); \
|
||||
LOG_PRINTF(LOG_DEBUG, "omt[%d]:numOfElems=%ld", i, n); \
|
||||
} while (0)
|
||||
|
||||
/* Model shared library entry point */
|
||||
extern OrderedRtMemRefDict *_dyn_entry_point_main_graph(OrderedRtMemRefDict *);
|
||||
|
||||
/* ONNX type to size (number of bytes) mapping */
|
||||
int onnx_type_size[] = {
|
||||
0, /* UNDEFINED = 0 */
|
||||
4, /* FLOAT = 1 */
|
||||
1, /* UINT8 = 2 */
|
||||
1, /* INT8 = 3 */
|
||||
2, /* UINT16 = 4 */
|
||||
2, /* INT16 = 5 */
|
||||
4, /* INT32 = 6 */
|
||||
8, /* INT64 = 7 */
|
||||
0, /* STRING = 8 */
|
||||
1, /* BOOL = 9 */
|
||||
2, /* FLOAT16 = 10 */
|
||||
8, /* DOUBLE = 11 */
|
||||
4, /* UINT32 = 12 */
|
||||
8, /* UINT64 = 13 */
|
||||
8, /* COMPLEX64 = 14 */
|
||||
16, /* COMPLEX128 = 15 */
|
||||
2, /* BFLOAT16 = 16 */
|
||||
};
|
||||
extern OMTensorList *_dyn_entry_point_main_graph(OMTensorList *);
|
||||
|
||||
/* Java classes and methods needed for making various JNI API calls */
|
||||
typedef struct {
|
||||
jclass ecpt_cls; /* java/lang/Exception class */
|
||||
jclass long_cls; /* java/lang/Long class */
|
||||
jclass string_cls; /* java/lang/String class */
|
||||
jclass ormrd_cls; /* com/ibm/onnxmlir/OrderedRtMemRefDict class */
|
||||
jclass rmr_cls; /* com/ibm/onnxmlir/RtMemRef class */
|
||||
jclass ecpt_cls; /* java/lang/Exception class */
|
||||
jclass long_cls; /* java/lang/Long class */
|
||||
jclass string_cls; /* java/lang/String class */
|
||||
jclass omt_cls; /* com/ibm/onnxmlir/OMTensor class */
|
||||
jclass omt_list_cls; /* com/ibm/onnxmlir/OMTensorList class */
|
||||
|
||||
jmethodID ormrd_constructor; /* OrderedRtMemRefDict constructor */
|
||||
jmethodID ormrd_getRmrs; /* OrderedRtMemRefDict getRmrs method */
|
||||
jmethodID ormrd_getNames; /* OrderedRtMemRefDict getNames method */
|
||||
jmethodID omt_constructor; /* OMTensor constructor */
|
||||
jmethodID omt_getData; /* OMTensor getData method */
|
||||
jmethodID omt_setData; /* OMTensor setData method */
|
||||
jmethodID omt_getDataSizes; /* OMTensor getDataSizes method */
|
||||
jmethodID omt_setDataSizes; /* OMTensor setDataSizes method */
|
||||
jmethodID omt_getDataStrides; /* OMTensor getDataStrides method */
|
||||
jmethodID omt_setDataStrides; /* OMTensor setDataStrides method */
|
||||
jmethodID omt_getDataType; /* OMTensor getDataType method */
|
||||
jmethodID omt_setDataType; /* OMTensor setDataType method */
|
||||
jmethodID omt_getDataBufferSize; /* OMTensor getDataBufferSize method */
|
||||
jmethodID omt_getRank; /* OMTensor getRank method */
|
||||
jmethodID omt_getName; /* OMTensor getName method */
|
||||
jmethodID omt_setName; /* OMTensor setName method */
|
||||
jmethodID omt_getNumOfElems; /* OMTensor getNumOfElems method */
|
||||
|
||||
jmethodID rmr_constructor; /* RtMemRef constructor */
|
||||
jmethodID rmr_getType; /* RtMemRef getType method */
|
||||
jmethodID rmr_setType; /* RtMemRef setType method */
|
||||
jmethodID rmr_getRank; /* RtMemRef getRank method */
|
||||
jmethodID rmr_getData; /* RtMemRef getData method */
|
||||
jmethodID rmr_setData; /* RtMemRef setData method */
|
||||
jmethodID rmr_getSizes; /* RtMemRef getSizes method */
|
||||
jmethodID rmr_setSizes; /* RtMemRef setSizes method */
|
||||
jmethodID rmr_getStrides; /* RtMemRef getStrides method */
|
||||
jmethodID rmr_setStrides; /* RtMemRef setStrides method */
|
||||
jmethodID rmr_getDataSize; /* RtMemRef getDataSize method */
|
||||
jmethodID omt_list_constructor; /* OMTensorList constructor */
|
||||
jmethodID omt_list_getOmts; /* OMTensorList getOmts method */
|
||||
} jniapi_t;
|
||||
|
||||
jniapi_t jniapi;
|
||||
|
||||
/* Fill in struct jniapi */
|
||||
jniapi_t *fill_jniapi(JNIEnv *env, jniapi_t *japi) {
|
||||
/* Get Java Exception, Long, String, OrderedRtMemRefDict, and RtMemRef classes
|
||||
/* Get Java Exception, Long, String, OMTensor, and OMTensorList classes
|
||||
*/
|
||||
JNI_VAR_CALL(
|
||||
env, japi->ecpt_cls, (*env)->FindClass(env, "java/lang/Exception"));
|
||||
JNI_VAR_CALL(env, japi->long_cls, (*env)->FindClass(env, "java/lang/Long"));
|
||||
JNI_VAR_CALL(
|
||||
env, japi->string_cls, (*env)->FindClass(env, "java/lang/String"));
|
||||
JNI_VAR_CALL(env, japi->ormrd_cls,
|
||||
(*env)->FindClass(env, "com/ibm/onnxmlir/OrderedRtMemRefDict"));
|
||||
JNI_VAR_CALL(
|
||||
env, japi->rmr_cls, (*env)->FindClass(env, "com/ibm/onnxmlir/RtMemRef"));
|
||||
env, japi->omt_cls, (*env)->FindClass(env, "com/ibm/onnxmlir/OMTensor"));
|
||||
JNI_VAR_CALL(env, japi->omt_list_cls,
|
||||
(*env)->FindClass(env, "com/ibm/onnxmlir/OMTensorList"));
|
||||
|
||||
/* Get method ID of constructor and various methods in OrderedRtMemRefDict */
|
||||
JNI_VAR_CALL(env, japi->ormrd_constructor,
|
||||
/* Get method ID of constructor and various methods in OMTensor */
|
||||
JNI_VAR_CALL(env, japi->omt_constructor,
|
||||
(*env)->GetMethodID(env, japi->omt_cls, "<init>", "(I)V"));
|
||||
JNI_VAR_CALL(env, japi->omt_getData,
|
||||
(*env)->GetMethodID(
|
||||
env, japi->ormrd_cls, "<init>", "([Lcom/ibm/onnxmlir/RtMemRef;)V"));
|
||||
JNI_VAR_CALL(env, japi->ormrd_getRmrs,
|
||||
env, japi->omt_cls, "getData", "()Ljava/nio/ByteBuffer;"));
|
||||
JNI_VAR_CALL(env, japi->omt_setData,
|
||||
(*env)->GetMethodID(
|
||||
env, japi->ormrd_cls, "getRmrs", "()[Lcom/ibm/onnxmlir/RtMemRef;"));
|
||||
JNI_VAR_CALL(env, japi->ormrd_getNames,
|
||||
env, japi->omt_cls, "setData", "(Ljava/nio/ByteBuffer;)V"));
|
||||
JNI_VAR_CALL(env, japi->omt_getDataSizes,
|
||||
(*env)->GetMethodID(env, japi->omt_cls, "getDataSizes", "()[J"));
|
||||
JNI_VAR_CALL(env, japi->omt_setDataSizes,
|
||||
(*env)->GetMethodID(env, japi->omt_cls, "setDataSizes", "([J)V"));
|
||||
JNI_VAR_CALL(env, japi->omt_getDataStrides,
|
||||
(*env)->GetMethodID(env, japi->omt_cls, "getDataStrides", "()[J"));
|
||||
JNI_VAR_CALL(env, japi->omt_setDataStrides,
|
||||
(*env)->GetMethodID(env, japi->omt_cls, "setDataStrides", "([J)V"));
|
||||
JNI_VAR_CALL(env, japi->omt_getDataType,
|
||||
(*env)->GetMethodID(env, japi->omt_cls, "getDataType", "()I"));
|
||||
JNI_VAR_CALL(env, japi->omt_setDataType,
|
||||
(*env)->GetMethodID(env, japi->omt_cls, "setDataType", "(I)V"));
|
||||
JNI_VAR_CALL(env, japi->omt_getDataBufferSize,
|
||||
(*env)->GetMethodID(env, japi->omt_cls, "getDataBufferSize", "()J"));
|
||||
JNI_VAR_CALL(env, japi->omt_getRank,
|
||||
(*env)->GetMethodID(env, japi->omt_cls, "getRank", "()I"));
|
||||
JNI_VAR_CALL(env, japi->omt_getName,
|
||||
(*env)->GetMethodID(
|
||||
env, japi->ormrd_cls, "getNames", "()[Ljava/lang/String;"));
|
||||
env, japi->omt_cls, "getName", "()Ljava/lang/String;"));
|
||||
JNI_VAR_CALL(env, japi->omt_setName,
|
||||
(*env)->GetMethodID(
|
||||
env, japi->omt_cls, "setName", "(Ljava/lang/String;)V"));
|
||||
JNI_VAR_CALL(env, japi->omt_getNumOfElems,
|
||||
(*env)->GetMethodID(env, japi->omt_cls, "getNumOfElems", "()J"));
|
||||
|
||||
/* Get method ID of constructor and various methods in RtMemRef */
|
||||
JNI_VAR_CALL(env, japi->rmr_constructor,
|
||||
(*env)->GetMethodID(env, japi->rmr_cls, "<init>", "(I)V"));
|
||||
JNI_VAR_CALL(env, japi->rmr_getType,
|
||||
(*env)->GetMethodID(env, japi->rmr_cls, "getType", "()I"));
|
||||
JNI_VAR_CALL(env, japi->rmr_setType,
|
||||
(*env)->GetMethodID(env, japi->rmr_cls, "setType", "(I)V"));
|
||||
JNI_VAR_CALL(env, japi->rmr_getRank,
|
||||
(*env)->GetMethodID(env, japi->rmr_cls, "getRank", "()I"));
|
||||
JNI_VAR_CALL(env, japi->rmr_getData,
|
||||
(*env)->GetMethodID(
|
||||
env, japi->rmr_cls, "getData", "()Ljava/nio/ByteBuffer;"));
|
||||
JNI_VAR_CALL(env, japi->rmr_setData,
|
||||
(*env)->GetMethodID(
|
||||
env, japi->rmr_cls, "setData", "(Ljava/nio/ByteBuffer;)V"));
|
||||
JNI_VAR_CALL(env, japi->rmr_getSizes,
|
||||
(*env)->GetMethodID(env, japi->rmr_cls, "getSizes", "()[J"));
|
||||
JNI_VAR_CALL(env, japi->rmr_setSizes,
|
||||
(*env)->GetMethodID(env, japi->rmr_cls, "setSizes", "([J)V"));
|
||||
JNI_VAR_CALL(env, japi->rmr_getStrides,
|
||||
(*env)->GetMethodID(env, japi->rmr_cls, "getStrides", "()[J"));
|
||||
JNI_VAR_CALL(env, japi->rmr_setStrides,
|
||||
(*env)->GetMethodID(env, japi->rmr_cls, "setStrides", "([J)V"));
|
||||
JNI_VAR_CALL(env, japi->rmr_getDataSize,
|
||||
(*env)->GetMethodID(env, japi->rmr_cls, "getDataSize", "()J"));
|
||||
/* Get method ID of constructor and various methods in OMTensorList */
|
||||
JNI_VAR_CALL(env, japi->omt_list_constructor,
|
||||
(*env)->GetMethodID(env, japi->omt_list_cls, "<init>",
|
||||
"([Lcom/ibm/onnxmlir/OMTensor;)V"));
|
||||
JNI_VAR_CALL(env, japi->omt_list_getOmts,
|
||||
(*env)->GetMethodID(env, japi->omt_list_cls, "getOmts",
|
||||
"()[Lcom/ibm/onnxmlir/OMTensor;"));
|
||||
|
||||
return japi;
|
||||
}
|
||||
|
||||
/* Convert Java object to native data structure */
|
||||
OrderedRtMemRefDict *ormrd_java_to_native(
|
||||
OMTensorList *omt_list_java_to_native(
|
||||
JNIEnv *env, jclass cls, jobject obj, jniapi_t *japi) {
|
||||
/* Get object array "rmrs" and "names" in OrderedRtMemRefDict */
|
||||
JNI_TYPE_VAR_CALL(env, jobjectArray, ormrd_rmrs,
|
||||
(*env)->CallObjectMethod(env, obj, japi->ormrd_getRmrs));
|
||||
JNI_TYPE_VAR_CALL(env, jobjectArray, ormrd_names,
|
||||
(*env)->CallObjectMethod(env, obj, japi->ormrd_getNames));
|
||||
|
||||
/* Get length of object array "rmrs" and "names" in OrderedRtMemRefDict */
|
||||
JNI_TYPE_VAR_CALL(
|
||||
env, jsize, ormrd_rmrs_len, (*env)->GetArrayLength(env, ormrd_rmrs));
|
||||
JNI_TYPE_VAR_CALL(
|
||||
env, jsize, ormrd_names_len, (*env)->GetArrayLength(env, ormrd_names));
|
||||
/* Get OMTensor array Java object in OMTensorList */
|
||||
JNI_TYPE_VAR_CALL(env, jobjectArray, omt_list_omts,
|
||||
(*env)->CallObjectMethod(env, obj, japi->omt_list_getOmts));
|
||||
|
||||
/* Allocate memory for holding each Java rmr object and name string,
|
||||
* and RtMemRef and char pointers for constructing native RtMemRef and name
|
||||
* array
|
||||
/* Get the number of OMTensors in the array */
|
||||
JNI_TYPE_VAR_CALL(
|
||||
env, jsize, omt_list_nomt, (*env)->GetArrayLength(env, omt_list_omts));
|
||||
|
||||
/* Allocate memory for holding each Java omt object and OMTensor pointers
|
||||
* for constructing native OMTensor array
|
||||
*/
|
||||
JNI_COND(jobject *, obj_rmr, malloc(ormrd_rmrs_len * sizeof(jobject)), NULL,
|
||||
env, japi->ecpt_cls, "obj_rmr=null");
|
||||
JNI_COND(jstring *, obj_name, malloc(ormrd_names_len * sizeof(jstring)), NULL,
|
||||
env, japi->ecpt_cls, "obj_name=null");
|
||||
JNI_COND(RtMemRef **, jni_rmr, malloc(ormrd_rmrs_len * sizeof(RtMemRef *)),
|
||||
NULL, env, japi->ecpt_cls, "jni_rmr=null");
|
||||
JNI_COND(const char **, jni_name,
|
||||
malloc(ormrd_names_len * sizeof(const char *)), NULL, env, japi->ecpt_cls,
|
||||
"jni_name=null");
|
||||
LIB_TYPE_VAR_CALL(jobject *, obj_omts,
|
||||
malloc(omt_list_nomt * sizeof(jobject)), NULL, env, japi->ecpt_cls,
|
||||
"obj_omts=null");
|
||||
LIB_TYPE_VAR_CALL(OMTensor **, jni_omts,
|
||||
malloc(omt_list_nomt * sizeof(OMTensor *)), NULL, env, japi->ecpt_cls,
|
||||
"jni_omts=null");
|
||||
|
||||
/* Create OrderedRtMemRefDict to be constructed and passed to the model shared
|
||||
* library */
|
||||
JNI_COND(OrderedRtMemRefDict *, ormrd, createOrderedRtMemRefDict(), NULL, env,
|
||||
japi->ecpt_cls, "ormrd=null");
|
||||
|
||||
/* Loop through all the ormrd_rmrs and ormrd_names */
|
||||
for (int i = 0; i < ormrd_rmrs_len; i++) {
|
||||
/* Loop through all the omt_list_omts */
|
||||
for (int i = 0; i < omt_list_nomt; i++) {
|
||||
JNI_VAR_CALL(
|
||||
env, obj_rmr[i], (*env)->GetObjectArrayElement(env, ormrd_rmrs, i));
|
||||
JNI_VAR_CALL(
|
||||
env, obj_name[i], (*env)->GetObjectArrayElement(env, ormrd_names, i));
|
||||
env, obj_omts[i], (*env)->GetObjectArrayElement(env, omt_list_omts, i));
|
||||
|
||||
/* Get type, rank, data, sizes, and strides by calling corresponding methods
|
||||
/* Get data, dataSizes, dataStrides, dataType, rank, name and
|
||||
* dataBufferSize by calling corresponding methods
|
||||
*/
|
||||
JNI_TYPE_VAR_CALL(env, jint, rmr_type,
|
||||
(*env)->CallIntMethod(env, obj_rmr[i], japi->rmr_getType));
|
||||
JNI_TYPE_VAR_CALL(env, jint, rmr_rank,
|
||||
(*env)->CallIntMethod(env, obj_rmr[i], japi->rmr_getRank));
|
||||
JNI_TYPE_VAR_CALL(env, jlong, rmr_datasize,
|
||||
(*env)->CallLongMethod(env, obj_rmr[i], japi->rmr_getDataSize));
|
||||
JNI_TYPE_VAR_CALL(env, jobject, rmr_data,
|
||||
(*env)->CallObjectMethod(env, obj_rmr[i], japi->rmr_getData));
|
||||
JNI_TYPE_VAR_CALL(env, jobject, rmr_sizes,
|
||||
(*env)->CallObjectMethod(env, obj_rmr[i], japi->rmr_getSizes));
|
||||
JNI_TYPE_VAR_CALL(env, jobject, rmr_strides,
|
||||
(*env)->CallObjectMethod(env, obj_rmr[i], japi->rmr_getStrides));
|
||||
|
||||
/* Primitive type int and long can be directly used */
|
||||
int jni_type = rmr_type, jni_rank = rmr_rank;
|
||||
long jni_datasize = rmr_datasize;
|
||||
JNI_TYPE_VAR_CALL(env, jobject, omt_data,
|
||||
(*env)->CallObjectMethod(env, obj_omts[i], japi->omt_getData));
|
||||
JNI_TYPE_VAR_CALL(env, jobject, omt_dataSizes,
|
||||
(*env)->CallObjectMethod(env, obj_omts[i], japi->omt_getDataSizes));
|
||||
JNI_TYPE_VAR_CALL(env, jobject, omt_dataStrides,
|
||||
(*env)->CallObjectMethod(env, obj_omts[i], japi->omt_getDataStrides));
|
||||
JNI_TYPE_VAR_CALL(env, jint, omt_dataType,
|
||||
(*env)->CallIntMethod(env, obj_omts[i], japi->omt_getDataType));
|
||||
JNI_TYPE_VAR_CALL(env, jlong, omt_dataBufferSize,
|
||||
(*env)->CallLongMethod(env, obj_omts[i], japi->omt_getDataBufferSize));
|
||||
JNI_TYPE_VAR_CALL(env, jint, omt_rank,
|
||||
(*env)->CallIntMethod(env, obj_omts[i], japi->omt_getRank));
|
||||
JNI_TYPE_VAR_CALL(env, jstring, omt_name,
|
||||
(*env)->CallObjectMethod(env, obj_omts[i], japi->omt_getName));
|
||||
JNI_TYPE_VAR_CALL(env, jlong, omt_numOfElems,
|
||||
(*env)->CallLongMethod(env, obj_omts[i], japi->omt_getNumOfElems));
|
||||
|
||||
/* Get direct buffer associated with data */
|
||||
JNI_TYPE_VAR_CALL(
|
||||
env, void *, jni_data, (*env)->GetDirectBufferAddress(env, rmr_data));
|
||||
env, void *, jni_data, (*env)->GetDirectBufferAddress(env, omt_data));
|
||||
|
||||
/* Get long array associated with sizes and strides */
|
||||
JNI_TYPE_VAR_CALL(env, long *, jni_sizes,
|
||||
(*env)->GetLongArrayElements(env, rmr_sizes, NULL));
|
||||
JNI_TYPE_VAR_CALL(env, long *, jni_strides,
|
||||
(*env)->GetLongArrayElements(env, rmr_strides, NULL));
|
||||
/* Get long array associated with data sizes and strides */
|
||||
JNI_TYPE_VAR_CALL(env, long *, jni_dataSizes,
|
||||
(*env)->GetLongArrayElements(env, omt_dataSizes, NULL));
|
||||
JNI_TYPE_VAR_CALL(env, long *, jni_dataStrides,
|
||||
(*env)->GetLongArrayElements(env, omt_dataStrides, NULL));
|
||||
|
||||
/* Primitive type int and long can be directly used */
|
||||
int jni_dataType = omt_dataType;
|
||||
long jni_dataBufferSize = omt_dataBufferSize;
|
||||
int jni_rank = omt_rank;
|
||||
long jni_numOfElems = omt_numOfElems;
|
||||
|
||||
/* Get name string */
|
||||
JNI_TYPE_VAR_CALL(env, char *, jni_name,
|
||||
(char *)(*env)->GetStringUTFChars(env, omt_name, NULL));
|
||||
|
||||
/* Print debug info on what we got from the Java side */
|
||||
RMR_DEBUG(
|
||||
i, jni_type, jni_rank, jni_sizes, jni_strides, jni_data, jni_datasize);
|
||||
OMT_DEBUG(i, jni_numOfElems, jni_data, jni_dataSizes, jni_dataStrides,
|
||||
jni_dataType, jni_dataBufferSize, jni_rank, jni_name);
|
||||
|
||||
/* Create native RtMemRef struct and fill in its fields */
|
||||
jni_rmr[i] = createRtMemRef(jni_rank);
|
||||
setDType(jni_rmr[i], jni_type);
|
||||
setData(jni_rmr[i], jni_data);
|
||||
setSizes(jni_rmr[i], jni_sizes);
|
||||
setStrides(jni_rmr[i], jni_strides);
|
||||
|
||||
/*jni_name[i] = (*env)->GetStringUTFChars(env, obj_name[i], NULL);
|
||||
printf("jni_name=%s\n", jni_name[i]);*/
|
||||
|
||||
/* Install RtMemRef into OrderedRtMemRefDict */
|
||||
setRtMemRef(ormrd, i, jni_rmr[i]);
|
||||
/* Create native OMTensor struct and fill in its fields */
|
||||
LIB_VAR_CALL(jni_omts[i], omt_create(jni_rank), NULL, env, japi->ecpt_cls,
|
||||
"jni_omts[%d]=null", i);
|
||||
omt_setData(jni_omts[i], jni_data);
|
||||
omt_setDataSizes(jni_omts[i], jni_dataSizes);
|
||||
omt_setDataStrides(jni_omts[i], jni_dataStrides);
|
||||
omt_setDataType(jni_omts[i], jni_dataType);
|
||||
omt_setName(jni_omts[i], jni_name);
|
||||
|
||||
/* Release reference to the java objects */
|
||||
JNI_CALL(
|
||||
env, (*env)->ReleaseLongArrayElements(env, rmr_sizes, jni_sizes, 0));
|
||||
JNI_CALL(env,
|
||||
(*env)->ReleaseLongArrayElements(env, rmr_strides, jni_strides, 0));
|
||||
(*env)->ReleaseLongArrayElements(env, omt_dataSizes, jni_dataSizes, 0));
|
||||
JNI_CALL(env, (*env)->ReleaseLongArrayElements(
|
||||
env, omt_dataStrides, jni_dataStrides, 0));
|
||||
JNI_CALL(env, (*env)->ReleaseStringUTFChars(env, omt_name, jni_name));
|
||||
}
|
||||
|
||||
/* setRtMemRef(ormrd, jni_rmr, jni_name); */
|
||||
return ormrd;
|
||||
/* Create OMTensorList to be constructed and passed to the
|
||||
* model shared library
|
||||
*/
|
||||
LIB_TYPE_VAR_CALL(OMTensorList *, list,
|
||||
omt_list_create(jni_omts, omt_list_nomt), NULL, env, japi->ecpt_cls,
|
||||
"list=null");
|
||||
|
||||
return list;
|
||||
}
|
||||
|
||||
/* Convert native data structure to Java object */
|
||||
jobject ormrd_native_to_java(
|
||||
JNIEnv *env, jclass cls, OrderedRtMemRefDict *dict, jniapi_t *japi) {
|
||||
JNI_COND(int, nrmr, numRtMemRefs(dict), 0, env, japi->ecpt_cls, "nrmr=0");
|
||||
jobject omt_list_native_to_java(
|
||||
JNIEnv *env, jclass cls, OMTensorList *dict, jniapi_t *japi) {
|
||||
|
||||
/* Create RtMemRef java object array */
|
||||
JNI_TYPE_VAR_CALL(env, jobjectArray, rmrs,
|
||||
(*env)->NewObjectArray(env, nrmr, japi->rmr_cls, NULL));
|
||||
/* Get the OMTensor array in the OMTensorList */
|
||||
LIB_TYPE_VAR_CALL(OMTensor **, jni_omts, omTensorListGetOmts(dict), NULL, env,
|
||||
japi->ecpt_cls, "jni_omts=null");
|
||||
/* Get the number of OMTensors in the OMTensorList */
|
||||
LIB_TYPE_VAR_CALL(int, jni_nomt, omTensorListGetNumOfOmts(dict), 0, env,
|
||||
japi->ecpt_cls, "jni_nomt=0");
|
||||
|
||||
/* Loop through the native RtMemRef structs */
|
||||
for (int i = 0; i < nrmr; i++) {
|
||||
JNI_COND(RtMemRef *, rmr, getRtMemRef(dict, i), NULL, env, japi->ecpt_cls,
|
||||
"rmr[%d]=null", i);
|
||||
/* Create OMTensor java object array */
|
||||
JNI_TYPE_VAR_CALL(env, jobjectArray, obj_omts,
|
||||
(*env)->NewObjectArray(env, jni_nomt, japi->omt_cls, NULL));
|
||||
|
||||
JNI_COND(int, jni_type, getDType(rmr), 0, env, japi->ecpt_cls,
|
||||
"rmr[%d]:type=0", i);
|
||||
JNI_COND(int, jni_rank, getRank(rmr), 0, env, japi->ecpt_cls,
|
||||
"rmr[%d]:rank=0", i);
|
||||
JNI_COND(long, jni_datasize, getDataSize(rmr), 0, env, japi->ecpt_cls,
|
||||
"rmr[%d]:datasize=0", i);
|
||||
JNI_COND(void *, jni_data, getData(rmr), NULL, env, japi->ecpt_cls,
|
||||
"rmr[%d]:data=null", i);
|
||||
JNI_COND(long *, jni_sizes, getSizes(rmr), NULL, env, japi->ecpt_cls,
|
||||
"rmr[%d]:sizes=null", i);
|
||||
JNI_COND(long *, jni_strides, getStrides(rmr), NULL, env, japi->ecpt_cls,
|
||||
"rmr[%d]:strides=null", i);
|
||||
/* Loop through the native OMTensor structs */
|
||||
for (int i = 0; i < jni_nomt; i++) {
|
||||
|
||||
LIB_TYPE_VAR_CALL(void *, jni_data, omTensorGetData(jni_omts[i]), NULL, env,
|
||||
japi->ecpt_cls, "omt[%d]:data=null", i);
|
||||
LIB_TYPE_VAR_CALL(long *, jni_dataSizes, omTensorGetDataSizes(jni_omts[i]),
|
||||
NULL, env, japi->ecpt_cls, "omt[%d]:dataSizes=null", i);
|
||||
LIB_TYPE_VAR_CALL(long *, jni_dataStrides,
|
||||
omTensorGetDataStrides(jni_omts[i]), NULL, env, japi->ecpt_cls,
|
||||
"omt[%d]:dataStrides=null", i);
|
||||
LIB_TYPE_VAR_CALL(int, jni_dataType, omTensorGetDataType(jni_omts[i]), 0,
|
||||
env, japi->ecpt_cls, "omt[%d]:dataType=0", i);
|
||||
LIB_TYPE_VAR_CALL(long, jni_dataBufferSize,
|
||||
omt_getDataBufferSize(jni_omts[i]), 0, env, japi->ecpt_cls,
|
||||
"omt[%ld]:dataBufferSize=0", i);
|
||||
LIB_TYPE_VAR_CALL(int, jni_rank, omTensorGetRank(jni_omts[i]), 0, env,
|
||||
japi->ecpt_cls, "omt[%d]:rank=0", i);
|
||||
LIB_TYPE_VAR_CALL(char *, jni_name, omTensorGetName(jni_omts[i]), NULL, env,
|
||||
japi->ecpt_cls, "omt[%d]:name=null", i);
|
||||
LIB_TYPE_VAR_CALL(long, jni_numOfElems, omTensorGetNumOfElems(jni_omts[i]),
|
||||
0, env, japi->ecpt_cls, "omt[%d]:numOfElems=0", i);
|
||||
|
||||
/* Print debug info on what we got from the native side */
|
||||
RMR_DEBUG(
|
||||
i, jni_type, jni_rank, jni_sizes, jni_strides, jni_data, jni_datasize);
|
||||
OMT_DEBUG(i, jni_numOfElems, jni_data, jni_dataSizes, jni_dataStrides,
|
||||
jni_dataType, jni_dataBufferSize, jni_rank, jni_name);
|
||||
|
||||
/* create the following Java objects:
|
||||
* - RtMemRef
|
||||
* - DirectByteBuffer (from native buffers)
|
||||
* - long array for sizes and strides
|
||||
/* Create the OMTensor Java object */
|
||||
JNI_TYPE_VAR_CALL(env, jobject, obj_omt,
|
||||
(*env)->NewObject(env, japi->omt_cls, japi->omt_constructor, jni_rank));
|
||||
|
||||
/* Create direct byte buffer Java object from native data buffer, and
|
||||
* call setData method
|
||||
*/
|
||||
JNI_TYPE_VAR_CALL(env, jobject, omt_data,
|
||||
(*env)->NewDirectByteBuffer(env, jni_data, jni_dataBufferSize));
|
||||
JNI_CALL(env,
|
||||
(*env)->CallObjectMethod(env, obj_omt, japi->omt_setData, omt_data));
|
||||
|
||||
/* Create data sizes array Java object, fill in from native array, and
|
||||
* call setDataSizes method
|
||||
*/
|
||||
JNI_TYPE_VAR_CALL(env, jobject, obj_rmr,
|
||||
(*env)->NewObject(env, japi->rmr_cls, japi->rmr_constructor, jni_rank));
|
||||
JNI_TYPE_VAR_CALL(env, jobject, rmr_data,
|
||||
(*env)->NewDirectByteBuffer(
|
||||
env, jni_data, jni_datasize * onnx_type_size[jni_type]));
|
||||
JNI_TYPE_VAR_CALL(
|
||||
env, jlongArray, rmr_sizes, (*env)->NewLongArray(env, jni_rank));
|
||||
JNI_TYPE_VAR_CALL(
|
||||
env, jlongArray, rmr_strides, (*env)->NewLongArray(env, jni_rank));
|
||||
|
||||
/* Call setType method */
|
||||
JNI_CALL(env,
|
||||
(*env)->CallObjectMethod(env, obj_rmr, japi->rmr_setType, jni_type));
|
||||
|
||||
/* Call setData method */
|
||||
JNI_CALL(env,
|
||||
(*env)->CallObjectMethod(env, obj_rmr, japi->rmr_setData, rmr_data));
|
||||
|
||||
/* Fill in sizes array from native array and call setSizes method */
|
||||
JNI_CALL(env,
|
||||
(*env)->SetLongArrayRegion(env, rmr_sizes, 0, jni_rank, jni_sizes));
|
||||
JNI_CALL(env,
|
||||
(*env)->CallObjectMethod(env, obj_rmr, japi->rmr_setSizes, rmr_sizes));
|
||||
|
||||
/* Fill in strides array from native array and call setStrides method */
|
||||
JNI_CALL(env,
|
||||
(*env)->SetLongArrayRegion(env, rmr_strides, 0, jni_rank, jni_strides));
|
||||
env, jlongArray, omt_dataSizes, (*env)->NewLongArray(env, jni_rank));
|
||||
JNI_CALL(env, (*env)->SetLongArrayRegion(
|
||||
env, omt_dataSizes, 0, jni_rank, jni_dataSizes));
|
||||
JNI_CALL(env, (*env)->CallObjectMethod(
|
||||
env, obj_rmr, japi->rmr_setStrides, rmr_strides));
|
||||
env, obj_omt, japi->omt_setDataSizes, omt_dataSizes));
|
||||
|
||||
/* Set DynMemRef object in the object array */
|
||||
JNI_CALL(env, (*env)->SetObjectArrayElement(env, rmrs, i, obj_rmr));
|
||||
/* Create data strides array Java object, fill in from native array, and
|
||||
* call setStrides method
|
||||
*/
|
||||
JNI_TYPE_VAR_CALL(
|
||||
env, jlongArray, omt_dataStrides, (*env)->NewLongArray(env, jni_rank));
|
||||
JNI_CALL(env, (*env)->SetLongArrayRegion(
|
||||
env, omt_dataStrides, 0, jni_rank, jni_dataStrides));
|
||||
JNI_CALL(env, (*env)->CallObjectMethod(
|
||||
env, obj_omt, japi->omt_setDataStrides, omt_dataStrides));
|
||||
|
||||
/* Primitive type int can be directly used. Call setDataType method */
|
||||
JNI_CALL(env, (*env)->CallIntMethod(
|
||||
env, obj_omt, japi->omt_setDataType, (jint)jni_dataType));
|
||||
|
||||
/* Create string Java object from native char * and call setName method */
|
||||
JNI_TYPE_VAR_CALL(
|
||||
env, jstring, omt_name, (*env)->NewStringUTF(env, jni_name));
|
||||
JNI_CALL(env,
|
||||
(*env)->CallObjectMethod(env, obj_omt, japi->omt_setName, omt_name));
|
||||
|
||||
/* Set OMTensor object in the object array */
|
||||
JNI_CALL(env, (*env)->SetObjectArrayElement(env, obj_omts, i, obj_omt));
|
||||
}
|
||||
|
||||
/* Create the OrderedRtMemRefDict java object */
|
||||
JNI_TYPE_VAR_CALL(env, jobject, ormrd,
|
||||
(*env)->NewObject(env, japi->ormrd_cls, japi->ormrd_constructor, rmrs));
|
||||
/* Create the OMTensorList java object */
|
||||
JNI_TYPE_VAR_CALL(env, jobject, list,
|
||||
(*env)->NewObject(
|
||||
env, japi->omt_list_cls, japi->omt_list_constructor, obj_omts));
|
||||
|
||||
return ormrd;
|
||||
return list;
|
||||
}
|
||||
|
||||
JNIEXPORT jobject JNICALL Java_com_ibm_onnxmlir_DynEntryPoint_main_1graph_1jni(
|
||||
JNIEnv *env, jclass cls, jobject obj) {
|
||||
CHECK_CALL(jniapi_t *, japi, fill_jniapi(env, &jniapi), NULL);
|
||||
|
||||
CHECK_CALL(OrderedRtMemRefDict *, input_ormrd,
|
||||
ormrd_java_to_native(env, cls, obj, japi), NULL);
|
||||
|
||||
CHECK_CALL(OrderedRtMemRefDict *, dict,
|
||||
_dyn_entry_point_main_graph(input_ormrd), NULL);
|
||||
CHECK_CALL(OMTensorList *, input_list,
|
||||
omt_list_java_to_native(env, cls, obj, japi), NULL);
|
||||
|
||||
CHECK_CALL(
|
||||
jobject, output_ormrd, ormrd_native_to_java(env, cls, dict, japi), NULL);
|
||||
OMTensorList *, dict, _dyn_entry_point_main_graph(input_list), NULL);
|
||||
|
||||
return output_ormrd;
|
||||
CHECK_CALL(jobject, output_list,
|
||||
omt_list_native_to_java(env, cls, dict, japi), NULL);
|
||||
|
||||
return output_list;
|
||||
}
|
||||
|
|
|
@ -56,9 +56,9 @@ public class DynEntryPoint {
|
|||
}
|
||||
}
|
||||
|
||||
private static native OrderedRtMemRefDict main_graph_jni(OrderedRtMemRefDict ormrd);
|
||||
private static native OMTensorList main_graph_jni(OMTensorList list);
|
||||
|
||||
public static OrderedRtMemRefDict main_graph(OrderedRtMemRefDict ormrd) {
|
||||
return main_graph_jni(ormrd);
|
||||
public static OMTensorList main_graph(OMTensorList list) {
|
||||
return main_graph_jni(list);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package com.ibm.onnxmlir;
|
||||
|
||||
import java.lang.reflect.Array;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.DoubleBuffer;
|
||||
|
@ -8,7 +9,7 @@ import java.nio.IntBuffer;
|
|||
import java.nio.LongBuffer;
|
||||
import java.nio.ShortBuffer;
|
||||
|
||||
public class RtMemRef {
|
||||
public class OMTensor {
|
||||
final ByteOrder endian = ByteOrder.nativeOrder();
|
||||
|
||||
/* We can use enum but that creates another class
|
||||
|
@ -52,47 +53,26 @@ public class RtMemRef {
|
|||
2, /* BFLOAT16 */
|
||||
};
|
||||
|
||||
private ByteBuffer _data;
|
||||
private int _type;
|
||||
private ByteBuffer _allocatedPtr;
|
||||
private long[] _shape;
|
||||
private long[] _stride;
|
||||
private int _dataType;
|
||||
private int _rank;
|
||||
private long[] _sizes;
|
||||
private long[] _strides;
|
||||
private String _name;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*/
|
||||
public RtMemRef(int rank) {
|
||||
public OMTensor(int rank) {
|
||||
if (rank <= 0)
|
||||
throw new IllegalArgumentException(
|
||||
"invalid rank " + rank);
|
||||
_data = null;
|
||||
_type = ONNX_TYPE_UNDEFINED;
|
||||
_allocatedPtr = null;
|
||||
_shape = new long[rank];
|
||||
_stride = new long[rank];
|
||||
_dataType = ONNX_TYPE_UNDEFINED;
|
||||
_rank = rank;
|
||||
_sizes = new long[rank];
|
||||
_strides = new long[rank];
|
||||
}
|
||||
|
||||
/* ---------- Data type getter and setter ---------- */
|
||||
/* For JNI wrapper only. Not intended for end user. */
|
||||
|
||||
/**
|
||||
* Type getter
|
||||
*
|
||||
* @return data type
|
||||
*/
|
||||
@SuppressWarnings("unused")
|
||||
private int getType() {
|
||||
return _type;
|
||||
}
|
||||
|
||||
/**
|
||||
* Type setter
|
||||
*
|
||||
* @param type data type to be set
|
||||
*/
|
||||
@SuppressWarnings("unused")
|
||||
private void setType(int type) {
|
||||
_type = type;
|
||||
_name = "";
|
||||
}
|
||||
|
||||
/* ---------- Raw data getter and setter ---------- */
|
||||
|
@ -105,7 +85,7 @@ public class RtMemRef {
|
|||
*/
|
||||
@SuppressWarnings("unused")
|
||||
private ByteBuffer getData() {
|
||||
return _data;
|
||||
return _allocatedPtr;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -115,7 +95,7 @@ public class RtMemRef {
|
|||
*/
|
||||
@SuppressWarnings("unused")
|
||||
private void setData(ByteBuffer data) {
|
||||
_data = data.order(endian);
|
||||
_allocatedPtr = data.order(endian);
|
||||
}
|
||||
|
||||
/* ---------- Byte data getter and setter ---------- */
|
||||
|
@ -126,14 +106,14 @@ public class RtMemRef {
|
|||
* @return byte data array
|
||||
*/
|
||||
public byte[] getByteData() {
|
||||
if (_data == null) return null;
|
||||
if (_allocatedPtr == null) return null;
|
||||
|
||||
/* asReadOnlyBuffer() creates a new view so the position of the
|
||||
* original data will stay at 0 for subsequent getByteData()
|
||||
* after get(b).
|
||||
*/
|
||||
byte[] b = new byte[_data.limit()];
|
||||
_data.asReadOnlyBuffer().get(b);
|
||||
byte[] b = new byte[_allocatedPtr.limit()];
|
||||
_allocatedPtr.asReadOnlyBuffer().get(b);
|
||||
return b;
|
||||
}
|
||||
|
||||
|
@ -146,9 +126,9 @@ public class RtMemRef {
|
|||
/* slice() creates a new view so the position of the
|
||||
* original data will stay at 0 for getByteData() after put(data).
|
||||
*/
|
||||
_data = ByteBuffer.allocateDirect(data.length);
|
||||
_data.slice().put(data);
|
||||
_type = ONNX_TYPE_INT8;
|
||||
_allocatedPtr = ByteBuffer.allocateDirect(data.length);
|
||||
_allocatedPtr.slice().put(data);
|
||||
_dataType = ONNX_TYPE_INT8;
|
||||
}
|
||||
|
||||
/* ---------- Short data getter and setter ---------- */
|
||||
|
@ -159,13 +139,13 @@ public class RtMemRef {
|
|||
* @return short data array
|
||||
*/
|
||||
public short[] getShortData() {
|
||||
if (_data == null) return null;
|
||||
if (_allocatedPtr == null) return null;
|
||||
|
||||
/* asShortBuffer() creates a new view so the position of the
|
||||
* original data will stay at 0 for subsequent getShortData()
|
||||
* after get(s).
|
||||
*/
|
||||
ShortBuffer sb = _data.asShortBuffer();
|
||||
ShortBuffer sb = _allocatedPtr.asShortBuffer();
|
||||
short[] s = new short[sb.limit()];
|
||||
sb.get(s);
|
||||
return s;
|
||||
|
@ -180,9 +160,9 @@ public class RtMemRef {
|
|||
/* asShortBuffer() creates a new view so the position of the
|
||||
* original data will stay at 0 for getShortData() after put(data).
|
||||
*/
|
||||
_data = ByteBuffer.allocateDirect(data.length*2).order(endian);
|
||||
_data.asShortBuffer().put(data);
|
||||
_type = ONNX_TYPE_INT16;
|
||||
_allocatedPtr = ByteBuffer.allocateDirect(data.length*2).order(endian);
|
||||
_allocatedPtr.asShortBuffer().put(data);
|
||||
_dataType = ONNX_TYPE_INT16;
|
||||
}
|
||||
|
||||
/* ---------- Int data getter and setter ---------- */
|
||||
|
@ -193,13 +173,13 @@ public class RtMemRef {
|
|||
* @return int data array
|
||||
*/
|
||||
public int[] getIntData() {
|
||||
if (_data == null) return null;
|
||||
if (_allocatedPtr == null) return null;
|
||||
|
||||
/* asIntBuffer() creates a new view so the position of the
|
||||
* original data will stay at 0 for subsequent getIntData()
|
||||
* after get(i).
|
||||
*/
|
||||
IntBuffer ib = _data.asIntBuffer();
|
||||
IntBuffer ib = _allocatedPtr.asIntBuffer();
|
||||
int[] i = new int[ib.limit()];
|
||||
ib.get(i);
|
||||
return i;
|
||||
|
@ -214,9 +194,9 @@ public class RtMemRef {
|
|||
/* asIntBuffer() creates a new view so the position of the
|
||||
* original data will stay at 0 for getIntData() after put(data).
|
||||
*/
|
||||
_data = ByteBuffer.allocateDirect(data.length*4).order(endian);
|
||||
_data.asIntBuffer().put(data);
|
||||
_type = ONNX_TYPE_INT32;
|
||||
_allocatedPtr = ByteBuffer.allocateDirect(data.length*4).order(endian);
|
||||
_allocatedPtr.asIntBuffer().put(data);
|
||||
_dataType = ONNX_TYPE_INT32;
|
||||
}
|
||||
|
||||
/* ---------- Long data getter and setter ---------- */
|
||||
|
@ -227,13 +207,13 @@ public class RtMemRef {
|
|||
* @return long data array
|
||||
*/
|
||||
public long[] getLongData() {
|
||||
if (_data == null) return null;
|
||||
if (_allocatedPtr == null) return null;
|
||||
|
||||
/* asLongBuffer() creates a new view so the position of the
|
||||
* original data will stay at 0 for subsequent getLongData()
|
||||
* after get(l).
|
||||
*/
|
||||
LongBuffer lb = _data.asLongBuffer();
|
||||
LongBuffer lb = _allocatedPtr.asLongBuffer();
|
||||
long[] l = new long[lb.limit()];
|
||||
lb.get(l);
|
||||
return l;
|
||||
|
@ -248,9 +228,9 @@ public class RtMemRef {
|
|||
/* asLongBuffer() creates a new view so the position of the
|
||||
* original data will stay at 0 for getLongData() after put(data).
|
||||
*/
|
||||
_data = ByteBuffer.allocateDirect(data.length*8).order(endian);
|
||||
_data.asLongBuffer().put(data);
|
||||
_type = ONNX_TYPE_INT64;
|
||||
_allocatedPtr = ByteBuffer.allocateDirect(data.length*8).order(endian);
|
||||
_allocatedPtr.asLongBuffer().put(data);
|
||||
_dataType = ONNX_TYPE_INT64;
|
||||
}
|
||||
|
||||
/* ---------- Float data getter and setter ---------- */
|
||||
|
@ -261,13 +241,13 @@ public class RtMemRef {
|
|||
* @return float data array
|
||||
*/
|
||||
public float[] getFloatData() {
|
||||
if (_data == null) return null;
|
||||
if (_allocatedPtr == null) return null;
|
||||
|
||||
/* asFloatBuffer() creates a new view so the position of the
|
||||
* original data will stay at 0 for subsequent getFloatData()
|
||||
* after get(f).
|
||||
*/
|
||||
FloatBuffer fb = _data.asFloatBuffer();
|
||||
FloatBuffer fb = _allocatedPtr.asFloatBuffer();
|
||||
float[] f = new float[fb.limit()];
|
||||
fb.get(f);
|
||||
return f;
|
||||
|
@ -282,9 +262,9 @@ public class RtMemRef {
|
|||
/* asFloatBuffer() creates a new view so the position of the
|
||||
* original data will stay at 0 for getFloatData() after put(data).
|
||||
*/
|
||||
_data = ByteBuffer.allocateDirect(data.length*4).order(endian);
|
||||
_data.asFloatBuffer().put(data);
|
||||
_type = ONNX_TYPE_FLOAT;
|
||||
_allocatedPtr = ByteBuffer.allocateDirect(data.length*4).order(endian);
|
||||
_allocatedPtr.asFloatBuffer().put(data);
|
||||
_dataType = ONNX_TYPE_FLOAT;
|
||||
}
|
||||
|
||||
/* ---------- Double data getter and setter ---------- */
|
||||
|
@ -295,13 +275,13 @@ public class RtMemRef {
|
|||
* @return double data array
|
||||
*/
|
||||
public double[] getDoubleData() {
|
||||
if (_data == null) return null;
|
||||
if (_allocatedPtr == null) return null;
|
||||
|
||||
/* asDoubleBuffer() creates a new view so the position of the
|
||||
* original data will stay at 0 for subsequent getDoubleData()
|
||||
* after get(d).
|
||||
*/
|
||||
DoubleBuffer db = _data.asDoubleBuffer();
|
||||
DoubleBuffer db = _allocatedPtr.asDoubleBuffer();
|
||||
double[] d = new double[db.limit()];
|
||||
db.get(d);
|
||||
return d;
|
||||
|
@ -316,85 +296,141 @@ public class RtMemRef {
|
|||
/* asDoubleBuffer() creates a new view so the position of the
|
||||
* original data will stay at 0 for getDoubleData() after put(data).
|
||||
*/
|
||||
_data = ByteBuffer.allocateDirect(data.length*8).order(endian);
|
||||
_data.asDoubleBuffer().put(data);
|
||||
_type = ONNX_TYPE_DOUBLE;
|
||||
_allocatedPtr = ByteBuffer.allocateDirect(data.length*8).order(endian);
|
||||
_allocatedPtr.asDoubleBuffer().put(data);
|
||||
_dataType = ONNX_TYPE_DOUBLE;
|
||||
}
|
||||
|
||||
/* ---------- Data sizes getter and setter ---------- */
|
||||
|
||||
/**
|
||||
* Data sizes getter
|
||||
*
|
||||
* @return data sizes array
|
||||
*/
|
||||
public long[] getDataSizes() {
|
||||
return _shape;
|
||||
}
|
||||
|
||||
/**
|
||||
* Data sizes setter
|
||||
*
|
||||
* @param dataSizes data sizes array to be set
|
||||
*/
|
||||
public void setDataSizes(long[] dataSizes) {
|
||||
if (dataSizes.length != _rank)
|
||||
throw new IllegalArgumentException(
|
||||
"array length " + dataSizes.length + " != rank " + _rank);
|
||||
_shape = dataSizes.clone();
|
||||
}
|
||||
|
||||
/* ---------- Data strides getter and setter ---------- */
|
||||
|
||||
/**
|
||||
* Data strides getter
|
||||
*
|
||||
* @return data strides array
|
||||
*/
|
||||
public long[] getDataStrides() {
|
||||
return _stride;
|
||||
}
|
||||
|
||||
/**
|
||||
* Data strides setter
|
||||
*
|
||||
* @param dataStrides data strides array to be set
|
||||
*/
|
||||
public void setDataStrides(long[] dataStrides) {
|
||||
if (dataStrides.length != _rank)
|
||||
throw new IllegalArgumentException(
|
||||
"array length " + dataStrides.length + " != rank " + _rank);
|
||||
_stride = dataStrides.clone();
|
||||
}
|
||||
|
||||
/* ---------- Data type getter and setter ---------- */
|
||||
|
||||
/**
|
||||
* Data type getter
|
||||
*
|
||||
* @return data type
|
||||
*/
|
||||
public int getDataType() {
|
||||
return _dataType;
|
||||
}
|
||||
|
||||
/**
|
||||
* Data type setter
|
||||
*
|
||||
* @param type data type to be set
|
||||
*/
|
||||
public void setDataType(int dataType) {
|
||||
if (dataType < 0 || dataType > ONNX_TYPE_BFLOAT16)
|
||||
throw new IllegalArgumentException(
|
||||
"data type " + dataType + " unknown");
|
||||
_dataType = dataType;
|
||||
}
|
||||
|
||||
/* ---------- Data buffer size getter ---------- */
|
||||
|
||||
/**
|
||||
* Data buffer size getter
|
||||
*
|
||||
* @return total size of the data buffer in bytes
|
||||
*/
|
||||
public long getDataBufferSize() {
|
||||
return _allocatedPtr == null ? 0 : _allocatedPtr.limit();
|
||||
}
|
||||
|
||||
/* ---------- Rank getter ---------- */
|
||||
|
||||
/**
|
||||
* Rank getter
|
||||
*
|
||||
* @return rank
|
||||
* @return rank of the OMTensor
|
||||
*/
|
||||
public int getRank() {
|
||||
return _rank;
|
||||
}
|
||||
|
||||
/* ---------- Sizes getter and setter ---------- */
|
||||
/* ---------- Name getter and setter ---------- */
|
||||
|
||||
/**
|
||||
* Sizes getter
|
||||
* Name getter
|
||||
*
|
||||
* @return sizes array
|
||||
* @return name of the OMTensor
|
||||
*/
|
||||
public long[] getSizes() {
|
||||
return _sizes;
|
||||
public String getName() {
|
||||
return _name;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sizes setter
|
||||
* Name setter
|
||||
*
|
||||
* @param sizes sizes array to be set
|
||||
* @param name name of the OMTensor
|
||||
*/
|
||||
public void setSizes(long[] sizes) {
|
||||
if (sizes.length != _rank)
|
||||
throw new IllegalArgumentException(
|
||||
"array length " + sizes.length + " != rank " + _rank);
|
||||
_sizes = sizes.clone();
|
||||
}
|
||||
|
||||
/* ---------- Strides getter and setter ---------- */
|
||||
|
||||
/**
|
||||
* Strides getter
|
||||
*
|
||||
* @return strides array
|
||||
*/
|
||||
public long[] getStrides() {
|
||||
return _strides;
|
||||
public void setName(String name) {
|
||||
_name = name == null ? "" : name;
|
||||
}
|
||||
|
||||
/**
|
||||
* Strides setter
|
||||
* Number of elements getter
|
||||
*
|
||||
* @param strides strides array to be set
|
||||
* @return number of data elements in the data buffer
|
||||
*/
|
||||
public void setStrides(long[] strides) {
|
||||
if (strides.length != _rank)
|
||||
throw new IllegalArgumentException(
|
||||
"array length " + strides.length + " != rank " + _rank);
|
||||
_strides = strides.clone();
|
||||
}
|
||||
|
||||
/**
|
||||
* Size getter
|
||||
*
|
||||
* @return product of sizes array, i.e., total number of data elements
|
||||
*/
|
||||
public long getDataSize() {
|
||||
long n = _sizes[0];
|
||||
for (int i = 1; i < _sizes.length; i++) n *= _sizes[i];
|
||||
public long getNumOfElems() {
|
||||
long n = _shape[0];
|
||||
for (int i = 1; i < _shape.length; i++) n *= _shape[i];
|
||||
return n;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check validity of RtMemRef
|
||||
* Check validity of OMTensor
|
||||
*
|
||||
* @return true if RtMemRef is valid, false otherwise
|
||||
* @return true if OMTensor is valid, false otherwise
|
||||
*/
|
||||
public boolean validRmr() {
|
||||
return (_data != null &&
|
||||
_data.limit() != 0 &&
|
||||
_data.limit() == getDataSize() * ONNX_TYPE_SIZE[_type]);
|
||||
public boolean isValidOmt() {
|
||||
return (_allocatedPtr != null &&
|
||||
_allocatedPtr.limit() != 0 &&
|
||||
_allocatedPtr.limit() == getNumOfElems() * ONNX_TYPE_SIZE[_dataType]);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
package com.ibm.onnxmlir;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
public class OMTensorList {
|
||||
|
||||
private OMTensor[] _omts;
|
||||
private HashMap<String, Integer> _n2i;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*
|
||||
* @param omts DynMemRef array
|
||||
*/
|
||||
public OMTensorList(OMTensor[] omts) {
|
||||
/* Go through the OMTensor array, check each for validity,
|
||||
* and create name (if not empty) to index mapping.
|
||||
*/
|
||||
for (int i = 0; i < omts.length; i++) {
|
||||
if (omts[i] == null || !omts[i].isValidOmt())
|
||||
throw new IllegalArgumentException(
|
||||
"OMTensor[" + i + "] is invalid");
|
||||
String name = omts[i].getName();
|
||||
if (!name.isEmpty() && _n2i.put(name, i) != null)
|
||||
throw new IllegalArgumentException(
|
||||
"OMTensor[" + i + "] duplicate name: " + name);
|
||||
}
|
||||
|
||||
_omts = omts;
|
||||
}
|
||||
|
||||
/**
|
||||
* OMTensor array getter
|
||||
*
|
||||
* @return OMTensor array
|
||||
*/
|
||||
public OMTensor[] getOmts() {
|
||||
return _omts;
|
||||
}
|
||||
}
|
|
@ -1,118 +0,0 @@
|
|||
package com.ibm.onnxmlir;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
public class OrderedRtMemRefDict {
|
||||
|
||||
private RtMemRef[] _rmrs;
|
||||
private String[] _names;
|
||||
private HashMap<String, Integer> _n2i;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*
|
||||
* @param rmrs DynMemRef array
|
||||
*/
|
||||
public OrderedRtMemRefDict(RtMemRef[] rmrs) {
|
||||
this(rmrs, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*
|
||||
* @param rmrs DynMemRef array
|
||||
* @param names name array
|
||||
*/
|
||||
public OrderedRtMemRefDict(RtMemRef[] rmrs, String[] names) {
|
||||
/* rmrs cannot be null or empty */
|
||||
if (rmrs == null || rmrs.length == 0)
|
||||
throw new IllegalArgumentException(
|
||||
"Number of dmrs is invalid");
|
||||
|
||||
/* If names is null or empty, construct a default one with
|
||||
* index as name.
|
||||
*/
|
||||
if (names == null || names.length == 0) {
|
||||
names = new String[rmrs.length];
|
||||
for (int i = 0; i < rmrs.length; i++)
|
||||
names[i] = Integer.toString(i);
|
||||
}
|
||||
|
||||
/* Number of rmrs and names must match */
|
||||
if (rmrs.length != names.length)
|
||||
throw new IllegalArgumentException(
|
||||
"Number of dmrs and names do not match");
|
||||
|
||||
/* Establish name to index mapping. Individual rmr is
|
||||
* checked for validity.
|
||||
*/
|
||||
_n2i = new HashMap<String, Integer>();
|
||||
for (int i = 0; i < names.length; i++) {
|
||||
if (rmrs[i] == null || !rmrs[i].validRmr())
|
||||
throw new IllegalArgumentException(
|
||||
"rmrs[" + i + "] is invalid");
|
||||
if (_n2i.put(names[i], i) != null)
|
||||
throw new IllegalArgumentException(
|
||||
"name[" + i + "] = " + names[i] + " not unique");
|
||||
}
|
||||
_rmrs = rmrs;
|
||||
_names = names;
|
||||
}
|
||||
|
||||
/**
|
||||
* RtMemRef getter by index
|
||||
*
|
||||
* @param idx index of RtMemRef instance to get
|
||||
* @return RtMemRef instance
|
||||
*/
|
||||
public RtMemRef getRmrbyIndex(int idx) {
|
||||
return _rmrs[idx];
|
||||
}
|
||||
|
||||
/**
|
||||
* RtMemRef getter by name
|
||||
*
|
||||
* @param name name of RtMemRef instance to get
|
||||
* @return RtMemRef instance
|
||||
*/
|
||||
public RtMemRef getRmrByName(String name) {
|
||||
return _rmrs[_n2i.get(name)];
|
||||
}
|
||||
|
||||
/**
|
||||
* RtMemRef array getter
|
||||
*
|
||||
* @return RtMemRef array
|
||||
*/
|
||||
public RtMemRef[] getRmrs() {
|
||||
return _rmrs;
|
||||
}
|
||||
|
||||
/**
|
||||
* Name getter
|
||||
*
|
||||
* @param idx index of name to get
|
||||
* @return name string
|
||||
*/
|
||||
public String getName(int idx) {
|
||||
return _names[idx];
|
||||
}
|
||||
|
||||
/**
|
||||
* Name array getter
|
||||
*
|
||||
* @return name array
|
||||
*/
|
||||
public String[] getNames() {
|
||||
return _names;
|
||||
}
|
||||
|
||||
/**
|
||||
* RtMemRef array size getter
|
||||
*
|
||||
* @return RtMemRef array size
|
||||
*/
|
||||
public int size() {
|
||||
return _rmrs.length;
|
||||
}
|
||||
}
|
|
@ -1,3 +1,4 @@
|
|||
add_subdirectory(mlir)
|
||||
add_subdirectory(backend)
|
||||
add_subdirectory(numerical)
|
||||
add_subdirectory(unit)
|
|
@ -92,7 +92,7 @@ class DummyBackend(onnx.backend.base.Backend):
|
|||
# Call frontend to process temp_model.onnx, bit code will be generated.
|
||||
execute_commands([ONNX_MLIR, "temp_model.onnx"])
|
||||
return EndiannessAwareExecutionSession("./temp_model.so",
|
||||
"_dyn_entry_point_main_graph")
|
||||
"run_main_graph")
|
||||
|
||||
@classmethod
|
||||
def supports_device(cls, device):
|
||||
|
@ -448,7 +448,6 @@ test_to_enable = [
|
|||
"test_shufflenet_cpu",
|
||||
]
|
||||
|
||||
|
||||
# Extract name of all test cases.
|
||||
import inspect
|
||||
all_tests = []
|
||||
|
|
|
@ -1,4 +1,9 @@
|
|||
add_executable(TestConv TestConv.cpp)
|
||||
target_compile_definitions(TestConv PRIVATE RTMEMREF_INTERNAL_API)
|
||||
target_include_directories(TestConv
|
||||
PRIVATE
|
||||
${ONNX_MLIR_BIN_ROOT}
|
||||
${ONNX_MLIR_SRC_ROOT}/include)
|
||||
target_link_libraries(TestConv
|
||||
${OMLibs}
|
||||
${MLIRLibs}
|
||||
|
@ -6,11 +11,5 @@ target_link_libraries(TestConv
|
|||
rapidcheck
|
||||
MainUtils
|
||||
ExecutionSession
|
||||
RtMemRefUtils)
|
||||
|
||||
target_include_directories(TestConv
|
||||
PRIVATE
|
||||
${ONNX_MLIR_SRC_ROOT}
|
||||
${ONNX_MLIR_BIN_ROOT}
|
||||
${ONNX_MLIR_SRC_ROOT})
|
||||
OMTensorUtils)
|
||||
add_test(NAME OMTestConv COMMAND TestConv)
|
|
@ -11,7 +11,8 @@
|
|||
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
#include "src/MainUtils.hpp"
|
||||
#include "src/Runtime/ExecusionSession.hpp"
|
||||
#include "src/Runtime/ExecutionSession.hpp"
|
||||
#include "src/Runtime/OMTensorHelper.h"
|
||||
|
||||
#define SHARED_LIB_BASE string("./TestConv_main_graph")
|
||||
|
||||
|
@ -93,38 +94,39 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H,
|
|||
OwningModuleRef moduleRef(module);
|
||||
|
||||
compileModule(moduleRef, ctx, SHARED_LIB_BASE, EmitLib);
|
||||
onnx_mlir::ExecutionSession sess(
|
||||
SHARED_LIB_BASE + ".so", "_dyn_entry_point_main_graph");
|
||||
onnx_mlir::ExecutionSession sess(SHARED_LIB_BASE + ".so", "run_main_graph");
|
||||
|
||||
std::vector<unique_ptr<RtMemRef>> inputs;
|
||||
auto xRmr = unique_ptr<RtMemRef>(getRndRealRmr<float>({N, C, H, W}));
|
||||
inputs.emplace_back(move(xRmr));
|
||||
auto wRmr = unique_ptr<RtMemRef>(getRndRealRmr<float>({C, C, kH, kW}));
|
||||
inputs.emplace_back(move(wRmr));
|
||||
std::vector<unique_ptr<OMTensor, decltype(&omTensorDestroy)>> inputs;
|
||||
auto xOmt = unique_ptr<OMTensor, decltype(&omTensorDestroy)>(
|
||||
omTensorCreateWithRandomData<float>({N, C, H, W}), omTensorDestroy);
|
||||
inputs.emplace_back(move(xOmt));
|
||||
auto wOmt = unique_ptr<OMTensor, decltype(&omTensorDestroy)>(
|
||||
omTensorCreateWithRandomData<float>({C, C, kH, kW}), omTensorDestroy);
|
||||
inputs.emplace_back(move(wOmt));
|
||||
|
||||
auto ref = RtMemRef::create<float>({NOut, COut, HOut, WOut});
|
||||
auto ref = omTensorCreateWithShape<float>({NOut, COut, HOut, WOut});
|
||||
auto &img = inputs.at(0);
|
||||
auto &filter = inputs.at(1);
|
||||
for (int64_t n = 0; n < NOut; n++)
|
||||
for (int64_t c = 0; c < COut; c++)
|
||||
for (int64_t h = 0; h < HOut; h++)
|
||||
for (int64_t w = 0; w < WOut; w++) {
|
||||
ref->elem<float>({n, c, h, w}) = 0;
|
||||
omTensorGetElem<float>(ref, {n, c, h, w}) = 0;
|
||||
for (int64_t ci = 0; ci < C; ci++)
|
||||
for (int64_t kh = 0; kh < kH; kh++)
|
||||
for (int64_t kw = 0; kw < kW; kw++)
|
||||
if ((h + kh - pHBegin >= 0 && h + kh - pHBegin < H) &&
|
||||
(w + kw - pWBegin >= 0 && w + kw - pWBegin < W))
|
||||
ref->elem<float>({n, c, h, w}) +=
|
||||
img->elem<float>(
|
||||
omTensorGetElem<float>(ref, {n, c, h, w}) +=
|
||||
omTensorGetElem<float>(img.get(),
|
||||
{n, ci, h + kh - pHBegin, w + kw - pWBegin}) *
|
||||
filter->elem<float>({c, ci, kh, kw});
|
||||
omTensorGetElem<float>(filter.get(), {c, ci, kh, kw});
|
||||
}
|
||||
|
||||
auto outputs = sess.run(move(inputs));
|
||||
auto &conv = outputs.at(0);
|
||||
|
||||
return isRmrClose<float>(conv.get(), ref);
|
||||
return omTensorAreTwoOmtsClose<float>(conv.get(), ref);
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
include(GoogleTest)
|
||||
|
||||
# Adapted from https://cliutils.gitlab.io/modern-cmake/chapters/testing/googletest.html.
|
||||
macro(add_unit_test TESTNAME)
|
||||
# create an exectuable in which the tests will be stored
|
||||
add_executable(${TESTNAME} ${ARGN})
|
||||
# link the Google test infrastructure, mocking library, and a default main fuction to
|
||||
# the test executable. Remove g_test_main if writing your own main function.
|
||||
target_link_libraries(${TESTNAME} gtest gmock gtest_main)
|
||||
# gtest_discover_tests replaces gtest_add_tests,
|
||||
# see https://cmake.org/cmake/help/v3.10/module/GoogleTest.html for more options to pass to it
|
||||
gtest_discover_tests(${TESTNAME}
|
||||
# set a working directory so your project root so that you can find test data via paths relative to the project root
|
||||
WORKING_DIRECTORY ${ONNX_MLIR_BIN_ROOT}
|
||||
PROPERTIES VS_DEBUGGER_WORKING_DIRECTORY "${ONNX_MLIR_BIN_ROOT}")
|
||||
endmacro()
|
||||
|
||||
macro(add_c_unit_test TESTNAME)
|
||||
add_executable(${TESTNAME} ${ARGN})
|
||||
add_test(NAME ${TESTNAME}
|
||||
COMMAND ${TESTNAME})
|
||||
endmacro()
|
||||
|
||||
add_subdirectory(Runtime)
|
|
@ -0,0 +1,7 @@
|
|||
add_subdirectory(DocExampleTest)
|
||||
|
||||
add_c_unit_test(OMTensorTest OMTensorTest.c)
|
||||
target_include_directories(OMTensorTest PRIVATE
|
||||
${ONNX_MLIR_SRC_ROOT}/include)
|
||||
target_link_libraries(OMTensorTest
|
||||
cruntime)
|
|
@ -0,0 +1,44 @@
|
|||
# Documentation example requires ONNX package installation, which has been
|
||||
# flaky on non-x86 platforms, so only perform this test on x86 arch.
|
||||
if (CMAKE_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)")
|
||||
find_package(PythonInterp 3 REQUIRED)
|
||||
|
||||
add_custom_target(OMInstallOnnx
|
||||
COMMAND ${PYTHON_EXECUTABLE} setup.py install --user
|
||||
WORKING_DIRECTORY ${ONNX_MLIR_SRC_ROOT}/third_party/onnx)
|
||||
|
||||
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/gen_add_onnx.py
|
||||
${CMAKE_CURRENT_BINARY_DIR}/gen_add_onnx.py COPYONLY)
|
||||
add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/add.onnx
|
||||
COMMAND ${PYTHON_EXECUTABLE} gen_add_onnx.py
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
|
||||
add_custom_target(OMGenerateAddModel
|
||||
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/add.onnx)
|
||||
add_dependencies(OMGenerateAddModel OMInstallOnnx)
|
||||
|
||||
add_custom_command(OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/add.so
|
||||
COMMAND onnx-mlir ${CMAKE_CURRENT_BINARY_DIR}/add.onnx
|
||||
DEPENDS OMGenerateAddModel)
|
||||
add_custom_target(OMGenerateAddLibrary
|
||||
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/add.so)
|
||||
add_custom_target(OMCopyAndRename
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
${CMAKE_CURRENT_BINARY_DIR}/add.so
|
||||
${CMAKE_CURRENT_BINARY_DIR}/library.so
|
||||
DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/add.so)
|
||||
|
||||
add_executable(OMRuntimeTest
|
||||
main.c)
|
||||
|
||||
add_library(OMRuntimeTestModel SHARED IMPORTED)
|
||||
file(TOUCH ${CMAKE_CURRENT_BINARY_DIR}/library.so)
|
||||
set_property(TARGET OMRuntimeTestModel
|
||||
PROPERTY IMPORTED_LOCATION ${CMAKE_CURRENT_BINARY_DIR}/library.so)
|
||||
target_link_libraries(OMRuntimeTest
|
||||
${CMAKE_CURRENT_BINARY_DIR}/library.so)
|
||||
target_include_directories(OMRuntimeTest
|
||||
PRIVATE ${ONNX_MLIR_SRC_ROOT}/include)
|
||||
add_dependencies(OMRuntimeTest OMCopyAndRename)
|
||||
add_test(NAME OMRuntimeTest
|
||||
COMMAND OMRuntimeTest)
|
||||
endif ()
|
|
@ -0,0 +1,33 @@
|
|||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import AttributeProto, TensorProto, GraphProto
|
||||
|
||||
# Create one input (ValueInfoProto)
|
||||
X1 = helper.make_tensor_value_info('X1', TensorProto.FLOAT, [3, 2])
|
||||
X2 = helper.make_tensor_value_info('X2', TensorProto.FLOAT, [3, 2])
|
||||
|
||||
# Create one output (ValueInfoProto)
|
||||
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [3, 2])
|
||||
|
||||
# Create a node (NodeProto) - This is based on Pad-11
|
||||
node_def = helper.make_node(
|
||||
'Add', # node name
|
||||
['X1', 'X2'], # inputs
|
||||
['Y'], # outputs
|
||||
)
|
||||
|
||||
# Create the graph (GraphProto)
|
||||
graph_def = helper.make_graph(
|
||||
[node_def],
|
||||
'test-model',
|
||||
[X1, X2],
|
||||
[Y],
|
||||
)
|
||||
|
||||
# Create the model (ModelProto)
|
||||
model_def = helper.make_model(graph_def, producer_name='onnx-example')
|
||||
|
||||
print('The model is:\n{}'.format(model_def))
|
||||
onnx.checker.check_model(model_def)
|
||||
onnx.save(model_def, "add.onnx")
|
||||
print('The model is checked!')
|
|
@ -0,0 +1,30 @@
|
|||
#include <OnnxMlirRuntime.h>
|
||||
#include <stdio.h>
|
||||
|
||||
OMTensorList *run_main_graph(OMTensorList *);
|
||||
|
||||
int main() {
|
||||
// Shared shape & rank.
|
||||
int64_t shape[] = {2, 2};
|
||||
int64_t rank = 2;
|
||||
// Construct x1 omt filled with 1.
|
||||
float x1Data[] = {1., 1., 1., 1., 1., 1.};
|
||||
int64_t *x1Shape = {2, 2};
|
||||
OMTensor *x1 = omTensorCreate(x1Data, shape, rank, ONNX_TYPE_FLOAT);
|
||||
// Construct x2 omt filled with 2.
|
||||
float x2Data[] = {2., 2., 2., 2., 2., 2.};
|
||||
int64_t *x2Shape = {2, 2};
|
||||
OMTensor *x2 = omTensorCreate(x2Data, shape, rank, ONNX_TYPE_FLOAT);
|
||||
// Construct a list of omts as input.
|
||||
OMTensor *list[2] = {x1, x2};
|
||||
OMTensorList *input = omTensorListCreate(list, 2);
|
||||
// Call the compiled onnx model function.
|
||||
OMTensorList *outputList = run_main_graph(input);
|
||||
// Get the first omt as output.
|
||||
OMTensor *y = omTensorListGetOmtByIndex(outputList, 0);
|
||||
float *outputPtr = (float *) omTensorGetDataPtr(y);
|
||||
// Print its content, should be all 3.
|
||||
for (int i = 0; i < 6; i++)
|
||||
printf("%f ", outputPtr[i]);
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
//===----------------- OMTensorTest.h - OMTensor Unit Test -----------------===//
|
||||
//
|
||||
// Copyright 2019-2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file contains declaration of OMTensor and data structures and
|
||||
// helper functions.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#include <assert.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "OnnxMlirRuntime.h"
|
||||
|
||||
void testOMTensorCtor() {
|
||||
float data[4] = {1.f, 1.f};
|
||||
int64_t shape[2] = {2, 2};
|
||||
OMTensor *tensor = omTensorCreate(data, shape, 2, ONNX_TYPE_FLOAT);
|
||||
assert(tensor);
|
||||
|
||||
int64_t* shape_ptr = omTensorGetDataShape(tensor);
|
||||
assert(shape_ptr);
|
||||
assert(shape_ptr[0] == 2);
|
||||
assert(shape_ptr[1] == 2);
|
||||
|
||||
int64_t* strides_ptr = omTensorGetStrides(tensor);
|
||||
assert(strides_ptr);
|
||||
assert(strides_ptr[0] == 2);
|
||||
assert(strides_ptr[1] == 1);
|
||||
}
|
||||
|
||||
int main() {
|
||||
testOMTensorCtor();
|
||||
return 0;
|
||||
}
|
Loading…
Reference in New Issue