Support encoding data type infomration as part of the DMR struct. (#178)
* Support encoding data type infomration as part of the DMR struct. * Support full range of np types. * Report error when encountering unsupported type. * Add gerRank method API. * Add missing API declarations. * DynMemRef -> RtMemRef * Format code. * Missed DynMemRef -> RtMemRef conversion. * More comments for RMR, and rename variable names from dmr -> rmr. * DynMemRef -> RtMemRef. * Format code.
This commit is contained in:
parent
f9cb113a84
commit
e902506ee5
|
@ -1,13 +1,13 @@
|
|||
# Create shared libcruntime.so since model.so linkage for backend tests
|
||||
# will fail on x86 Linux if cruntime is statically linked.
|
||||
add_library(cruntime STATIC
|
||||
DynMemRef.cpp
|
||||
DynMemRef.h
|
||||
RtMemRef.cpp
|
||||
RtMemRef.h
|
||||
DataType.h)
|
||||
|
||||
add_library(DynMemRefUtils
|
||||
DynMemRef.h
|
||||
DynMemRef.cpp
|
||||
add_library(RtMemRefUtils
|
||||
RtMemRef.h
|
||||
RtMemRef.cpp
|
||||
DataType.h)
|
||||
|
||||
add_library(ExecutionSession
|
||||
|
@ -28,7 +28,8 @@ pybind11_add_module(PyRuntime
|
|||
target_link_libraries(PyRuntime PRIVATE
|
||||
${CMAKE_DL_LIBS}
|
||||
ExecutionSession
|
||||
DynMemRefUtils)
|
||||
RtMemRefUtils
|
||||
onnx)
|
||||
target_include_directories(PyRuntime PRIVATE
|
||||
${ONNX_MLIR_SRC_ROOT}
|
||||
${ONNX_MLIR_BIN_ROOT}
|
||||
|
@ -41,6 +42,6 @@ set_target_properties(EmbeddedDataLoader PROPERTIES
|
|||
POSITION_INDEPENDENT_CODE TRUE)
|
||||
|
||||
add_dependencies(PyRuntime cruntime)
|
||||
install(FILES DynMemRef.h DESTINATION include)
|
||||
install(FILES RtMemRef.h DESTINATION include)
|
||||
install(TARGETS cruntime DESTINATION lib)
|
||||
install(TARGETS EmbeddedDataLoader DESTINATION lib)
|
||||
|
|
|
@ -42,20 +42,19 @@ ExecutionSession::ExecutionSession(
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<DynMemRef>> ExecutionSession::run(
|
||||
std::vector<std::unique_ptr<DynMemRef>> ins) {
|
||||
auto *wrappedInput = createOrderedDynMemRefDict();
|
||||
std::vector<std::unique_ptr<RtMemRef>> ExecutionSession::run(
|
||||
std::vector<std::unique_ptr<RtMemRef>> ins) {
|
||||
auto *wrappedInput = createOrderedRtMemRefDict();
|
||||
for (size_t i = 0; i < ins.size(); i++)
|
||||
setDynMemRef(wrappedInput, i, ins.at(i).get());
|
||||
setRtMemRef(wrappedInput, i, ins.at(i).get());
|
||||
|
||||
auto *wrappedOutput = _entryPointFunc(wrappedInput);
|
||||
|
||||
std::vector<std::unique_ptr<DynMemRef>> outs;
|
||||
std::vector<std::unique_ptr<RtMemRef>> outs;
|
||||
auto outputSize = getSize(wrappedOutput);
|
||||
|
||||
for (size_t i = 0; i < getSize(wrappedOutput); i++) {
|
||||
outs.emplace_back(
|
||||
std::unique_ptr<DynMemRef>(getDynMemRef(wrappedOutput, i)));
|
||||
outs.emplace_back(std::unique_ptr<RtMemRef>(getRtMemRef(wrappedOutput, i)));
|
||||
}
|
||||
return std::move(outs);
|
||||
}
|
||||
|
|
|
@ -15,18 +15,18 @@
|
|||
#include <dlfcn.h>
|
||||
#include <string>
|
||||
|
||||
#include "src/Runtime/DynMemRef.h"
|
||||
#include "src/Runtime/RtMemRef.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
typedef OrderedDynMemRefDict *(*entryPointFuncType)(OrderedDynMemRefDict *);
|
||||
typedef OrderedRtMemRefDict *(*entryPointFuncType)(OrderedRtMemRefDict *);
|
||||
|
||||
class ExecutionSession {
|
||||
public:
|
||||
ExecutionSession(std::string sharedLibPath, std::string entryPointName);
|
||||
|
||||
std::vector<std::unique_ptr<DynMemRef>> run(
|
||||
std::vector<std::unique_ptr<DynMemRef>>);
|
||||
std::vector<std::unique_ptr<RtMemRef>> run(
|
||||
std::vector<std::unique_ptr<RtMemRef>>);
|
||||
|
||||
~ExecutionSession();
|
||||
|
||||
|
|
|
@ -9,6 +9,9 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "onnx/onnx_pb.h"
|
||||
#include <third_party/onnx/onnx/onnx_pb.h>
|
||||
|
||||
#include "PyExecutionSession.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
@ -16,40 +19,72 @@ namespace onnx_mlir {
|
|||
std::vector<py::array> PyExecutionSession::pyRun(
|
||||
std::vector<py::array> inputsPyArray) {
|
||||
assert(_entryPointFunc && "Entry point not loaded.");
|
||||
auto *wrappedInput = createOrderedDynMemRefDict();
|
||||
auto *wrappedInput = createOrderedRtMemRefDict();
|
||||
int inputIdx = 0;
|
||||
for (auto inputPyArray : inputsPyArray) {
|
||||
auto *inputDynMemRef = createDynMemRef(inputPyArray.ndim());
|
||||
auto *inputRtMemRef = createRtMemRef(inputPyArray.ndim());
|
||||
assert(inputPyArray.flags() && py::array::c_style &&
|
||||
"Expect contiguous python array.");
|
||||
|
||||
if (inputPyArray.writeable()) {
|
||||
inputDynMemRef->data = inputPyArray.mutable_data();
|
||||
inputDynMemRef->alignedData = inputPyArray.mutable_data();
|
||||
inputRtMemRef->data = inputPyArray.mutable_data();
|
||||
inputRtMemRef->alignedData = inputPyArray.mutable_data();
|
||||
} else {
|
||||
// If data is not writable, copy them to a writable buffer.
|
||||
auto *copiedData = (float *)malloc(inputPyArray.nbytes());
|
||||
memcpy(copiedData, inputPyArray.data(), inputPyArray.nbytes());
|
||||
inputDynMemRef->data = copiedData;
|
||||
inputDynMemRef->alignedData = copiedData;
|
||||
inputRtMemRef->data = copiedData;
|
||||
inputRtMemRef->alignedData = copiedData;
|
||||
}
|
||||
|
||||
for (int i = 0; i < inputPyArray.ndim(); i++) {
|
||||
inputDynMemRef->sizes[i] = inputPyArray.shape(i);
|
||||
inputDynMemRef->strides[i] = inputPyArray.strides(i);
|
||||
inputRtMemRef->sizes[i] = inputPyArray.shape(i);
|
||||
inputRtMemRef->strides[i] = inputPyArray.strides(i);
|
||||
}
|
||||
|
||||
setDynMemRef(wrappedInput, inputIdx++, inputDynMemRef);
|
||||
setRtMemRef(wrappedInput, inputIdx++, inputRtMemRef);
|
||||
}
|
||||
|
||||
std::vector<py::array> outputPyArrays;
|
||||
auto *wrappedOutput = _entryPointFunc(wrappedInput);
|
||||
for (int i = 0; i < numDynMemRefs(wrappedOutput); i++) {
|
||||
auto *dynMemRef = getDynMemRef(wrappedOutput, i);
|
||||
for (int i = 0; i < numRtMemRefs(wrappedOutput); i++) {
|
||||
auto *dynMemRef = getRtMemRef(wrappedOutput, i);
|
||||
auto shape = std::vector<int64_t>(
|
||||
dynMemRef->sizes, dynMemRef->sizes + dynMemRef->rank);
|
||||
outputPyArrays.emplace_back(
|
||||
py::array(py::dtype("float32"), shape, dynMemRef->data));
|
||||
|
||||
// https://numpy.org/devdocs/user/basics.types.html
|
||||
py::dtype dtype;
|
||||
if (dynMemRef->onnx_dtype == onnx::TensorProto::FLOAT)
|
||||
dtype = py::dtype("float32");
|
||||
else if (dynMemRef->onnx_dtype = onnx::TensorProto::UINT8)
|
||||
dtype = py::dtype("uint8");
|
||||
else if (dynMemRef->onnx_dtype = onnx::TensorProto::INT8)
|
||||
dtype = py::dtype("int8");
|
||||
else if (dynMemRef->onnx_dtype = onnx::TensorProto::UINT16)
|
||||
dtype = py::dtype("uint16");
|
||||
else if (dynMemRef->onnx_dtype = onnx::TensorProto::INT16)
|
||||
dtype = py::dtype("int16");
|
||||
else if (dynMemRef->onnx_dtype == onnx::TensorProto::INT32)
|
||||
dtype = py::dtype("int32");
|
||||
else if (dynMemRef->onnx_dtype == onnx::TensorProto::INT64)
|
||||
dtype = py::dtype("int64");
|
||||
// TODO(tjingrant) wait for Tong's input for how to represent string.
|
||||
else if (dynMemRef->onnx_dtype = onnx::TensorProto::BOOL)
|
||||
dtype = py::dtype("bool_");
|
||||
else if (dynMemRef->onnx_dtype = onnx::TensorProto::FLOAT16)
|
||||
dtype = py::dtype("float32");
|
||||
else if (dynMemRef->onnx_dtype = onnx::TensorProto::DOUBLE)
|
||||
dtype = py::dtype("float64");
|
||||
else if (dynMemRef->onnx_dtype == onnx::TensorProto::UINT32)
|
||||
dtype = py::dtype("uint32");
|
||||
else if (dynMemRef->onnx_dtype == onnx::TensorProto::UINT64)
|
||||
dtype = py::dtype("uint64");
|
||||
else {
|
||||
fprintf(stderr, "Unsupported ONNX type in RtMemRef.onnx_dtype.");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
outputPyArrays.emplace_back(py::array(dtype, shape, dynMemRef->data));
|
||||
}
|
||||
|
||||
return outputPyArrays;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===----------- DynMemRef.cpp - Dynamic MemRef Implementation ------------===//
|
||||
//===----------- RtMemRef.cpp - Dynamic MemRef Implementation ------------===//
|
||||
//
|
||||
// Copyright 2019-2020 The IBM Research Authors.
|
||||
//
|
||||
|
@ -18,7 +18,7 @@
|
|||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "DynMemRef.h"
|
||||
#include "RtMemRef.h"
|
||||
|
||||
namespace {
|
||||
// Helper function to compute cartisian product.
|
||||
|
@ -39,19 +39,19 @@ inline std::vector<std::vector<INDEX_TYPE>> CartProduct(
|
|||
}
|
||||
} // namespace
|
||||
|
||||
DynMemRef::DynMemRef(int _rank) {
|
||||
RtMemRef::RtMemRef(int _rank) {
|
||||
rank = _rank;
|
||||
sizes = (INDEX_TYPE *)malloc(rank * sizeof(INDEX_TYPE));
|
||||
strides = (int64_t *)malloc(rank * sizeof(int64_t));
|
||||
}
|
||||
|
||||
INDEX_TYPE DynMemRef::size() const {
|
||||
INDEX_TYPE RtMemRef::size() const {
|
||||
return std::accumulate(sizes, sizes + rank, 1, std::multiplies<>());
|
||||
}
|
||||
|
||||
std::vector<std::vector<INDEX_TYPE>> DynMemRef::indexSet() const {
|
||||
std::vector<std::vector<INDEX_TYPE>> RtMemRef::indexSet() const {
|
||||
// First, we create index set of each dimension separately.
|
||||
// i.e., for a tensor/DMR of shape (2, 3), its dimWiseIdxSet will be:
|
||||
// i.e., for a tensor/RMR of shape (2, 3), its dimWiseIdxSet will be:
|
||||
// {{0,1}, {0,1,2}};
|
||||
std::vector<std::vector<INDEX_TYPE>> dimWiseIdxSet;
|
||||
for (auto dimSize : std::vector<INDEX_TYPE>(sizes, sizes + rank)) {
|
||||
|
@ -60,18 +60,18 @@ std::vector<std::vector<INDEX_TYPE>> DynMemRef::indexSet() const {
|
|||
dimWiseIdxSet.emplace_back(dimIdxSet);
|
||||
}
|
||||
// Then, the cartesian product of vectors within dimWiseIdxSet will be the
|
||||
// index set for the whole DMR.
|
||||
// index set for the whole RMR.
|
||||
return CartProduct(dimWiseIdxSet);
|
||||
}
|
||||
|
||||
INDEX_TYPE DynMemRef::computeOffset(std::vector<INDEX_TYPE> &idxs) const {
|
||||
INDEX_TYPE RtMemRef::computeOffset(std::vector<INDEX_TYPE> &idxs) const {
|
||||
auto dimStrides = std::vector<INDEX_TYPE>(strides, strides + rank);
|
||||
INDEX_TYPE elemOffset = std::inner_product(
|
||||
idxs.begin(), idxs.end(), dimStrides.begin(), (INDEX_TYPE)0);
|
||||
return elemOffset;
|
||||
}
|
||||
|
||||
std::vector<int64_t> DynMemRef::computeStridesFromSizes() const {
|
||||
std::vector<int64_t> RtMemRef::computeStridesFromSizes() const {
|
||||
// Shift dimension sizes one to the left, fill in the vacated rightmost
|
||||
// element with 1; this gets us a vector that'll be more useful for computing
|
||||
// strides of memory access along each dimension using prefix product (aka
|
||||
|
@ -86,7 +86,7 @@ std::vector<int64_t> DynMemRef::computeStridesFromSizes() const {
|
|||
return dimStrides;
|
||||
}
|
||||
|
||||
DynMemRef::~DynMemRef() {
|
||||
RtMemRef::~RtMemRef() {
|
||||
free(data);
|
||||
free(sizes);
|
||||
free(strides);
|
||||
|
@ -95,27 +95,26 @@ DynMemRef::~DynMemRef() {
|
|||
// An ordered dynamic MemRef dictionary.
|
||||
// The goal is to support accessing dynamic memory ref by name and by index.
|
||||
// Currently, only accessing by index is supported.
|
||||
struct OrderedDynMemRefDict {
|
||||
std::map<std::string, DynMemRef *> tensorDict;
|
||||
struct OrderedRtMemRefDict {
|
||||
std::map<std::string, RtMemRef *> tensorDict;
|
||||
std::vector<std::string> orderedNames;
|
||||
};
|
||||
|
||||
int numDynMemRefs(OrderedDynMemRefDict *dict) {
|
||||
int numRtMemRefs(OrderedRtMemRefDict *dict) {
|
||||
return dict->orderedNames.size();
|
||||
}
|
||||
|
||||
OrderedDynMemRefDict *createOrderedDynMemRefDict() {
|
||||
return new OrderedDynMemRefDict();
|
||||
OrderedRtMemRefDict *createOrderedRtMemRefDict() {
|
||||
return new OrderedRtMemRefDict();
|
||||
}
|
||||
|
||||
DynMemRef *createDynMemRef(int rank) { return new DynMemRef(rank); }
|
||||
RtMemRef *createRtMemRef(int rank) { return new RtMemRef(rank); }
|
||||
|
||||
DynMemRef *getDynMemRef(OrderedDynMemRefDict *tensorDict, int idx) {
|
||||
RtMemRef *getRtMemRef(OrderedRtMemRefDict *tensorDict, int idx) {
|
||||
return tensorDict->tensorDict[tensorDict->orderedNames[idx]];
|
||||
}
|
||||
|
||||
void setDynMemRef(
|
||||
OrderedDynMemRefDict *tensorDict, int idx, DynMemRef *tensor) {
|
||||
void setRtMemRef(OrderedRtMemRefDict *tensorDict, int idx, RtMemRef *tensor) {
|
||||
if (tensorDict->orderedNames.size() <= idx)
|
||||
tensorDict->orderedNames.resize(idx + 1);
|
||||
|
||||
|
@ -130,30 +129,36 @@ void setDynMemRef(
|
|||
tensorDict->tensorDict[tensorDict->orderedNames[idx]] = tensor;
|
||||
}
|
||||
|
||||
void *getData(DynMemRef *dynMemRef) { return dynMemRef->data; }
|
||||
void *getData(RtMemRef *dynMemRef) { return dynMemRef->data; }
|
||||
|
||||
void setData(DynMemRef *dynMemRef, void *dataPtr) { dynMemRef->data = dataPtr; }
|
||||
void setData(RtMemRef *dynMemRef, void *dataPtr) { dynMemRef->data = dataPtr; }
|
||||
|
||||
void *getAlignedData(DynMemRef *dynMemRef) { return dynMemRef->alignedData; }
|
||||
void *getAlignedData(RtMemRef *dynMemRef) { return dynMemRef->alignedData; }
|
||||
|
||||
void setAlignedData(DynMemRef *dynMemRef, void *dataPtr) {
|
||||
void setAlignedData(RtMemRef *dynMemRef, void *dataPtr) {
|
||||
dynMemRef->alignedData = dataPtr;
|
||||
}
|
||||
|
||||
INDEX_TYPE *getSizes(DynMemRef *dynMemRef) { return dynMemRef->sizes; }
|
||||
INDEX_TYPE *getSizes(RtMemRef *dynMemRef) { return dynMemRef->sizes; }
|
||||
|
||||
void setSizes(DynMemRef *dynMemRef, INDEX_TYPE *sizes) {
|
||||
void setSizes(RtMemRef *dynMemRef, INDEX_TYPE *sizes) {
|
||||
for (int i = 0; i < dynMemRef->rank; i++)
|
||||
dynMemRef->sizes[i] = sizes[i];
|
||||
}
|
||||
|
||||
int64_t *getStrides(DynMemRef *dynMemRef) { return dynMemRef->strides; }
|
||||
int64_t *getStrides(RtMemRef *dynMemRef) { return dynMemRef->strides; }
|
||||
|
||||
int64_t getSize(OrderedDynMemRefDict *dict) {
|
||||
return dict->orderedNames.size();
|
||||
int64_t getSize(OrderedRtMemRefDict *dict) { return dict->orderedNames.size(); }
|
||||
|
||||
void setDType(RtMemRef *dynMemRef, int onnxType) {
|
||||
dynMemRef->onnx_dtype = onnxType;
|
||||
}
|
||||
|
||||
void setStrides(DynMemRef *dynMemRef, int64_t *strides) {
|
||||
int getDType(RtMemRef *dynMemRef) { return dynMemRef->onnx_dtype; }
|
||||
|
||||
unsigned int getRank(RtMemRef *dynMemRef) { return dynMemRef->rank; }
|
||||
|
||||
void setStrides(RtMemRef *dynMemRef, int64_t *strides) {
|
||||
for (int i = 0; i < dynMemRef->rank; i++)
|
||||
dynMemRef->sizes[i] = strides[i];
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
//===------------ DynMemRef.h - Dynamic MemRef Implementation -------------===//
|
||||
//===------------ RtMemRef.h - Dynamic MemRef Implementation -------------===//
|
||||
//
|
||||
// Copyright 2019-2020 The IBM Research Authors.
|
||||
//
|
||||
|
@ -24,40 +24,69 @@
|
|||
|
||||
typedef int64_t INDEX_TYPE;
|
||||
|
||||
// This is a dynamic version of memref.
|
||||
// The same struct can be used to represent memrefs of
|
||||
// all ranks and type combinations.
|
||||
// We will refer to it as a DMR (Dynamic MemRef).
|
||||
struct DynMemRef {
|
||||
// Typically, MemRefs in MLIR context are used as a compile-time constructs.
|
||||
// Information such as element type and rank of the data payload is statically
|
||||
// encoded, meaning that they are determined and fixed at compile-time. This
|
||||
// presents significant burden for any runtime components trying to interact
|
||||
// with the compiled executable.
|
||||
//
|
||||
// Thus a version of MemRef struct that is amenable to runtime manipulation is
|
||||
// provided as a basis for building any runtime-related components providing
|
||||
// user-facing programming interfaces. All information are dynamically encoded
|
||||
// as members of this struct so that they can be accessed and modified easily
|
||||
// during runtime.
|
||||
//
|
||||
// We will refer to it as a RMF (Runtime MemRef).
|
||||
struct RtMemRef {
|
||||
|
||||
// Pointer to the raw memory space allocated to host the RMR content. This
|
||||
// pointer should only be acessed for memory management purposes, not for
|
||||
// reading RMR content.
|
||||
void *data;
|
||||
|
||||
// Pointer to the properly aligned array of elements stored in this Rmr.
|
||||
void *alignedData;
|
||||
|
||||
// Distance between the start of the raw memory space and the first element of
|
||||
// the RMR content.
|
||||
INDEX_TYPE offset;
|
||||
|
||||
// Number of dimensions of the array represented by the RMR.
|
||||
unsigned int rank;
|
||||
|
||||
// An array recording the per-dimension sizes of the array represented by the
|
||||
// RMR.
|
||||
INDEX_TYPE *sizes;
|
||||
|
||||
// An array recording the per-dimension strides of the array represented by
|
||||
// the RMR.
|
||||
int64_t *strides;
|
||||
|
||||
// Refer to TensorProto_DataType at
|
||||
// https://github.com/onnx/onnx/blob/cc2230603422bae893d5bc900d2d773ab34400a4/onnx/onnx-ml.proto#L451
|
||||
// for enum value interpretation.
|
||||
unsigned int onnx_dtype;
|
||||
#ifdef __cplusplus
|
||||
explicit DynMemRef(int _rank);
|
||||
explicit RtMemRef(int _rank);
|
||||
|
||||
// Create a full DMR of type T and shape _sizes, with all data fields
|
||||
// Create a full RMR of type T and shape _sizes, with all data fields
|
||||
// initialized to proper values and data pointers malloc'ed.
|
||||
template <typename T>
|
||||
static DynMemRef *create(std::vector<INDEX_TYPE> _sizes) {
|
||||
auto dmr = new DynMemRef(_sizes.size());
|
||||
dmr->offset = 0;
|
||||
dmr->rank = _sizes.size();
|
||||
dmr->sizes = (INDEX_TYPE *)malloc(dmr->rank * sizeof(INDEX_TYPE));
|
||||
std::copy(_sizes.begin(), _sizes.end(), dmr->sizes);
|
||||
static RtMemRef *create(std::vector<INDEX_TYPE> _sizes) {
|
||||
auto rmr = new RtMemRef(_sizes.size());
|
||||
rmr->offset = 0;
|
||||
rmr->rank = _sizes.size();
|
||||
rmr->sizes = (INDEX_TYPE *)malloc(rmr->rank * sizeof(INDEX_TYPE));
|
||||
std::copy(_sizes.begin(), _sizes.end(), rmr->sizes);
|
||||
|
||||
dmr->strides = (int64_t *)malloc(dmr->rank * sizeof(int64_t));
|
||||
auto computedStrides = dmr->computeStridesFromSizes();
|
||||
std::copy(computedStrides.begin(), computedStrides.end(), dmr->strides);
|
||||
rmr->strides = (int64_t *)malloc(rmr->rank * sizeof(int64_t));
|
||||
auto computedStrides = rmr->computeStridesFromSizes();
|
||||
std::copy(computedStrides.begin(), computedStrides.end(), rmr->strides);
|
||||
|
||||
dmr->data = malloc(dmr->size() * sizeof(T));
|
||||
dmr->alignedData = dmr->data;
|
||||
rmr->data = malloc(rmr->size() * sizeof(T));
|
||||
rmr->alignedData = rmr->data;
|
||||
|
||||
return dmr;
|
||||
return rmr;
|
||||
}
|
||||
|
||||
// Access an element (by reference) at index position idxs.
|
||||
|
@ -75,13 +104,13 @@ struct DynMemRef {
|
|||
return typedPtr[idx];
|
||||
}
|
||||
|
||||
// Get a typed ptr to the data content of the DMR.
|
||||
// Get a typed ptr to the data content of the RMR.
|
||||
template <typename T>
|
||||
T *typedPtr() {
|
||||
return (T *)data;
|
||||
}
|
||||
|
||||
// Get how many elements are stored in DMR, as implied by its shape.
|
||||
// Get how many elements are stored in RMR, as implied by its shape.
|
||||
INDEX_TYPE size() const;
|
||||
|
||||
// Helper function to compute strides of access along each dimensions from its
|
||||
|
@ -92,63 +121,77 @@ struct DynMemRef {
|
|||
INDEX_TYPE computeOffset(std::vector<INDEX_TYPE> &idxs) const;
|
||||
|
||||
// Get the index set (i.e., all valid multi-dimensional array indexes that can
|
||||
// be used to access this DMR's constituent elements).
|
||||
// be used to access this RMR's constituent elements).
|
||||
std::vector<std::vector<INDEX_TYPE>> indexSet() const;
|
||||
|
||||
~DynMemRef();
|
||||
~RtMemRef();
|
||||
#endif
|
||||
};
|
||||
|
||||
#ifdef __cplusplus
|
||||
// Ordered DynMemRef Dictionary is a data structure for wrapping the input
|
||||
// Ordered RtMemRef Dictionary is a data structure for wrapping the input
|
||||
// dynmemrefs so that they can be addressed both by index and by name.
|
||||
struct OrderedDynMemRefDict;
|
||||
struct OrderedRtMemRefDict;
|
||||
|
||||
#else
|
||||
typedef struct DynMemRef DynMemRef;
|
||||
typedef struct _OrderedDynMemRefDict OrderedDynMemRefDict;
|
||||
typedef struct RtMemRef RtMemRef;
|
||||
typedef struct _OrderedRtMemRefDict OrderedRtMemRefDict;
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Get number of dynamic memrefs in OrderedDynMemRefDict dict.
|
||||
int numDynMemRefs(OrderedDynMemRefDict *dict);
|
||||
// Get number of dynamic memrefs in OrderedRtMemRefDict dict.
|
||||
int numRtMemRefs(OrderedRtMemRefDict *dict);
|
||||
|
||||
// Create an ordered dynamic memref dictionary.
|
||||
OrderedDynMemRefDict *createOrderedDynMemRefDict();
|
||||
OrderedRtMemRefDict *createOrderedRtMemRefDict();
|
||||
|
||||
// Get how many dynamic memrefs are in dict.
|
||||
int64_t getSize(OrderedDynMemRefDict *dict);
|
||||
int64_t getSize(OrderedRtMemRefDict *dict);
|
||||
|
||||
// Create a dynmemref with a certain rank.
|
||||
DynMemRef *createDynMemRef(int rank);
|
||||
RtMemRef *createRtMemRef(int rank);
|
||||
|
||||
// Get the i-th dynmemref from orderedDict.
|
||||
DynMemRef *getDynMemRef(OrderedDynMemRefDict *orderedDict, int i);
|
||||
RtMemRef *getRtMemRef(OrderedRtMemRefDict *orderedDict, int i);
|
||||
|
||||
// Set the i-th dynmemref in orderedDict to be dynMemRef.
|
||||
void setDynMemRef(
|
||||
OrderedDynMemRefDict *tensorDict, int idx, DynMemRef *dynMemRef);
|
||||
void setRtMemRef(OrderedRtMemRefDict *tensorDict, int idx, RtMemRef *dynMemRef);
|
||||
|
||||
// Get data pointer from dynMemRef.
|
||||
void *getData(DynMemRef *dynMemRef);
|
||||
void *getData(RtMemRef *dynMemRef);
|
||||
|
||||
// Set data pointer for dynMemRef.
|
||||
void setData(DynMemRef *dynMemRef, void *data);
|
||||
void setData(RtMemRef *dynMemRef, void *data);
|
||||
|
||||
// Get algined data pointer from dynMemRef.
|
||||
void *getAlignedData(DynMemRef *);
|
||||
void *getAlignedData(RtMemRef *);
|
||||
|
||||
// Set aligned data pointer for dynMemRef.
|
||||
void setAlignedData(DynMemRef *, void *);
|
||||
void setAlignedData(RtMemRef *, void *);
|
||||
|
||||
// Get the data type enum value of the dynMemRef.
|
||||
int getDType(RtMemRef *dynMemRef);
|
||||
|
||||
// Set the data type enum value of the dynMemRef.
|
||||
void setDType(RtMemRef *dynMemRef, int onnxType);
|
||||
|
||||
// Get the rank of the dynMemRef.
|
||||
unsigned int getRank(RtMemRef *dynMemRef);
|
||||
|
||||
// Get ptr to sizes array.
|
||||
INDEX_TYPE *getSizes(DynMemRef *);
|
||||
INDEX_TYPE *getSizes(RtMemRef *);
|
||||
|
||||
// Set the sizes array (by copying size values from array `sizes`).
|
||||
void setSizes(RtMemRef *, INDEX_TYPE *sizes);
|
||||
|
||||
// Get ptr to strides array.
|
||||
int64_t *getStrides(DynMemRef *);
|
||||
int64_t *getStrides(RtMemRef *);
|
||||
|
||||
// Set the strides array (by copying stride values from array `strides`).
|
||||
void setStrides(RtMemRef *, int64_t *strides);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
@ -164,22 +207,22 @@ void printVector(std::vector<T> vec, std::string _delimiter = ",",
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
DynMemRef *getRndRealDmr(
|
||||
RtMemRef *getRndRealRmr(
|
||||
std::vector<INDEX_TYPE> sizes, T lb = -1.0, T ub = 1.0) {
|
||||
// Will be used to obtain a seed for the random number engine
|
||||
std::random_device rd;
|
||||
// Standard mersenne_twister_engine seeded with rd()
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_real_distribution<> dis(lb, ub);
|
||||
auto dmr = DynMemRef::create<T>(sizes);
|
||||
auto ptr = (T *)dmr->data;
|
||||
std::generate(ptr, ptr + dmr->size(), [&]() { return dis(gen); });
|
||||
return dmr;
|
||||
auto rmr = RtMemRef::create<T>(sizes);
|
||||
auto ptr = (T *)rmr->data;
|
||||
std::generate(ptr, ptr + rmr->size(), [&]() { return dis(gen); });
|
||||
return rmr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool isDmrClose(
|
||||
DynMemRef *a, DynMemRef *b, float rtol = 1e-5, float atol = 1e-5) {
|
||||
inline bool isRmrClose(
|
||||
RtMemRef *a, RtMemRef *b, float rtol = 1e-5, float atol = 1e-5) {
|
||||
|
||||
// Compare shape.
|
||||
auto aShape = std::vector<INDEX_TYPE>(a->sizes, a->sizes + a->rank);
|
||||
|
@ -232,3 +275,6 @@ inline bool isDmrClose(
|
|||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// Will transition from RtMemRef to RtMemRef soon.
|
||||
typedef RtMemRef RtMemRef;
|
|
@ -19,6 +19,7 @@
|
|||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "onnx/onnx_pb.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
|
||||
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
|
||||
|
@ -29,6 +30,38 @@ using namespace mlir;
|
|||
|
||||
namespace {
|
||||
|
||||
static onnx::TensorProto::DataType llvmTypeToOnnxType(
|
||||
mlir::LLVM::LLVMType elemType) {
|
||||
if (elemType.isFloatTy())
|
||||
return onnx::TensorProto::FLOAT;
|
||||
if (elemType.isUnsignedInteger(8))
|
||||
return onnx::TensorProto::UINT8;
|
||||
if (elemType.isSignedInteger(8))
|
||||
return onnx::TensorProto::INT8;
|
||||
if (elemType.isUnsignedInteger(16))
|
||||
return onnx::TensorProto::UINT16;
|
||||
if (elemType.isSignedInteger(16))
|
||||
return onnx::TensorProto::INT16;
|
||||
if (elemType.isSignedInteger(32))
|
||||
return onnx::TensorProto::INT32;
|
||||
if (elemType.isSignedInteger(64))
|
||||
return onnx::TensorProto::INT64;
|
||||
// TODO, wait for Tong's input about how string is represented in MLIR.
|
||||
if (elemType.isInteger(1))
|
||||
return onnx::TensorProto::BOOL;
|
||||
if (elemType.isHalfTy())
|
||||
return onnx::TensorProto::FLOAT16;
|
||||
if (elemType.isDoubleTy())
|
||||
return onnx::TensorProto::DOUBLE;
|
||||
if (elemType.isUnsignedInteger(32))
|
||||
return onnx::TensorProto::UINT32;
|
||||
if (elemType.isUnsignedInteger(64))
|
||||
return onnx::TensorProto::INT64;
|
||||
// Complex types don't seem to exist in LLVM Dialect.
|
||||
elemType.dump();
|
||||
llvm_unreachable("Unexpected LLVM type, cannot be converted to ONNX type.");
|
||||
}
|
||||
|
||||
static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName,
|
||||
ModuleOp module, mlir::LLVM::LLVMType funcType, PatternRewriter &rewriter) {
|
||||
auto *context = module.getContext();
|
||||
|
@ -374,6 +407,8 @@ public:
|
|||
SET_DATA,
|
||||
GET_SIZES,
|
||||
GET_STRIDES,
|
||||
SET_DTYPE,
|
||||
GET_DTYPE,
|
||||
};
|
||||
|
||||
struct ApiSpec {
|
||||
|
@ -469,7 +504,7 @@ public:
|
|||
|
||||
// Fill in the memref underlying ptrToMemRef with information extracted
|
||||
// from dynMemRef.
|
||||
fillPtrToMemRefWithDynMemRef(
|
||||
fillPtrToMemRefWithRtMemRef(
|
||||
dynMemRef, ptrToMemRef, rewriter, loc, apiRegistry, llvmDialect);
|
||||
|
||||
// ptrToMemRef will be an input to main computation graph function.
|
||||
|
@ -518,14 +553,14 @@ public:
|
|||
auto outMemRefRank = getRankFromMemRefType(outMemRefTy);
|
||||
auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
|
||||
auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
|
||||
auto outRtMemRef = callApi(rewriter, loc, apiRegistry,
|
||||
API::CREATE_DYN_MEM_REF, {outMemRefRankVal});
|
||||
fillDynMemRefWithMemRef(
|
||||
memRef, outDynMemRef, rewriter, loc, apiRegistry, llvmDialect);
|
||||
fillRtMemRefWithMemRef(
|
||||
memRef, outRtMemRef, rewriter, loc, apiRegistry, llvmDialect);
|
||||
auto idx = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Ty, rewriter.getI32IntegerAttr(i));
|
||||
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
|
||||
{wrappedOutput, idx, outDynMemRef});
|
||||
{wrappedOutput, idx, outRtMemRef});
|
||||
}
|
||||
// Return wrapped output.
|
||||
rewriter.create<LLVM::ReturnOp>(
|
||||
|
@ -549,14 +584,16 @@ private:
|
|||
// specifying its signature.
|
||||
// clang-format off
|
||||
std::vector<ApiSpec> apiSpecs = {
|
||||
ApiSpec(API::CREATE_ORDERED_DYN_MEM_REF_DICT, "createOrderedDynMemRefDict", opaquePtrTy, {}),
|
||||
ApiSpec(API::CREATE_DYN_MEM_REF, "createDynMemRef", opaquePtrTy, {int32Ty}),
|
||||
ApiSpec(API::CREATE_ORDERED_DYN_MEM_REF_DICT, "createOrderedRtMemRefDict", opaquePtrTy, {}),
|
||||
ApiSpec(API::CREATE_DYN_MEM_REF, "createRtMemRef", opaquePtrTy, {int32Ty}),
|
||||
ApiSpec(API::GET_DATA, "getData", opaquePtrTy, {opaquePtrTy}),
|
||||
ApiSpec(API::SET_DATA, "setData", voidTy, {opaquePtrTy, opaquePtrTy}),
|
||||
ApiSpec(API::GET_DYN_MEM_REF, "getDynMemRef", opaquePtrTy, {opaquePtrTy, int32Ty}),
|
||||
ApiSpec(API::SET_DYN_MEM_REF, "setDynMemRef", voidTy, {opaquePtrTy, int32Ty, opaquePtrTy}),
|
||||
ApiSpec(API::GET_DYN_MEM_REF, "getRtMemRef", opaquePtrTy, {opaquePtrTy, int32Ty}),
|
||||
ApiSpec(API::SET_DYN_MEM_REF, "setRtMemRef", voidTy, {opaquePtrTy, int32Ty, opaquePtrTy}),
|
||||
ApiSpec(API::GET_SIZES, "getSizes", int64PtrTy, {opaquePtrTy}),
|
||||
ApiSpec(API::GET_STRIDES, "getStrides", int64PtrTy, {opaquePtrTy})
|
||||
ApiSpec(API::GET_STRIDES, "getStrides", int64PtrTy, {opaquePtrTy}),
|
||||
ApiSpec(API::GET_DTYPE, "getDType", int32Ty, {opaquePtrTy}),
|
||||
ApiSpec(API::SET_DTYPE, "setDType", voidTy, {opaquePtrTy, int32Ty}),
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
|
@ -598,7 +635,7 @@ private:
|
|||
return *entryPointEntryBlock;
|
||||
}
|
||||
|
||||
void fillPtrToMemRefWithDynMemRef(Value &dynMemRef, Value &ptrToMemRef,
|
||||
void fillPtrToMemRefWithRtMemRef(Value &dynMemRef, Value &ptrToMemRef,
|
||||
PatternRewriter &rewriter, const Location &loc,
|
||||
const std::map<API, ApiSpec> &apiRegistry,
|
||||
LLVM::LLVMDialect *llvmDialect) const {
|
||||
|
@ -659,12 +696,13 @@ private:
|
|||
rewriter.create<LLVM::StoreOp>(loc, memRef, ptrToMemRef);
|
||||
}
|
||||
|
||||
void fillDynMemRefWithMemRef(Value &outMemRef, Value &outDynMemRef,
|
||||
void fillRtMemRefWithMemRef(Value &outMemRef, Value &outRtMemRef,
|
||||
PatternRewriter &rewriter, const Location &loc,
|
||||
const std::map<API, ApiSpec> &apiRegistry,
|
||||
LLVM::LLVMDialect *llvmDialect) const {
|
||||
auto outMemRefTy = outMemRef.getType().dyn_cast<LLVM::LLVMType>();
|
||||
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||
auto int32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
|
||||
|
||||
// Extract the data pointer, and record it in dynamic mem ref created.
|
||||
Value outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(loc,
|
||||
|
@ -673,13 +711,19 @@ private:
|
|||
outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr);
|
||||
callApi(rewriter, loc, apiRegistry, API::SET_DATA,
|
||||
{outDynMemRef, outMemRefDataPtr});
|
||||
{outRtMemRef, outMemRefDataPtr});
|
||||
auto elemTy = outMemRefTy.getStructElementType(0).getPointerElementTy();
|
||||
auto onnxTy = llvmTypeToOnnxType(elemTy);
|
||||
auto onnxTyVal = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Ty, rewriter.getI32IntegerAttr(onnxTy));
|
||||
callApi(
|
||||
rewriter, loc, apiRegistry, API::SET_DTYPE, {outRtMemRef, onnxTyVal});
|
||||
|
||||
auto rank = getRankFromMemRefType(outMemRefTy);
|
||||
auto sizesArrayPtr =
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {outDynMemRef});
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {outRtMemRef});
|
||||
auto stridesArrayPtr =
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {outDynMemRef});
|
||||
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {outRtMemRef});
|
||||
|
||||
for (decltype(rank) i = 0; i < rank; i++) {
|
||||
auto dimIdx = rewriter.create<LLVM::ConstantOp>(
|
||||
|
|
|
@ -6,7 +6,7 @@ target_link_libraries(TestConv
|
|||
rapidcheck
|
||||
MainUtils
|
||||
ExecutionSession
|
||||
DynMemRefUtils)
|
||||
RtMemRefUtils)
|
||||
|
||||
target_include_directories(TestConv
|
||||
PRIVATE
|
||||
|
|
|
@ -96,13 +96,13 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H,
|
|||
onnx_mlir::ExecutionSession sess(
|
||||
pathStr + ".so", "_dyn_entry_point_main_graph");
|
||||
|
||||
std::vector<unique_ptr<DynMemRef>> inputs;
|
||||
auto xDmr = unique_ptr<DynMemRef>(getRndRealDmr<float>({N, C, H, W}));
|
||||
inputs.emplace_back(move(xDmr));
|
||||
auto wDmr = unique_ptr<DynMemRef>(getRndRealDmr<float>({C, C, kH, kW}));
|
||||
inputs.emplace_back(move(wDmr));
|
||||
std::vector<unique_ptr<RtMemRef>> inputs;
|
||||
auto xRmr = unique_ptr<RtMemRef>(getRndRealRmr<float>({N, C, H, W}));
|
||||
inputs.emplace_back(move(xRmr));
|
||||
auto wRmr = unique_ptr<RtMemRef>(getRndRealRmr<float>({C, C, kH, kW}));
|
||||
inputs.emplace_back(move(wRmr));
|
||||
|
||||
auto ref = DynMemRef::create<float>({NOut, COut, HOut, WOut});
|
||||
auto ref = RtMemRef::create<float>({NOut, COut, HOut, WOut});
|
||||
auto &img = inputs.at(0);
|
||||
auto &filter = inputs.at(1);
|
||||
for (int64_t n = 0; n < NOut; n++)
|
||||
|
@ -124,7 +124,7 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H,
|
|||
auto outputs = sess.run(move(inputs));
|
||||
auto &conv = outputs.at(0);
|
||||
|
||||
return isDmrClose<float>(conv.get(), ref);
|
||||
return isRmrClose<float>(conv.get(), ref);
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
|
Loading…
Reference in New Issue