Support encoding data type infomration as part of the DMR struct. (#178)

* Support encoding data type infomration as part of the DMR struct.

* Support full range of np types.

* Report error when encountering unsupported type.

* Add gerRank method API.

* Add missing API declarations.

* DynMemRef -> RtMemRef

* Format code.

* Missed DynMemRef -> RtMemRef conversion.

* More comments for RMR, and rename variable names from dmr -> rmr.

* DynMemRef -> RtMemRef.

* Format code.
This commit is contained in:
Tian Jin 2020-06-30 10:58:21 +08:00 committed by GitHub
parent f9cb113a84
commit e902506ee5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 262 additions and 132 deletions

View File

@ -1,13 +1,13 @@
# Create shared libcruntime.so since model.so linkage for backend tests # Create shared libcruntime.so since model.so linkage for backend tests
# will fail on x86 Linux if cruntime is statically linked. # will fail on x86 Linux if cruntime is statically linked.
add_library(cruntime STATIC add_library(cruntime STATIC
DynMemRef.cpp RtMemRef.cpp
DynMemRef.h RtMemRef.h
DataType.h) DataType.h)
add_library(DynMemRefUtils add_library(RtMemRefUtils
DynMemRef.h RtMemRef.h
DynMemRef.cpp RtMemRef.cpp
DataType.h) DataType.h)
add_library(ExecutionSession add_library(ExecutionSession
@ -28,7 +28,8 @@ pybind11_add_module(PyRuntime
target_link_libraries(PyRuntime PRIVATE target_link_libraries(PyRuntime PRIVATE
${CMAKE_DL_LIBS} ${CMAKE_DL_LIBS}
ExecutionSession ExecutionSession
DynMemRefUtils) RtMemRefUtils
onnx)
target_include_directories(PyRuntime PRIVATE target_include_directories(PyRuntime PRIVATE
${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_SRC_ROOT}
${ONNX_MLIR_BIN_ROOT} ${ONNX_MLIR_BIN_ROOT}
@ -41,6 +42,6 @@ set_target_properties(EmbeddedDataLoader PROPERTIES
POSITION_INDEPENDENT_CODE TRUE) POSITION_INDEPENDENT_CODE TRUE)
add_dependencies(PyRuntime cruntime) add_dependencies(PyRuntime cruntime)
install(FILES DynMemRef.h DESTINATION include) install(FILES RtMemRef.h DESTINATION include)
install(TARGETS cruntime DESTINATION lib) install(TARGETS cruntime DESTINATION lib)
install(TARGETS EmbeddedDataLoader DESTINATION lib) install(TARGETS EmbeddedDataLoader DESTINATION lib)

View File

@ -42,20 +42,19 @@ ExecutionSession::ExecutionSession(
} }
} }
std::vector<std::unique_ptr<DynMemRef>> ExecutionSession::run( std::vector<std::unique_ptr<RtMemRef>> ExecutionSession::run(
std::vector<std::unique_ptr<DynMemRef>> ins) { std::vector<std::unique_ptr<RtMemRef>> ins) {
auto *wrappedInput = createOrderedDynMemRefDict(); auto *wrappedInput = createOrderedRtMemRefDict();
for (size_t i = 0; i < ins.size(); i++) for (size_t i = 0; i < ins.size(); i++)
setDynMemRef(wrappedInput, i, ins.at(i).get()); setRtMemRef(wrappedInput, i, ins.at(i).get());
auto *wrappedOutput = _entryPointFunc(wrappedInput); auto *wrappedOutput = _entryPointFunc(wrappedInput);
std::vector<std::unique_ptr<DynMemRef>> outs; std::vector<std::unique_ptr<RtMemRef>> outs;
auto outputSize = getSize(wrappedOutput); auto outputSize = getSize(wrappedOutput);
for (size_t i = 0; i < getSize(wrappedOutput); i++) { for (size_t i = 0; i < getSize(wrappedOutput); i++) {
outs.emplace_back( outs.emplace_back(std::unique_ptr<RtMemRef>(getRtMemRef(wrappedOutput, i)));
std::unique_ptr<DynMemRef>(getDynMemRef(wrappedOutput, i)));
} }
return std::move(outs); return std::move(outs);
} }

View File

@ -15,18 +15,18 @@
#include <dlfcn.h> #include <dlfcn.h>
#include <string> #include <string>
#include "src/Runtime/DynMemRef.h" #include "src/Runtime/RtMemRef.h"
namespace onnx_mlir { namespace onnx_mlir {
typedef OrderedDynMemRefDict *(*entryPointFuncType)(OrderedDynMemRefDict *); typedef OrderedRtMemRefDict *(*entryPointFuncType)(OrderedRtMemRefDict *);
class ExecutionSession { class ExecutionSession {
public: public:
ExecutionSession(std::string sharedLibPath, std::string entryPointName); ExecutionSession(std::string sharedLibPath, std::string entryPointName);
std::vector<std::unique_ptr<DynMemRef>> run( std::vector<std::unique_ptr<RtMemRef>> run(
std::vector<std::unique_ptr<DynMemRef>>); std::vector<std::unique_ptr<RtMemRef>>);
~ExecutionSession(); ~ExecutionSession();

View File

@ -9,6 +9,9 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "onnx/onnx_pb.h"
#include <third_party/onnx/onnx/onnx_pb.h>
#include "PyExecutionSession.hpp" #include "PyExecutionSession.hpp"
namespace onnx_mlir { namespace onnx_mlir {
@ -16,40 +19,72 @@ namespace onnx_mlir {
std::vector<py::array> PyExecutionSession::pyRun( std::vector<py::array> PyExecutionSession::pyRun(
std::vector<py::array> inputsPyArray) { std::vector<py::array> inputsPyArray) {
assert(_entryPointFunc && "Entry point not loaded."); assert(_entryPointFunc && "Entry point not loaded.");
auto *wrappedInput = createOrderedDynMemRefDict(); auto *wrappedInput = createOrderedRtMemRefDict();
int inputIdx = 0; int inputIdx = 0;
for (auto inputPyArray : inputsPyArray) { for (auto inputPyArray : inputsPyArray) {
auto *inputDynMemRef = createDynMemRef(inputPyArray.ndim()); auto *inputRtMemRef = createRtMemRef(inputPyArray.ndim());
assert(inputPyArray.flags() && py::array::c_style && assert(inputPyArray.flags() && py::array::c_style &&
"Expect contiguous python array."); "Expect contiguous python array.");
if (inputPyArray.writeable()) { if (inputPyArray.writeable()) {
inputDynMemRef->data = inputPyArray.mutable_data(); inputRtMemRef->data = inputPyArray.mutable_data();
inputDynMemRef->alignedData = inputPyArray.mutable_data(); inputRtMemRef->alignedData = inputPyArray.mutable_data();
} else { } else {
// If data is not writable, copy them to a writable buffer. // If data is not writable, copy them to a writable buffer.
auto *copiedData = (float *)malloc(inputPyArray.nbytes()); auto *copiedData = (float *)malloc(inputPyArray.nbytes());
memcpy(copiedData, inputPyArray.data(), inputPyArray.nbytes()); memcpy(copiedData, inputPyArray.data(), inputPyArray.nbytes());
inputDynMemRef->data = copiedData; inputRtMemRef->data = copiedData;
inputDynMemRef->alignedData = copiedData; inputRtMemRef->alignedData = copiedData;
} }
for (int i = 0; i < inputPyArray.ndim(); i++) { for (int i = 0; i < inputPyArray.ndim(); i++) {
inputDynMemRef->sizes[i] = inputPyArray.shape(i); inputRtMemRef->sizes[i] = inputPyArray.shape(i);
inputDynMemRef->strides[i] = inputPyArray.strides(i); inputRtMemRef->strides[i] = inputPyArray.strides(i);
} }
setDynMemRef(wrappedInput, inputIdx++, inputDynMemRef); setRtMemRef(wrappedInput, inputIdx++, inputRtMemRef);
} }
std::vector<py::array> outputPyArrays; std::vector<py::array> outputPyArrays;
auto *wrappedOutput = _entryPointFunc(wrappedInput); auto *wrappedOutput = _entryPointFunc(wrappedInput);
for (int i = 0; i < numDynMemRefs(wrappedOutput); i++) { for (int i = 0; i < numRtMemRefs(wrappedOutput); i++) {
auto *dynMemRef = getDynMemRef(wrappedOutput, i); auto *dynMemRef = getRtMemRef(wrappedOutput, i);
auto shape = std::vector<int64_t>( auto shape = std::vector<int64_t>(
dynMemRef->sizes, dynMemRef->sizes + dynMemRef->rank); dynMemRef->sizes, dynMemRef->sizes + dynMemRef->rank);
outputPyArrays.emplace_back(
py::array(py::dtype("float32"), shape, dynMemRef->data)); // https://numpy.org/devdocs/user/basics.types.html
py::dtype dtype;
if (dynMemRef->onnx_dtype == onnx::TensorProto::FLOAT)
dtype = py::dtype("float32");
else if (dynMemRef->onnx_dtype = onnx::TensorProto::UINT8)
dtype = py::dtype("uint8");
else if (dynMemRef->onnx_dtype = onnx::TensorProto::INT8)
dtype = py::dtype("int8");
else if (dynMemRef->onnx_dtype = onnx::TensorProto::UINT16)
dtype = py::dtype("uint16");
else if (dynMemRef->onnx_dtype = onnx::TensorProto::INT16)
dtype = py::dtype("int16");
else if (dynMemRef->onnx_dtype == onnx::TensorProto::INT32)
dtype = py::dtype("int32");
else if (dynMemRef->onnx_dtype == onnx::TensorProto::INT64)
dtype = py::dtype("int64");
// TODO(tjingrant) wait for Tong's input for how to represent string.
else if (dynMemRef->onnx_dtype = onnx::TensorProto::BOOL)
dtype = py::dtype("bool_");
else if (dynMemRef->onnx_dtype = onnx::TensorProto::FLOAT16)
dtype = py::dtype("float32");
else if (dynMemRef->onnx_dtype = onnx::TensorProto::DOUBLE)
dtype = py::dtype("float64");
else if (dynMemRef->onnx_dtype == onnx::TensorProto::UINT32)
dtype = py::dtype("uint32");
else if (dynMemRef->onnx_dtype == onnx::TensorProto::UINT64)
dtype = py::dtype("uint64");
else {
fprintf(stderr, "Unsupported ONNX type in RtMemRef.onnx_dtype.");
exit(1);
}
outputPyArrays.emplace_back(py::array(dtype, shape, dynMemRef->data));
} }
return outputPyArrays; return outputPyArrays;

View File

@ -1,4 +1,4 @@
//===----------- DynMemRef.cpp - Dynamic MemRef Implementation ------------===// //===----------- RtMemRef.cpp - Dynamic MemRef Implementation ------------===//
// //
// Copyright 2019-2020 The IBM Research Authors. // Copyright 2019-2020 The IBM Research Authors.
// //
@ -18,7 +18,7 @@
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include "DynMemRef.h" #include "RtMemRef.h"
namespace { namespace {
// Helper function to compute cartisian product. // Helper function to compute cartisian product.
@ -39,19 +39,19 @@ inline std::vector<std::vector<INDEX_TYPE>> CartProduct(
} }
} // namespace } // namespace
DynMemRef::DynMemRef(int _rank) { RtMemRef::RtMemRef(int _rank) {
rank = _rank; rank = _rank;
sizes = (INDEX_TYPE *)malloc(rank * sizeof(INDEX_TYPE)); sizes = (INDEX_TYPE *)malloc(rank * sizeof(INDEX_TYPE));
strides = (int64_t *)malloc(rank * sizeof(int64_t)); strides = (int64_t *)malloc(rank * sizeof(int64_t));
} }
INDEX_TYPE DynMemRef::size() const { INDEX_TYPE RtMemRef::size() const {
return std::accumulate(sizes, sizes + rank, 1, std::multiplies<>()); return std::accumulate(sizes, sizes + rank, 1, std::multiplies<>());
} }
std::vector<std::vector<INDEX_TYPE>> DynMemRef::indexSet() const { std::vector<std::vector<INDEX_TYPE>> RtMemRef::indexSet() const {
// First, we create index set of each dimension separately. // First, we create index set of each dimension separately.
// i.e., for a tensor/DMR of shape (2, 3), its dimWiseIdxSet will be: // i.e., for a tensor/RMR of shape (2, 3), its dimWiseIdxSet will be:
// {{0,1}, {0,1,2}}; // {{0,1}, {0,1,2}};
std::vector<std::vector<INDEX_TYPE>> dimWiseIdxSet; std::vector<std::vector<INDEX_TYPE>> dimWiseIdxSet;
for (auto dimSize : std::vector<INDEX_TYPE>(sizes, sizes + rank)) { for (auto dimSize : std::vector<INDEX_TYPE>(sizes, sizes + rank)) {
@ -60,18 +60,18 @@ std::vector<std::vector<INDEX_TYPE>> DynMemRef::indexSet() const {
dimWiseIdxSet.emplace_back(dimIdxSet); dimWiseIdxSet.emplace_back(dimIdxSet);
} }
// Then, the cartesian product of vectors within dimWiseIdxSet will be the // Then, the cartesian product of vectors within dimWiseIdxSet will be the
// index set for the whole DMR. // index set for the whole RMR.
return CartProduct(dimWiseIdxSet); return CartProduct(dimWiseIdxSet);
} }
INDEX_TYPE DynMemRef::computeOffset(std::vector<INDEX_TYPE> &idxs) const { INDEX_TYPE RtMemRef::computeOffset(std::vector<INDEX_TYPE> &idxs) const {
auto dimStrides = std::vector<INDEX_TYPE>(strides, strides + rank); auto dimStrides = std::vector<INDEX_TYPE>(strides, strides + rank);
INDEX_TYPE elemOffset = std::inner_product( INDEX_TYPE elemOffset = std::inner_product(
idxs.begin(), idxs.end(), dimStrides.begin(), (INDEX_TYPE)0); idxs.begin(), idxs.end(), dimStrides.begin(), (INDEX_TYPE)0);
return elemOffset; return elemOffset;
} }
std::vector<int64_t> DynMemRef::computeStridesFromSizes() const { std::vector<int64_t> RtMemRef::computeStridesFromSizes() const {
// Shift dimension sizes one to the left, fill in the vacated rightmost // Shift dimension sizes one to the left, fill in the vacated rightmost
// element with 1; this gets us a vector that'll be more useful for computing // element with 1; this gets us a vector that'll be more useful for computing
// strides of memory access along each dimension using prefix product (aka // strides of memory access along each dimension using prefix product (aka
@ -86,7 +86,7 @@ std::vector<int64_t> DynMemRef::computeStridesFromSizes() const {
return dimStrides; return dimStrides;
} }
DynMemRef::~DynMemRef() { RtMemRef::~RtMemRef() {
free(data); free(data);
free(sizes); free(sizes);
free(strides); free(strides);
@ -95,27 +95,26 @@ DynMemRef::~DynMemRef() {
// An ordered dynamic MemRef dictionary. // An ordered dynamic MemRef dictionary.
// The goal is to support accessing dynamic memory ref by name and by index. // The goal is to support accessing dynamic memory ref by name and by index.
// Currently, only accessing by index is supported. // Currently, only accessing by index is supported.
struct OrderedDynMemRefDict { struct OrderedRtMemRefDict {
std::map<std::string, DynMemRef *> tensorDict; std::map<std::string, RtMemRef *> tensorDict;
std::vector<std::string> orderedNames; std::vector<std::string> orderedNames;
}; };
int numDynMemRefs(OrderedDynMemRefDict *dict) { int numRtMemRefs(OrderedRtMemRefDict *dict) {
return dict->orderedNames.size(); return dict->orderedNames.size();
} }
OrderedDynMemRefDict *createOrderedDynMemRefDict() { OrderedRtMemRefDict *createOrderedRtMemRefDict() {
return new OrderedDynMemRefDict(); return new OrderedRtMemRefDict();
} }
DynMemRef *createDynMemRef(int rank) { return new DynMemRef(rank); } RtMemRef *createRtMemRef(int rank) { return new RtMemRef(rank); }
DynMemRef *getDynMemRef(OrderedDynMemRefDict *tensorDict, int idx) { RtMemRef *getRtMemRef(OrderedRtMemRefDict *tensorDict, int idx) {
return tensorDict->tensorDict[tensorDict->orderedNames[idx]]; return tensorDict->tensorDict[tensorDict->orderedNames[idx]];
} }
void setDynMemRef( void setRtMemRef(OrderedRtMemRefDict *tensorDict, int idx, RtMemRef *tensor) {
OrderedDynMemRefDict *tensorDict, int idx, DynMemRef *tensor) {
if (tensorDict->orderedNames.size() <= idx) if (tensorDict->orderedNames.size() <= idx)
tensorDict->orderedNames.resize(idx + 1); tensorDict->orderedNames.resize(idx + 1);
@ -130,30 +129,36 @@ void setDynMemRef(
tensorDict->tensorDict[tensorDict->orderedNames[idx]] = tensor; tensorDict->tensorDict[tensorDict->orderedNames[idx]] = tensor;
} }
void *getData(DynMemRef *dynMemRef) { return dynMemRef->data; } void *getData(RtMemRef *dynMemRef) { return dynMemRef->data; }
void setData(DynMemRef *dynMemRef, void *dataPtr) { dynMemRef->data = dataPtr; } void setData(RtMemRef *dynMemRef, void *dataPtr) { dynMemRef->data = dataPtr; }
void *getAlignedData(DynMemRef *dynMemRef) { return dynMemRef->alignedData; } void *getAlignedData(RtMemRef *dynMemRef) { return dynMemRef->alignedData; }
void setAlignedData(DynMemRef *dynMemRef, void *dataPtr) { void setAlignedData(RtMemRef *dynMemRef, void *dataPtr) {
dynMemRef->alignedData = dataPtr; dynMemRef->alignedData = dataPtr;
} }
INDEX_TYPE *getSizes(DynMemRef *dynMemRef) { return dynMemRef->sizes; } INDEX_TYPE *getSizes(RtMemRef *dynMemRef) { return dynMemRef->sizes; }
void setSizes(DynMemRef *dynMemRef, INDEX_TYPE *sizes) { void setSizes(RtMemRef *dynMemRef, INDEX_TYPE *sizes) {
for (int i = 0; i < dynMemRef->rank; i++) for (int i = 0; i < dynMemRef->rank; i++)
dynMemRef->sizes[i] = sizes[i]; dynMemRef->sizes[i] = sizes[i];
} }
int64_t *getStrides(DynMemRef *dynMemRef) { return dynMemRef->strides; } int64_t *getStrides(RtMemRef *dynMemRef) { return dynMemRef->strides; }
int64_t getSize(OrderedDynMemRefDict *dict) { int64_t getSize(OrderedRtMemRefDict *dict) { return dict->orderedNames.size(); }
return dict->orderedNames.size();
void setDType(RtMemRef *dynMemRef, int onnxType) {
dynMemRef->onnx_dtype = onnxType;
} }
void setStrides(DynMemRef *dynMemRef, int64_t *strides) { int getDType(RtMemRef *dynMemRef) { return dynMemRef->onnx_dtype; }
unsigned int getRank(RtMemRef *dynMemRef) { return dynMemRef->rank; }
void setStrides(RtMemRef *dynMemRef, int64_t *strides) {
for (int i = 0; i < dynMemRef->rank; i++) for (int i = 0; i < dynMemRef->rank; i++)
dynMemRef->sizes[i] = strides[i]; dynMemRef->sizes[i] = strides[i];
} }

View File

@ -1,4 +1,4 @@
//===------------ DynMemRef.h - Dynamic MemRef Implementation -------------===// //===------------ RtMemRef.h - Dynamic MemRef Implementation -------------===//
// //
// Copyright 2019-2020 The IBM Research Authors. // Copyright 2019-2020 The IBM Research Authors.
// //
@ -24,40 +24,69 @@
typedef int64_t INDEX_TYPE; typedef int64_t INDEX_TYPE;
// This is a dynamic version of memref. // Typically, MemRefs in MLIR context are used as a compile-time constructs.
// The same struct can be used to represent memrefs of // Information such as element type and rank of the data payload is statically
// all ranks and type combinations. // encoded, meaning that they are determined and fixed at compile-time. This
// We will refer to it as a DMR (Dynamic MemRef). // presents significant burden for any runtime components trying to interact
struct DynMemRef { // with the compiled executable.
//
// Thus a version of MemRef struct that is amenable to runtime manipulation is
// provided as a basis for building any runtime-related components providing
// user-facing programming interfaces. All information are dynamically encoded
// as members of this struct so that they can be accessed and modified easily
// during runtime.
//
// We will refer to it as a RMF (Runtime MemRef).
struct RtMemRef {
// Pointer to the raw memory space allocated to host the RMR content. This
// pointer should only be acessed for memory management purposes, not for
// reading RMR content.
void *data; void *data;
// Pointer to the properly aligned array of elements stored in this Rmr.
void *alignedData; void *alignedData;
// Distance between the start of the raw memory space and the first element of
// the RMR content.
INDEX_TYPE offset; INDEX_TYPE offset;
// Number of dimensions of the array represented by the RMR.
unsigned int rank; unsigned int rank;
// An array recording the per-dimension sizes of the array represented by the
// RMR.
INDEX_TYPE *sizes; INDEX_TYPE *sizes;
// An array recording the per-dimension strides of the array represented by
// the RMR.
int64_t *strides; int64_t *strides;
// Refer to TensorProto_DataType at
// https://github.com/onnx/onnx/blob/cc2230603422bae893d5bc900d2d773ab34400a4/onnx/onnx-ml.proto#L451
// for enum value interpretation.
unsigned int onnx_dtype;
#ifdef __cplusplus #ifdef __cplusplus
explicit DynMemRef(int _rank); explicit RtMemRef(int _rank);
// Create a full DMR of type T and shape _sizes, with all data fields // Create a full RMR of type T and shape _sizes, with all data fields
// initialized to proper values and data pointers malloc'ed. // initialized to proper values and data pointers malloc'ed.
template <typename T> template <typename T>
static DynMemRef *create(std::vector<INDEX_TYPE> _sizes) { static RtMemRef *create(std::vector<INDEX_TYPE> _sizes) {
auto dmr = new DynMemRef(_sizes.size()); auto rmr = new RtMemRef(_sizes.size());
dmr->offset = 0; rmr->offset = 0;
dmr->rank = _sizes.size(); rmr->rank = _sizes.size();
dmr->sizes = (INDEX_TYPE *)malloc(dmr->rank * sizeof(INDEX_TYPE)); rmr->sizes = (INDEX_TYPE *)malloc(rmr->rank * sizeof(INDEX_TYPE));
std::copy(_sizes.begin(), _sizes.end(), dmr->sizes); std::copy(_sizes.begin(), _sizes.end(), rmr->sizes);
dmr->strides = (int64_t *)malloc(dmr->rank * sizeof(int64_t)); rmr->strides = (int64_t *)malloc(rmr->rank * sizeof(int64_t));
auto computedStrides = dmr->computeStridesFromSizes(); auto computedStrides = rmr->computeStridesFromSizes();
std::copy(computedStrides.begin(), computedStrides.end(), dmr->strides); std::copy(computedStrides.begin(), computedStrides.end(), rmr->strides);
dmr->data = malloc(dmr->size() * sizeof(T)); rmr->data = malloc(rmr->size() * sizeof(T));
dmr->alignedData = dmr->data; rmr->alignedData = rmr->data;
return dmr; return rmr;
} }
// Access an element (by reference) at index position idxs. // Access an element (by reference) at index position idxs.
@ -75,13 +104,13 @@ struct DynMemRef {
return typedPtr[idx]; return typedPtr[idx];
} }
// Get a typed ptr to the data content of the DMR. // Get a typed ptr to the data content of the RMR.
template <typename T> template <typename T>
T *typedPtr() { T *typedPtr() {
return (T *)data; return (T *)data;
} }
// Get how many elements are stored in DMR, as implied by its shape. // Get how many elements are stored in RMR, as implied by its shape.
INDEX_TYPE size() const; INDEX_TYPE size() const;
// Helper function to compute strides of access along each dimensions from its // Helper function to compute strides of access along each dimensions from its
@ -92,63 +121,77 @@ struct DynMemRef {
INDEX_TYPE computeOffset(std::vector<INDEX_TYPE> &idxs) const; INDEX_TYPE computeOffset(std::vector<INDEX_TYPE> &idxs) const;
// Get the index set (i.e., all valid multi-dimensional array indexes that can // Get the index set (i.e., all valid multi-dimensional array indexes that can
// be used to access this DMR's constituent elements). // be used to access this RMR's constituent elements).
std::vector<std::vector<INDEX_TYPE>> indexSet() const; std::vector<std::vector<INDEX_TYPE>> indexSet() const;
~DynMemRef(); ~RtMemRef();
#endif #endif
}; };
#ifdef __cplusplus #ifdef __cplusplus
// Ordered DynMemRef Dictionary is a data structure for wrapping the input // Ordered RtMemRef Dictionary is a data structure for wrapping the input
// dynmemrefs so that they can be addressed both by index and by name. // dynmemrefs so that they can be addressed both by index and by name.
struct OrderedDynMemRefDict; struct OrderedRtMemRefDict;
#else #else
typedef struct DynMemRef DynMemRef; typedef struct RtMemRef RtMemRef;
typedef struct _OrderedDynMemRefDict OrderedDynMemRefDict; typedef struct _OrderedRtMemRefDict OrderedRtMemRefDict;
#endif #endif
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
// Get number of dynamic memrefs in OrderedDynMemRefDict dict. // Get number of dynamic memrefs in OrderedRtMemRefDict dict.
int numDynMemRefs(OrderedDynMemRefDict *dict); int numRtMemRefs(OrderedRtMemRefDict *dict);
// Create an ordered dynamic memref dictionary. // Create an ordered dynamic memref dictionary.
OrderedDynMemRefDict *createOrderedDynMemRefDict(); OrderedRtMemRefDict *createOrderedRtMemRefDict();
// Get how many dynamic memrefs are in dict. // Get how many dynamic memrefs are in dict.
int64_t getSize(OrderedDynMemRefDict *dict); int64_t getSize(OrderedRtMemRefDict *dict);
// Create a dynmemref with a certain rank. // Create a dynmemref with a certain rank.
DynMemRef *createDynMemRef(int rank); RtMemRef *createRtMemRef(int rank);
// Get the i-th dynmemref from orderedDict. // Get the i-th dynmemref from orderedDict.
DynMemRef *getDynMemRef(OrderedDynMemRefDict *orderedDict, int i); RtMemRef *getRtMemRef(OrderedRtMemRefDict *orderedDict, int i);
// Set the i-th dynmemref in orderedDict to be dynMemRef. // Set the i-th dynmemref in orderedDict to be dynMemRef.
void setDynMemRef( void setRtMemRef(OrderedRtMemRefDict *tensorDict, int idx, RtMemRef *dynMemRef);
OrderedDynMemRefDict *tensorDict, int idx, DynMemRef *dynMemRef);
// Get data pointer from dynMemRef. // Get data pointer from dynMemRef.
void *getData(DynMemRef *dynMemRef); void *getData(RtMemRef *dynMemRef);
// Set data pointer for dynMemRef. // Set data pointer for dynMemRef.
void setData(DynMemRef *dynMemRef, void *data); void setData(RtMemRef *dynMemRef, void *data);
// Get algined data pointer from dynMemRef. // Get algined data pointer from dynMemRef.
void *getAlignedData(DynMemRef *); void *getAlignedData(RtMemRef *);
// Set aligned data pointer for dynMemRef. // Set aligned data pointer for dynMemRef.
void setAlignedData(DynMemRef *, void *); void setAlignedData(RtMemRef *, void *);
// Get the data type enum value of the dynMemRef.
int getDType(RtMemRef *dynMemRef);
// Set the data type enum value of the dynMemRef.
void setDType(RtMemRef *dynMemRef, int onnxType);
// Get the rank of the dynMemRef.
unsigned int getRank(RtMemRef *dynMemRef);
// Get ptr to sizes array. // Get ptr to sizes array.
INDEX_TYPE *getSizes(DynMemRef *); INDEX_TYPE *getSizes(RtMemRef *);
// Set the sizes array (by copying size values from array `sizes`).
void setSizes(RtMemRef *, INDEX_TYPE *sizes);
// Get ptr to strides array. // Get ptr to strides array.
int64_t *getStrides(DynMemRef *); int64_t *getStrides(RtMemRef *);
// Set the strides array (by copying stride values from array `strides`).
void setStrides(RtMemRef *, int64_t *strides);
#ifdef __cplusplus #ifdef __cplusplus
} }
@ -164,22 +207,22 @@ void printVector(std::vector<T> vec, std::string _delimiter = ",",
} }
template <typename T> template <typename T>
DynMemRef *getRndRealDmr( RtMemRef *getRndRealRmr(
std::vector<INDEX_TYPE> sizes, T lb = -1.0, T ub = 1.0) { std::vector<INDEX_TYPE> sizes, T lb = -1.0, T ub = 1.0) {
// Will be used to obtain a seed for the random number engine // Will be used to obtain a seed for the random number engine
std::random_device rd; std::random_device rd;
// Standard mersenne_twister_engine seeded with rd() // Standard mersenne_twister_engine seeded with rd()
std::mt19937 gen(rd()); std::mt19937 gen(rd());
std::uniform_real_distribution<> dis(lb, ub); std::uniform_real_distribution<> dis(lb, ub);
auto dmr = DynMemRef::create<T>(sizes); auto rmr = RtMemRef::create<T>(sizes);
auto ptr = (T *)dmr->data; auto ptr = (T *)rmr->data;
std::generate(ptr, ptr + dmr->size(), [&]() { return dis(gen); }); std::generate(ptr, ptr + rmr->size(), [&]() { return dis(gen); });
return dmr; return rmr;
} }
template <typename T> template <typename T>
inline bool isDmrClose( inline bool isRmrClose(
DynMemRef *a, DynMemRef *b, float rtol = 1e-5, float atol = 1e-5) { RtMemRef *a, RtMemRef *b, float rtol = 1e-5, float atol = 1e-5) {
// Compare shape. // Compare shape.
auto aShape = std::vector<INDEX_TYPE>(a->sizes, a->sizes + a->rank); auto aShape = std::vector<INDEX_TYPE>(a->sizes, a->sizes + a->rank);
@ -232,3 +275,6 @@ inline bool isDmrClose(
} }
} }
#endif #endif
// Will transition from RtMemRef to RtMemRef soon.
typedef RtMemRef RtMemRef;

View File

@ -19,6 +19,7 @@
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "onnx/onnx_pb.h"
#include "llvm/ADT/Sequence.h" #include "llvm/ADT/Sequence.h"
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
@ -29,6 +30,38 @@ using namespace mlir;
namespace { namespace {
static onnx::TensorProto::DataType llvmTypeToOnnxType(
mlir::LLVM::LLVMType elemType) {
if (elemType.isFloatTy())
return onnx::TensorProto::FLOAT;
if (elemType.isUnsignedInteger(8))
return onnx::TensorProto::UINT8;
if (elemType.isSignedInteger(8))
return onnx::TensorProto::INT8;
if (elemType.isUnsignedInteger(16))
return onnx::TensorProto::UINT16;
if (elemType.isSignedInteger(16))
return onnx::TensorProto::INT16;
if (elemType.isSignedInteger(32))
return onnx::TensorProto::INT32;
if (elemType.isSignedInteger(64))
return onnx::TensorProto::INT64;
// TODO, wait for Tong's input about how string is represented in MLIR.
if (elemType.isInteger(1))
return onnx::TensorProto::BOOL;
if (elemType.isHalfTy())
return onnx::TensorProto::FLOAT16;
if (elemType.isDoubleTy())
return onnx::TensorProto::DOUBLE;
if (elemType.isUnsignedInteger(32))
return onnx::TensorProto::UINT32;
if (elemType.isUnsignedInteger(64))
return onnx::TensorProto::INT64;
// Complex types don't seem to exist in LLVM Dialect.
elemType.dump();
llvm_unreachable("Unexpected LLVM type, cannot be converted to ONNX type.");
}
static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName, static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName,
ModuleOp module, mlir::LLVM::LLVMType funcType, PatternRewriter &rewriter) { ModuleOp module, mlir::LLVM::LLVMType funcType, PatternRewriter &rewriter) {
auto *context = module.getContext(); auto *context = module.getContext();
@ -374,6 +407,8 @@ public:
SET_DATA, SET_DATA,
GET_SIZES, GET_SIZES,
GET_STRIDES, GET_STRIDES,
SET_DTYPE,
GET_DTYPE,
}; };
struct ApiSpec { struct ApiSpec {
@ -469,7 +504,7 @@ public:
// Fill in the memref underlying ptrToMemRef with information extracted // Fill in the memref underlying ptrToMemRef with information extracted
// from dynMemRef. // from dynMemRef.
fillPtrToMemRefWithDynMemRef( fillPtrToMemRefWithRtMemRef(
dynMemRef, ptrToMemRef, rewriter, loc, apiRegistry, llvmDialect); dynMemRef, ptrToMemRef, rewriter, loc, apiRegistry, llvmDialect);
// ptrToMemRef will be an input to main computation graph function. // ptrToMemRef will be an input to main computation graph function.
@ -518,14 +553,14 @@ public:
auto outMemRefRank = getRankFromMemRefType(outMemRefTy); auto outMemRefRank = getRankFromMemRefType(outMemRefTy);
auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>( auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank)); loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
auto outDynMemRef = callApi(rewriter, loc, apiRegistry, auto outRtMemRef = callApi(rewriter, loc, apiRegistry,
API::CREATE_DYN_MEM_REF, {outMemRefRankVal}); API::CREATE_DYN_MEM_REF, {outMemRefRankVal});
fillDynMemRefWithMemRef( fillRtMemRefWithMemRef(
memRef, outDynMemRef, rewriter, loc, apiRegistry, llvmDialect); memRef, outRtMemRef, rewriter, loc, apiRegistry, llvmDialect);
auto idx = rewriter.create<LLVM::ConstantOp>( auto idx = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(i)); loc, int32Ty, rewriter.getI32IntegerAttr(i));
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF, callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
{wrappedOutput, idx, outDynMemRef}); {wrappedOutput, idx, outRtMemRef});
} }
// Return wrapped output. // Return wrapped output.
rewriter.create<LLVM::ReturnOp>( rewriter.create<LLVM::ReturnOp>(
@ -549,14 +584,16 @@ private:
// specifying its signature. // specifying its signature.
// clang-format off // clang-format off
std::vector<ApiSpec> apiSpecs = { std::vector<ApiSpec> apiSpecs = {
ApiSpec(API::CREATE_ORDERED_DYN_MEM_REF_DICT, "createOrderedDynMemRefDict", opaquePtrTy, {}), ApiSpec(API::CREATE_ORDERED_DYN_MEM_REF_DICT, "createOrderedRtMemRefDict", opaquePtrTy, {}),
ApiSpec(API::CREATE_DYN_MEM_REF, "createDynMemRef", opaquePtrTy, {int32Ty}), ApiSpec(API::CREATE_DYN_MEM_REF, "createRtMemRef", opaquePtrTy, {int32Ty}),
ApiSpec(API::GET_DATA, "getData", opaquePtrTy, {opaquePtrTy}), ApiSpec(API::GET_DATA, "getData", opaquePtrTy, {opaquePtrTy}),
ApiSpec(API::SET_DATA, "setData", voidTy, {opaquePtrTy, opaquePtrTy}), ApiSpec(API::SET_DATA, "setData", voidTy, {opaquePtrTy, opaquePtrTy}),
ApiSpec(API::GET_DYN_MEM_REF, "getDynMemRef", opaquePtrTy, {opaquePtrTy, int32Ty}), ApiSpec(API::GET_DYN_MEM_REF, "getRtMemRef", opaquePtrTy, {opaquePtrTy, int32Ty}),
ApiSpec(API::SET_DYN_MEM_REF, "setDynMemRef", voidTy, {opaquePtrTy, int32Ty, opaquePtrTy}), ApiSpec(API::SET_DYN_MEM_REF, "setRtMemRef", voidTy, {opaquePtrTy, int32Ty, opaquePtrTy}),
ApiSpec(API::GET_SIZES, "getSizes", int64PtrTy, {opaquePtrTy}), ApiSpec(API::GET_SIZES, "getSizes", int64PtrTy, {opaquePtrTy}),
ApiSpec(API::GET_STRIDES, "getStrides", int64PtrTy, {opaquePtrTy}) ApiSpec(API::GET_STRIDES, "getStrides", int64PtrTy, {opaquePtrTy}),
ApiSpec(API::GET_DTYPE, "getDType", int32Ty, {opaquePtrTy}),
ApiSpec(API::SET_DTYPE, "setDType", voidTy, {opaquePtrTy, int32Ty}),
}; };
// clang-format on // clang-format on
@ -598,7 +635,7 @@ private:
return *entryPointEntryBlock; return *entryPointEntryBlock;
} }
void fillPtrToMemRefWithDynMemRef(Value &dynMemRef, Value &ptrToMemRef, void fillPtrToMemRefWithRtMemRef(Value &dynMemRef, Value &ptrToMemRef,
PatternRewriter &rewriter, const Location &loc, PatternRewriter &rewriter, const Location &loc,
const std::map<API, ApiSpec> &apiRegistry, const std::map<API, ApiSpec> &apiRegistry,
LLVM::LLVMDialect *llvmDialect) const { LLVM::LLVMDialect *llvmDialect) const {
@ -659,12 +696,13 @@ private:
rewriter.create<LLVM::StoreOp>(loc, memRef, ptrToMemRef); rewriter.create<LLVM::StoreOp>(loc, memRef, ptrToMemRef);
} }
void fillDynMemRefWithMemRef(Value &outMemRef, Value &outDynMemRef, void fillRtMemRefWithMemRef(Value &outMemRef, Value &outRtMemRef,
PatternRewriter &rewriter, const Location &loc, PatternRewriter &rewriter, const Location &loc,
const std::map<API, ApiSpec> &apiRegistry, const std::map<API, ApiSpec> &apiRegistry,
LLVM::LLVMDialect *llvmDialect) const { LLVM::LLVMDialect *llvmDialect) const {
auto outMemRefTy = outMemRef.getType().dyn_cast<LLVM::LLVMType>(); auto outMemRefTy = outMemRef.getType().dyn_cast<LLVM::LLVMType>();
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto int32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
// Extract the data pointer, and record it in dynamic mem ref created. // Extract the data pointer, and record it in dynamic mem ref created.
Value outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(loc, Value outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(loc,
@ -673,13 +711,19 @@ private:
outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>( outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr); loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr);
callApi(rewriter, loc, apiRegistry, API::SET_DATA, callApi(rewriter, loc, apiRegistry, API::SET_DATA,
{outDynMemRef, outMemRefDataPtr}); {outRtMemRef, outMemRefDataPtr});
auto elemTy = outMemRefTy.getStructElementType(0).getPointerElementTy();
auto onnxTy = llvmTypeToOnnxType(elemTy);
auto onnxTyVal = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(onnxTy));
callApi(
rewriter, loc, apiRegistry, API::SET_DTYPE, {outRtMemRef, onnxTyVal});
auto rank = getRankFromMemRefType(outMemRefTy); auto rank = getRankFromMemRefType(outMemRefTy);
auto sizesArrayPtr = auto sizesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {outDynMemRef}); callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {outRtMemRef});
auto stridesArrayPtr = auto stridesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {outDynMemRef}); callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {outRtMemRef});
for (decltype(rank) i = 0; i < rank; i++) { for (decltype(rank) i = 0; i < rank; i++) {
auto dimIdx = rewriter.create<LLVM::ConstantOp>( auto dimIdx = rewriter.create<LLVM::ConstantOp>(

View File

@ -6,7 +6,7 @@ target_link_libraries(TestConv
rapidcheck rapidcheck
MainUtils MainUtils
ExecutionSession ExecutionSession
DynMemRefUtils) RtMemRefUtils)
target_include_directories(TestConv target_include_directories(TestConv
PRIVATE PRIVATE

View File

@ -96,13 +96,13 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H,
onnx_mlir::ExecutionSession sess( onnx_mlir::ExecutionSession sess(
pathStr + ".so", "_dyn_entry_point_main_graph"); pathStr + ".so", "_dyn_entry_point_main_graph");
std::vector<unique_ptr<DynMemRef>> inputs; std::vector<unique_ptr<RtMemRef>> inputs;
auto xDmr = unique_ptr<DynMemRef>(getRndRealDmr<float>({N, C, H, W})); auto xRmr = unique_ptr<RtMemRef>(getRndRealRmr<float>({N, C, H, W}));
inputs.emplace_back(move(xDmr)); inputs.emplace_back(move(xRmr));
auto wDmr = unique_ptr<DynMemRef>(getRndRealDmr<float>({C, C, kH, kW})); auto wRmr = unique_ptr<RtMemRef>(getRndRealRmr<float>({C, C, kH, kW}));
inputs.emplace_back(move(wDmr)); inputs.emplace_back(move(wRmr));
auto ref = DynMemRef::create<float>({NOut, COut, HOut, WOut}); auto ref = RtMemRef::create<float>({NOut, COut, HOut, WOut});
auto &img = inputs.at(0); auto &img = inputs.at(0);
auto &filter = inputs.at(1); auto &filter = inputs.at(1);
for (int64_t n = 0; n < NOut; n++) for (int64_t n = 0; n < NOut; n++)
@ -124,7 +124,7 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H,
auto outputs = sess.run(move(inputs)); auto outputs = sess.run(move(inputs));
auto &conv = outputs.at(0); auto &conv = outputs.at(0);
return isDmrClose<float>(conv.get(), ref); return isRmrClose<float>(conv.get(), ref);
} }
int main() { int main() {