Support encoding data type infomration as part of the DMR struct. (#178)
* Support encoding data type infomration as part of the DMR struct. * Support full range of np types. * Report error when encountering unsupported type. * Add gerRank method API. * Add missing API declarations. * DynMemRef -> RtMemRef * Format code. * Missed DynMemRef -> RtMemRef conversion. * More comments for RMR, and rename variable names from dmr -> rmr. * DynMemRef -> RtMemRef. * Format code.
This commit is contained in:
parent
f9cb113a84
commit
e902506ee5
|
@ -1,13 +1,13 @@
|
||||||
# Create shared libcruntime.so since model.so linkage for backend tests
|
# Create shared libcruntime.so since model.so linkage for backend tests
|
||||||
# will fail on x86 Linux if cruntime is statically linked.
|
# will fail on x86 Linux if cruntime is statically linked.
|
||||||
add_library(cruntime STATIC
|
add_library(cruntime STATIC
|
||||||
DynMemRef.cpp
|
RtMemRef.cpp
|
||||||
DynMemRef.h
|
RtMemRef.h
|
||||||
DataType.h)
|
DataType.h)
|
||||||
|
|
||||||
add_library(DynMemRefUtils
|
add_library(RtMemRefUtils
|
||||||
DynMemRef.h
|
RtMemRef.h
|
||||||
DynMemRef.cpp
|
RtMemRef.cpp
|
||||||
DataType.h)
|
DataType.h)
|
||||||
|
|
||||||
add_library(ExecutionSession
|
add_library(ExecutionSession
|
||||||
|
@ -28,7 +28,8 @@ pybind11_add_module(PyRuntime
|
||||||
target_link_libraries(PyRuntime PRIVATE
|
target_link_libraries(PyRuntime PRIVATE
|
||||||
${CMAKE_DL_LIBS}
|
${CMAKE_DL_LIBS}
|
||||||
ExecutionSession
|
ExecutionSession
|
||||||
DynMemRefUtils)
|
RtMemRefUtils
|
||||||
|
onnx)
|
||||||
target_include_directories(PyRuntime PRIVATE
|
target_include_directories(PyRuntime PRIVATE
|
||||||
${ONNX_MLIR_SRC_ROOT}
|
${ONNX_MLIR_SRC_ROOT}
|
||||||
${ONNX_MLIR_BIN_ROOT}
|
${ONNX_MLIR_BIN_ROOT}
|
||||||
|
@ -41,6 +42,6 @@ set_target_properties(EmbeddedDataLoader PROPERTIES
|
||||||
POSITION_INDEPENDENT_CODE TRUE)
|
POSITION_INDEPENDENT_CODE TRUE)
|
||||||
|
|
||||||
add_dependencies(PyRuntime cruntime)
|
add_dependencies(PyRuntime cruntime)
|
||||||
install(FILES DynMemRef.h DESTINATION include)
|
install(FILES RtMemRef.h DESTINATION include)
|
||||||
install(TARGETS cruntime DESTINATION lib)
|
install(TARGETS cruntime DESTINATION lib)
|
||||||
install(TARGETS EmbeddedDataLoader DESTINATION lib)
|
install(TARGETS EmbeddedDataLoader DESTINATION lib)
|
||||||
|
|
|
@ -42,20 +42,19 @@ ExecutionSession::ExecutionSession(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::unique_ptr<DynMemRef>> ExecutionSession::run(
|
std::vector<std::unique_ptr<RtMemRef>> ExecutionSession::run(
|
||||||
std::vector<std::unique_ptr<DynMemRef>> ins) {
|
std::vector<std::unique_ptr<RtMemRef>> ins) {
|
||||||
auto *wrappedInput = createOrderedDynMemRefDict();
|
auto *wrappedInput = createOrderedRtMemRefDict();
|
||||||
for (size_t i = 0; i < ins.size(); i++)
|
for (size_t i = 0; i < ins.size(); i++)
|
||||||
setDynMemRef(wrappedInput, i, ins.at(i).get());
|
setRtMemRef(wrappedInput, i, ins.at(i).get());
|
||||||
|
|
||||||
auto *wrappedOutput = _entryPointFunc(wrappedInput);
|
auto *wrappedOutput = _entryPointFunc(wrappedInput);
|
||||||
|
|
||||||
std::vector<std::unique_ptr<DynMemRef>> outs;
|
std::vector<std::unique_ptr<RtMemRef>> outs;
|
||||||
auto outputSize = getSize(wrappedOutput);
|
auto outputSize = getSize(wrappedOutput);
|
||||||
|
|
||||||
for (size_t i = 0; i < getSize(wrappedOutput); i++) {
|
for (size_t i = 0; i < getSize(wrappedOutput); i++) {
|
||||||
outs.emplace_back(
|
outs.emplace_back(std::unique_ptr<RtMemRef>(getRtMemRef(wrappedOutput, i)));
|
||||||
std::unique_ptr<DynMemRef>(getDynMemRef(wrappedOutput, i)));
|
|
||||||
}
|
}
|
||||||
return std::move(outs);
|
return std::move(outs);
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,18 +15,18 @@
|
||||||
#include <dlfcn.h>
|
#include <dlfcn.h>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "src/Runtime/DynMemRef.h"
|
#include "src/Runtime/RtMemRef.h"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
typedef OrderedDynMemRefDict *(*entryPointFuncType)(OrderedDynMemRefDict *);
|
typedef OrderedRtMemRefDict *(*entryPointFuncType)(OrderedRtMemRefDict *);
|
||||||
|
|
||||||
class ExecutionSession {
|
class ExecutionSession {
|
||||||
public:
|
public:
|
||||||
ExecutionSession(std::string sharedLibPath, std::string entryPointName);
|
ExecutionSession(std::string sharedLibPath, std::string entryPointName);
|
||||||
|
|
||||||
std::vector<std::unique_ptr<DynMemRef>> run(
|
std::vector<std::unique_ptr<RtMemRef>> run(
|
||||||
std::vector<std::unique_ptr<DynMemRef>>);
|
std::vector<std::unique_ptr<RtMemRef>>);
|
||||||
|
|
||||||
~ExecutionSession();
|
~ExecutionSession();
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,9 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "onnx/onnx_pb.h"
|
||||||
|
#include <third_party/onnx/onnx/onnx_pb.h>
|
||||||
|
|
||||||
#include "PyExecutionSession.hpp"
|
#include "PyExecutionSession.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
@ -16,40 +19,72 @@ namespace onnx_mlir {
|
||||||
std::vector<py::array> PyExecutionSession::pyRun(
|
std::vector<py::array> PyExecutionSession::pyRun(
|
||||||
std::vector<py::array> inputsPyArray) {
|
std::vector<py::array> inputsPyArray) {
|
||||||
assert(_entryPointFunc && "Entry point not loaded.");
|
assert(_entryPointFunc && "Entry point not loaded.");
|
||||||
auto *wrappedInput = createOrderedDynMemRefDict();
|
auto *wrappedInput = createOrderedRtMemRefDict();
|
||||||
int inputIdx = 0;
|
int inputIdx = 0;
|
||||||
for (auto inputPyArray : inputsPyArray) {
|
for (auto inputPyArray : inputsPyArray) {
|
||||||
auto *inputDynMemRef = createDynMemRef(inputPyArray.ndim());
|
auto *inputRtMemRef = createRtMemRef(inputPyArray.ndim());
|
||||||
assert(inputPyArray.flags() && py::array::c_style &&
|
assert(inputPyArray.flags() && py::array::c_style &&
|
||||||
"Expect contiguous python array.");
|
"Expect contiguous python array.");
|
||||||
|
|
||||||
if (inputPyArray.writeable()) {
|
if (inputPyArray.writeable()) {
|
||||||
inputDynMemRef->data = inputPyArray.mutable_data();
|
inputRtMemRef->data = inputPyArray.mutable_data();
|
||||||
inputDynMemRef->alignedData = inputPyArray.mutable_data();
|
inputRtMemRef->alignedData = inputPyArray.mutable_data();
|
||||||
} else {
|
} else {
|
||||||
// If data is not writable, copy them to a writable buffer.
|
// If data is not writable, copy them to a writable buffer.
|
||||||
auto *copiedData = (float *)malloc(inputPyArray.nbytes());
|
auto *copiedData = (float *)malloc(inputPyArray.nbytes());
|
||||||
memcpy(copiedData, inputPyArray.data(), inputPyArray.nbytes());
|
memcpy(copiedData, inputPyArray.data(), inputPyArray.nbytes());
|
||||||
inputDynMemRef->data = copiedData;
|
inputRtMemRef->data = copiedData;
|
||||||
inputDynMemRef->alignedData = copiedData;
|
inputRtMemRef->alignedData = copiedData;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < inputPyArray.ndim(); i++) {
|
for (int i = 0; i < inputPyArray.ndim(); i++) {
|
||||||
inputDynMemRef->sizes[i] = inputPyArray.shape(i);
|
inputRtMemRef->sizes[i] = inputPyArray.shape(i);
|
||||||
inputDynMemRef->strides[i] = inputPyArray.strides(i);
|
inputRtMemRef->strides[i] = inputPyArray.strides(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
setDynMemRef(wrappedInput, inputIdx++, inputDynMemRef);
|
setRtMemRef(wrappedInput, inputIdx++, inputRtMemRef);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<py::array> outputPyArrays;
|
std::vector<py::array> outputPyArrays;
|
||||||
auto *wrappedOutput = _entryPointFunc(wrappedInput);
|
auto *wrappedOutput = _entryPointFunc(wrappedInput);
|
||||||
for (int i = 0; i < numDynMemRefs(wrappedOutput); i++) {
|
for (int i = 0; i < numRtMemRefs(wrappedOutput); i++) {
|
||||||
auto *dynMemRef = getDynMemRef(wrappedOutput, i);
|
auto *dynMemRef = getRtMemRef(wrappedOutput, i);
|
||||||
auto shape = std::vector<int64_t>(
|
auto shape = std::vector<int64_t>(
|
||||||
dynMemRef->sizes, dynMemRef->sizes + dynMemRef->rank);
|
dynMemRef->sizes, dynMemRef->sizes + dynMemRef->rank);
|
||||||
outputPyArrays.emplace_back(
|
|
||||||
py::array(py::dtype("float32"), shape, dynMemRef->data));
|
// https://numpy.org/devdocs/user/basics.types.html
|
||||||
|
py::dtype dtype;
|
||||||
|
if (dynMemRef->onnx_dtype == onnx::TensorProto::FLOAT)
|
||||||
|
dtype = py::dtype("float32");
|
||||||
|
else if (dynMemRef->onnx_dtype = onnx::TensorProto::UINT8)
|
||||||
|
dtype = py::dtype("uint8");
|
||||||
|
else if (dynMemRef->onnx_dtype = onnx::TensorProto::INT8)
|
||||||
|
dtype = py::dtype("int8");
|
||||||
|
else if (dynMemRef->onnx_dtype = onnx::TensorProto::UINT16)
|
||||||
|
dtype = py::dtype("uint16");
|
||||||
|
else if (dynMemRef->onnx_dtype = onnx::TensorProto::INT16)
|
||||||
|
dtype = py::dtype("int16");
|
||||||
|
else if (dynMemRef->onnx_dtype == onnx::TensorProto::INT32)
|
||||||
|
dtype = py::dtype("int32");
|
||||||
|
else if (dynMemRef->onnx_dtype == onnx::TensorProto::INT64)
|
||||||
|
dtype = py::dtype("int64");
|
||||||
|
// TODO(tjingrant) wait for Tong's input for how to represent string.
|
||||||
|
else if (dynMemRef->onnx_dtype = onnx::TensorProto::BOOL)
|
||||||
|
dtype = py::dtype("bool_");
|
||||||
|
else if (dynMemRef->onnx_dtype = onnx::TensorProto::FLOAT16)
|
||||||
|
dtype = py::dtype("float32");
|
||||||
|
else if (dynMemRef->onnx_dtype = onnx::TensorProto::DOUBLE)
|
||||||
|
dtype = py::dtype("float64");
|
||||||
|
else if (dynMemRef->onnx_dtype == onnx::TensorProto::UINT32)
|
||||||
|
dtype = py::dtype("uint32");
|
||||||
|
else if (dynMemRef->onnx_dtype == onnx::TensorProto::UINT64)
|
||||||
|
dtype = py::dtype("uint64");
|
||||||
|
else {
|
||||||
|
fprintf(stderr, "Unsupported ONNX type in RtMemRef.onnx_dtype.");
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
outputPyArrays.emplace_back(py::array(dtype, shape, dynMemRef->data));
|
||||||
}
|
}
|
||||||
|
|
||||||
return outputPyArrays;
|
return outputPyArrays;
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
//===----------- DynMemRef.cpp - Dynamic MemRef Implementation ------------===//
|
//===----------- RtMemRef.cpp - Dynamic MemRef Implementation ------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019-2020 The IBM Research Authors.
|
// Copyright 2019-2020 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -18,7 +18,7 @@
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
|
||||||
#include "DynMemRef.h"
|
#include "RtMemRef.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Helper function to compute cartisian product.
|
// Helper function to compute cartisian product.
|
||||||
|
@ -39,19 +39,19 @@ inline std::vector<std::vector<INDEX_TYPE>> CartProduct(
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
DynMemRef::DynMemRef(int _rank) {
|
RtMemRef::RtMemRef(int _rank) {
|
||||||
rank = _rank;
|
rank = _rank;
|
||||||
sizes = (INDEX_TYPE *)malloc(rank * sizeof(INDEX_TYPE));
|
sizes = (INDEX_TYPE *)malloc(rank * sizeof(INDEX_TYPE));
|
||||||
strides = (int64_t *)malloc(rank * sizeof(int64_t));
|
strides = (int64_t *)malloc(rank * sizeof(int64_t));
|
||||||
}
|
}
|
||||||
|
|
||||||
INDEX_TYPE DynMemRef::size() const {
|
INDEX_TYPE RtMemRef::size() const {
|
||||||
return std::accumulate(sizes, sizes + rank, 1, std::multiplies<>());
|
return std::accumulate(sizes, sizes + rank, 1, std::multiplies<>());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<INDEX_TYPE>> DynMemRef::indexSet() const {
|
std::vector<std::vector<INDEX_TYPE>> RtMemRef::indexSet() const {
|
||||||
// First, we create index set of each dimension separately.
|
// First, we create index set of each dimension separately.
|
||||||
// i.e., for a tensor/DMR of shape (2, 3), its dimWiseIdxSet will be:
|
// i.e., for a tensor/RMR of shape (2, 3), its dimWiseIdxSet will be:
|
||||||
// {{0,1}, {0,1,2}};
|
// {{0,1}, {0,1,2}};
|
||||||
std::vector<std::vector<INDEX_TYPE>> dimWiseIdxSet;
|
std::vector<std::vector<INDEX_TYPE>> dimWiseIdxSet;
|
||||||
for (auto dimSize : std::vector<INDEX_TYPE>(sizes, sizes + rank)) {
|
for (auto dimSize : std::vector<INDEX_TYPE>(sizes, sizes + rank)) {
|
||||||
|
@ -60,18 +60,18 @@ std::vector<std::vector<INDEX_TYPE>> DynMemRef::indexSet() const {
|
||||||
dimWiseIdxSet.emplace_back(dimIdxSet);
|
dimWiseIdxSet.emplace_back(dimIdxSet);
|
||||||
}
|
}
|
||||||
// Then, the cartesian product of vectors within dimWiseIdxSet will be the
|
// Then, the cartesian product of vectors within dimWiseIdxSet will be the
|
||||||
// index set for the whole DMR.
|
// index set for the whole RMR.
|
||||||
return CartProduct(dimWiseIdxSet);
|
return CartProduct(dimWiseIdxSet);
|
||||||
}
|
}
|
||||||
|
|
||||||
INDEX_TYPE DynMemRef::computeOffset(std::vector<INDEX_TYPE> &idxs) const {
|
INDEX_TYPE RtMemRef::computeOffset(std::vector<INDEX_TYPE> &idxs) const {
|
||||||
auto dimStrides = std::vector<INDEX_TYPE>(strides, strides + rank);
|
auto dimStrides = std::vector<INDEX_TYPE>(strides, strides + rank);
|
||||||
INDEX_TYPE elemOffset = std::inner_product(
|
INDEX_TYPE elemOffset = std::inner_product(
|
||||||
idxs.begin(), idxs.end(), dimStrides.begin(), (INDEX_TYPE)0);
|
idxs.begin(), idxs.end(), dimStrides.begin(), (INDEX_TYPE)0);
|
||||||
return elemOffset;
|
return elemOffset;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> DynMemRef::computeStridesFromSizes() const {
|
std::vector<int64_t> RtMemRef::computeStridesFromSizes() const {
|
||||||
// Shift dimension sizes one to the left, fill in the vacated rightmost
|
// Shift dimension sizes one to the left, fill in the vacated rightmost
|
||||||
// element with 1; this gets us a vector that'll be more useful for computing
|
// element with 1; this gets us a vector that'll be more useful for computing
|
||||||
// strides of memory access along each dimension using prefix product (aka
|
// strides of memory access along each dimension using prefix product (aka
|
||||||
|
@ -86,7 +86,7 @@ std::vector<int64_t> DynMemRef::computeStridesFromSizes() const {
|
||||||
return dimStrides;
|
return dimStrides;
|
||||||
}
|
}
|
||||||
|
|
||||||
DynMemRef::~DynMemRef() {
|
RtMemRef::~RtMemRef() {
|
||||||
free(data);
|
free(data);
|
||||||
free(sizes);
|
free(sizes);
|
||||||
free(strides);
|
free(strides);
|
||||||
|
@ -95,27 +95,26 @@ DynMemRef::~DynMemRef() {
|
||||||
// An ordered dynamic MemRef dictionary.
|
// An ordered dynamic MemRef dictionary.
|
||||||
// The goal is to support accessing dynamic memory ref by name and by index.
|
// The goal is to support accessing dynamic memory ref by name and by index.
|
||||||
// Currently, only accessing by index is supported.
|
// Currently, only accessing by index is supported.
|
||||||
struct OrderedDynMemRefDict {
|
struct OrderedRtMemRefDict {
|
||||||
std::map<std::string, DynMemRef *> tensorDict;
|
std::map<std::string, RtMemRef *> tensorDict;
|
||||||
std::vector<std::string> orderedNames;
|
std::vector<std::string> orderedNames;
|
||||||
};
|
};
|
||||||
|
|
||||||
int numDynMemRefs(OrderedDynMemRefDict *dict) {
|
int numRtMemRefs(OrderedRtMemRefDict *dict) {
|
||||||
return dict->orderedNames.size();
|
return dict->orderedNames.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
OrderedDynMemRefDict *createOrderedDynMemRefDict() {
|
OrderedRtMemRefDict *createOrderedRtMemRefDict() {
|
||||||
return new OrderedDynMemRefDict();
|
return new OrderedRtMemRefDict();
|
||||||
}
|
}
|
||||||
|
|
||||||
DynMemRef *createDynMemRef(int rank) { return new DynMemRef(rank); }
|
RtMemRef *createRtMemRef(int rank) { return new RtMemRef(rank); }
|
||||||
|
|
||||||
DynMemRef *getDynMemRef(OrderedDynMemRefDict *tensorDict, int idx) {
|
RtMemRef *getRtMemRef(OrderedRtMemRefDict *tensorDict, int idx) {
|
||||||
return tensorDict->tensorDict[tensorDict->orderedNames[idx]];
|
return tensorDict->tensorDict[tensorDict->orderedNames[idx]];
|
||||||
}
|
}
|
||||||
|
|
||||||
void setDynMemRef(
|
void setRtMemRef(OrderedRtMemRefDict *tensorDict, int idx, RtMemRef *tensor) {
|
||||||
OrderedDynMemRefDict *tensorDict, int idx, DynMemRef *tensor) {
|
|
||||||
if (tensorDict->orderedNames.size() <= idx)
|
if (tensorDict->orderedNames.size() <= idx)
|
||||||
tensorDict->orderedNames.resize(idx + 1);
|
tensorDict->orderedNames.resize(idx + 1);
|
||||||
|
|
||||||
|
@ -130,30 +129,36 @@ void setDynMemRef(
|
||||||
tensorDict->tensorDict[tensorDict->orderedNames[idx]] = tensor;
|
tensorDict->tensorDict[tensorDict->orderedNames[idx]] = tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
void *getData(DynMemRef *dynMemRef) { return dynMemRef->data; }
|
void *getData(RtMemRef *dynMemRef) { return dynMemRef->data; }
|
||||||
|
|
||||||
void setData(DynMemRef *dynMemRef, void *dataPtr) { dynMemRef->data = dataPtr; }
|
void setData(RtMemRef *dynMemRef, void *dataPtr) { dynMemRef->data = dataPtr; }
|
||||||
|
|
||||||
void *getAlignedData(DynMemRef *dynMemRef) { return dynMemRef->alignedData; }
|
void *getAlignedData(RtMemRef *dynMemRef) { return dynMemRef->alignedData; }
|
||||||
|
|
||||||
void setAlignedData(DynMemRef *dynMemRef, void *dataPtr) {
|
void setAlignedData(RtMemRef *dynMemRef, void *dataPtr) {
|
||||||
dynMemRef->alignedData = dataPtr;
|
dynMemRef->alignedData = dataPtr;
|
||||||
}
|
}
|
||||||
|
|
||||||
INDEX_TYPE *getSizes(DynMemRef *dynMemRef) { return dynMemRef->sizes; }
|
INDEX_TYPE *getSizes(RtMemRef *dynMemRef) { return dynMemRef->sizes; }
|
||||||
|
|
||||||
void setSizes(DynMemRef *dynMemRef, INDEX_TYPE *sizes) {
|
void setSizes(RtMemRef *dynMemRef, INDEX_TYPE *sizes) {
|
||||||
for (int i = 0; i < dynMemRef->rank; i++)
|
for (int i = 0; i < dynMemRef->rank; i++)
|
||||||
dynMemRef->sizes[i] = sizes[i];
|
dynMemRef->sizes[i] = sizes[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t *getStrides(DynMemRef *dynMemRef) { return dynMemRef->strides; }
|
int64_t *getStrides(RtMemRef *dynMemRef) { return dynMemRef->strides; }
|
||||||
|
|
||||||
int64_t getSize(OrderedDynMemRefDict *dict) {
|
int64_t getSize(OrderedRtMemRefDict *dict) { return dict->orderedNames.size(); }
|
||||||
return dict->orderedNames.size();
|
|
||||||
|
void setDType(RtMemRef *dynMemRef, int onnxType) {
|
||||||
|
dynMemRef->onnx_dtype = onnxType;
|
||||||
}
|
}
|
||||||
|
|
||||||
void setStrides(DynMemRef *dynMemRef, int64_t *strides) {
|
int getDType(RtMemRef *dynMemRef) { return dynMemRef->onnx_dtype; }
|
||||||
|
|
||||||
|
unsigned int getRank(RtMemRef *dynMemRef) { return dynMemRef->rank; }
|
||||||
|
|
||||||
|
void setStrides(RtMemRef *dynMemRef, int64_t *strides) {
|
||||||
for (int i = 0; i < dynMemRef->rank; i++)
|
for (int i = 0; i < dynMemRef->rank; i++)
|
||||||
dynMemRef->sizes[i] = strides[i];
|
dynMemRef->sizes[i] = strides[i];
|
||||||
}
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
//===------------ DynMemRef.h - Dynamic MemRef Implementation -------------===//
|
//===------------ RtMemRef.h - Dynamic MemRef Implementation -------------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019-2020 The IBM Research Authors.
|
// Copyright 2019-2020 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -24,40 +24,69 @@
|
||||||
|
|
||||||
typedef int64_t INDEX_TYPE;
|
typedef int64_t INDEX_TYPE;
|
||||||
|
|
||||||
// This is a dynamic version of memref.
|
// Typically, MemRefs in MLIR context are used as a compile-time constructs.
|
||||||
// The same struct can be used to represent memrefs of
|
// Information such as element type and rank of the data payload is statically
|
||||||
// all ranks and type combinations.
|
// encoded, meaning that they are determined and fixed at compile-time. This
|
||||||
// We will refer to it as a DMR (Dynamic MemRef).
|
// presents significant burden for any runtime components trying to interact
|
||||||
struct DynMemRef {
|
// with the compiled executable.
|
||||||
|
//
|
||||||
|
// Thus a version of MemRef struct that is amenable to runtime manipulation is
|
||||||
|
// provided as a basis for building any runtime-related components providing
|
||||||
|
// user-facing programming interfaces. All information are dynamically encoded
|
||||||
|
// as members of this struct so that they can be accessed and modified easily
|
||||||
|
// during runtime.
|
||||||
|
//
|
||||||
|
// We will refer to it as a RMF (Runtime MemRef).
|
||||||
|
struct RtMemRef {
|
||||||
|
|
||||||
|
// Pointer to the raw memory space allocated to host the RMR content. This
|
||||||
|
// pointer should only be acessed for memory management purposes, not for
|
||||||
|
// reading RMR content.
|
||||||
void *data;
|
void *data;
|
||||||
|
|
||||||
|
// Pointer to the properly aligned array of elements stored in this Rmr.
|
||||||
void *alignedData;
|
void *alignedData;
|
||||||
|
|
||||||
|
// Distance between the start of the raw memory space and the first element of
|
||||||
|
// the RMR content.
|
||||||
INDEX_TYPE offset;
|
INDEX_TYPE offset;
|
||||||
|
|
||||||
|
// Number of dimensions of the array represented by the RMR.
|
||||||
unsigned int rank;
|
unsigned int rank;
|
||||||
|
|
||||||
|
// An array recording the per-dimension sizes of the array represented by the
|
||||||
|
// RMR.
|
||||||
INDEX_TYPE *sizes;
|
INDEX_TYPE *sizes;
|
||||||
|
|
||||||
|
// An array recording the per-dimension strides of the array represented by
|
||||||
|
// the RMR.
|
||||||
int64_t *strides;
|
int64_t *strides;
|
||||||
|
|
||||||
|
// Refer to TensorProto_DataType at
|
||||||
|
// https://github.com/onnx/onnx/blob/cc2230603422bae893d5bc900d2d773ab34400a4/onnx/onnx-ml.proto#L451
|
||||||
|
// for enum value interpretation.
|
||||||
|
unsigned int onnx_dtype;
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
explicit DynMemRef(int _rank);
|
explicit RtMemRef(int _rank);
|
||||||
|
|
||||||
// Create a full DMR of type T and shape _sizes, with all data fields
|
// Create a full RMR of type T and shape _sizes, with all data fields
|
||||||
// initialized to proper values and data pointers malloc'ed.
|
// initialized to proper values and data pointers malloc'ed.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static DynMemRef *create(std::vector<INDEX_TYPE> _sizes) {
|
static RtMemRef *create(std::vector<INDEX_TYPE> _sizes) {
|
||||||
auto dmr = new DynMemRef(_sizes.size());
|
auto rmr = new RtMemRef(_sizes.size());
|
||||||
dmr->offset = 0;
|
rmr->offset = 0;
|
||||||
dmr->rank = _sizes.size();
|
rmr->rank = _sizes.size();
|
||||||
dmr->sizes = (INDEX_TYPE *)malloc(dmr->rank * sizeof(INDEX_TYPE));
|
rmr->sizes = (INDEX_TYPE *)malloc(rmr->rank * sizeof(INDEX_TYPE));
|
||||||
std::copy(_sizes.begin(), _sizes.end(), dmr->sizes);
|
std::copy(_sizes.begin(), _sizes.end(), rmr->sizes);
|
||||||
|
|
||||||
dmr->strides = (int64_t *)malloc(dmr->rank * sizeof(int64_t));
|
rmr->strides = (int64_t *)malloc(rmr->rank * sizeof(int64_t));
|
||||||
auto computedStrides = dmr->computeStridesFromSizes();
|
auto computedStrides = rmr->computeStridesFromSizes();
|
||||||
std::copy(computedStrides.begin(), computedStrides.end(), dmr->strides);
|
std::copy(computedStrides.begin(), computedStrides.end(), rmr->strides);
|
||||||
|
|
||||||
dmr->data = malloc(dmr->size() * sizeof(T));
|
rmr->data = malloc(rmr->size() * sizeof(T));
|
||||||
dmr->alignedData = dmr->data;
|
rmr->alignedData = rmr->data;
|
||||||
|
|
||||||
return dmr;
|
return rmr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Access an element (by reference) at index position idxs.
|
// Access an element (by reference) at index position idxs.
|
||||||
|
@ -75,13 +104,13 @@ struct DynMemRef {
|
||||||
return typedPtr[idx];
|
return typedPtr[idx];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get a typed ptr to the data content of the DMR.
|
// Get a typed ptr to the data content of the RMR.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T *typedPtr() {
|
T *typedPtr() {
|
||||||
return (T *)data;
|
return (T *)data;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get how many elements are stored in DMR, as implied by its shape.
|
// Get how many elements are stored in RMR, as implied by its shape.
|
||||||
INDEX_TYPE size() const;
|
INDEX_TYPE size() const;
|
||||||
|
|
||||||
// Helper function to compute strides of access along each dimensions from its
|
// Helper function to compute strides of access along each dimensions from its
|
||||||
|
@ -92,63 +121,77 @@ struct DynMemRef {
|
||||||
INDEX_TYPE computeOffset(std::vector<INDEX_TYPE> &idxs) const;
|
INDEX_TYPE computeOffset(std::vector<INDEX_TYPE> &idxs) const;
|
||||||
|
|
||||||
// Get the index set (i.e., all valid multi-dimensional array indexes that can
|
// Get the index set (i.e., all valid multi-dimensional array indexes that can
|
||||||
// be used to access this DMR's constituent elements).
|
// be used to access this RMR's constituent elements).
|
||||||
std::vector<std::vector<INDEX_TYPE>> indexSet() const;
|
std::vector<std::vector<INDEX_TYPE>> indexSet() const;
|
||||||
|
|
||||||
~DynMemRef();
|
~RtMemRef();
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
// Ordered DynMemRef Dictionary is a data structure for wrapping the input
|
// Ordered RtMemRef Dictionary is a data structure for wrapping the input
|
||||||
// dynmemrefs so that they can be addressed both by index and by name.
|
// dynmemrefs so that they can be addressed both by index and by name.
|
||||||
struct OrderedDynMemRefDict;
|
struct OrderedRtMemRefDict;
|
||||||
|
|
||||||
#else
|
#else
|
||||||
typedef struct DynMemRef DynMemRef;
|
typedef struct RtMemRef RtMemRef;
|
||||||
typedef struct _OrderedDynMemRefDict OrderedDynMemRefDict;
|
typedef struct _OrderedRtMemRefDict OrderedRtMemRefDict;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Get number of dynamic memrefs in OrderedDynMemRefDict dict.
|
// Get number of dynamic memrefs in OrderedRtMemRefDict dict.
|
||||||
int numDynMemRefs(OrderedDynMemRefDict *dict);
|
int numRtMemRefs(OrderedRtMemRefDict *dict);
|
||||||
|
|
||||||
// Create an ordered dynamic memref dictionary.
|
// Create an ordered dynamic memref dictionary.
|
||||||
OrderedDynMemRefDict *createOrderedDynMemRefDict();
|
OrderedRtMemRefDict *createOrderedRtMemRefDict();
|
||||||
|
|
||||||
// Get how many dynamic memrefs are in dict.
|
// Get how many dynamic memrefs are in dict.
|
||||||
int64_t getSize(OrderedDynMemRefDict *dict);
|
int64_t getSize(OrderedRtMemRefDict *dict);
|
||||||
|
|
||||||
// Create a dynmemref with a certain rank.
|
// Create a dynmemref with a certain rank.
|
||||||
DynMemRef *createDynMemRef(int rank);
|
RtMemRef *createRtMemRef(int rank);
|
||||||
|
|
||||||
// Get the i-th dynmemref from orderedDict.
|
// Get the i-th dynmemref from orderedDict.
|
||||||
DynMemRef *getDynMemRef(OrderedDynMemRefDict *orderedDict, int i);
|
RtMemRef *getRtMemRef(OrderedRtMemRefDict *orderedDict, int i);
|
||||||
|
|
||||||
// Set the i-th dynmemref in orderedDict to be dynMemRef.
|
// Set the i-th dynmemref in orderedDict to be dynMemRef.
|
||||||
void setDynMemRef(
|
void setRtMemRef(OrderedRtMemRefDict *tensorDict, int idx, RtMemRef *dynMemRef);
|
||||||
OrderedDynMemRefDict *tensorDict, int idx, DynMemRef *dynMemRef);
|
|
||||||
|
|
||||||
// Get data pointer from dynMemRef.
|
// Get data pointer from dynMemRef.
|
||||||
void *getData(DynMemRef *dynMemRef);
|
void *getData(RtMemRef *dynMemRef);
|
||||||
|
|
||||||
// Set data pointer for dynMemRef.
|
// Set data pointer for dynMemRef.
|
||||||
void setData(DynMemRef *dynMemRef, void *data);
|
void setData(RtMemRef *dynMemRef, void *data);
|
||||||
|
|
||||||
// Get algined data pointer from dynMemRef.
|
// Get algined data pointer from dynMemRef.
|
||||||
void *getAlignedData(DynMemRef *);
|
void *getAlignedData(RtMemRef *);
|
||||||
|
|
||||||
// Set aligned data pointer for dynMemRef.
|
// Set aligned data pointer for dynMemRef.
|
||||||
void setAlignedData(DynMemRef *, void *);
|
void setAlignedData(RtMemRef *, void *);
|
||||||
|
|
||||||
|
// Get the data type enum value of the dynMemRef.
|
||||||
|
int getDType(RtMemRef *dynMemRef);
|
||||||
|
|
||||||
|
// Set the data type enum value of the dynMemRef.
|
||||||
|
void setDType(RtMemRef *dynMemRef, int onnxType);
|
||||||
|
|
||||||
|
// Get the rank of the dynMemRef.
|
||||||
|
unsigned int getRank(RtMemRef *dynMemRef);
|
||||||
|
|
||||||
// Get ptr to sizes array.
|
// Get ptr to sizes array.
|
||||||
INDEX_TYPE *getSizes(DynMemRef *);
|
INDEX_TYPE *getSizes(RtMemRef *);
|
||||||
|
|
||||||
|
// Set the sizes array (by copying size values from array `sizes`).
|
||||||
|
void setSizes(RtMemRef *, INDEX_TYPE *sizes);
|
||||||
|
|
||||||
// Get ptr to strides array.
|
// Get ptr to strides array.
|
||||||
int64_t *getStrides(DynMemRef *);
|
int64_t *getStrides(RtMemRef *);
|
||||||
|
|
||||||
|
// Set the strides array (by copying stride values from array `strides`).
|
||||||
|
void setStrides(RtMemRef *, int64_t *strides);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
|
@ -164,22 +207,22 @@ void printVector(std::vector<T> vec, std::string _delimiter = ",",
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
DynMemRef *getRndRealDmr(
|
RtMemRef *getRndRealRmr(
|
||||||
std::vector<INDEX_TYPE> sizes, T lb = -1.0, T ub = 1.0) {
|
std::vector<INDEX_TYPE> sizes, T lb = -1.0, T ub = 1.0) {
|
||||||
// Will be used to obtain a seed for the random number engine
|
// Will be used to obtain a seed for the random number engine
|
||||||
std::random_device rd;
|
std::random_device rd;
|
||||||
// Standard mersenne_twister_engine seeded with rd()
|
// Standard mersenne_twister_engine seeded with rd()
|
||||||
std::mt19937 gen(rd());
|
std::mt19937 gen(rd());
|
||||||
std::uniform_real_distribution<> dis(lb, ub);
|
std::uniform_real_distribution<> dis(lb, ub);
|
||||||
auto dmr = DynMemRef::create<T>(sizes);
|
auto rmr = RtMemRef::create<T>(sizes);
|
||||||
auto ptr = (T *)dmr->data;
|
auto ptr = (T *)rmr->data;
|
||||||
std::generate(ptr, ptr + dmr->size(), [&]() { return dis(gen); });
|
std::generate(ptr, ptr + rmr->size(), [&]() { return dis(gen); });
|
||||||
return dmr;
|
return rmr;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline bool isDmrClose(
|
inline bool isRmrClose(
|
||||||
DynMemRef *a, DynMemRef *b, float rtol = 1e-5, float atol = 1e-5) {
|
RtMemRef *a, RtMemRef *b, float rtol = 1e-5, float atol = 1e-5) {
|
||||||
|
|
||||||
// Compare shape.
|
// Compare shape.
|
||||||
auto aShape = std::vector<INDEX_TYPE>(a->sizes, a->sizes + a->rank);
|
auto aShape = std::vector<INDEX_TYPE>(a->sizes, a->sizes + a->rank);
|
||||||
|
@ -232,3 +275,6 @@ inline bool isDmrClose(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// Will transition from RtMemRef to RtMemRef soon.
|
||||||
|
typedef RtMemRef RtMemRef;
|
|
@ -19,6 +19,7 @@
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
|
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "onnx/onnx_pb.h"
|
||||||
#include "llvm/ADT/Sequence.h"
|
#include "llvm/ADT/Sequence.h"
|
||||||
|
|
||||||
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
|
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
|
||||||
|
@ -29,6 +30,38 @@ using namespace mlir;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
static onnx::TensorProto::DataType llvmTypeToOnnxType(
|
||||||
|
mlir::LLVM::LLVMType elemType) {
|
||||||
|
if (elemType.isFloatTy())
|
||||||
|
return onnx::TensorProto::FLOAT;
|
||||||
|
if (elemType.isUnsignedInteger(8))
|
||||||
|
return onnx::TensorProto::UINT8;
|
||||||
|
if (elemType.isSignedInteger(8))
|
||||||
|
return onnx::TensorProto::INT8;
|
||||||
|
if (elemType.isUnsignedInteger(16))
|
||||||
|
return onnx::TensorProto::UINT16;
|
||||||
|
if (elemType.isSignedInteger(16))
|
||||||
|
return onnx::TensorProto::INT16;
|
||||||
|
if (elemType.isSignedInteger(32))
|
||||||
|
return onnx::TensorProto::INT32;
|
||||||
|
if (elemType.isSignedInteger(64))
|
||||||
|
return onnx::TensorProto::INT64;
|
||||||
|
// TODO, wait for Tong's input about how string is represented in MLIR.
|
||||||
|
if (elemType.isInteger(1))
|
||||||
|
return onnx::TensorProto::BOOL;
|
||||||
|
if (elemType.isHalfTy())
|
||||||
|
return onnx::TensorProto::FLOAT16;
|
||||||
|
if (elemType.isDoubleTy())
|
||||||
|
return onnx::TensorProto::DOUBLE;
|
||||||
|
if (elemType.isUnsignedInteger(32))
|
||||||
|
return onnx::TensorProto::UINT32;
|
||||||
|
if (elemType.isUnsignedInteger(64))
|
||||||
|
return onnx::TensorProto::INT64;
|
||||||
|
// Complex types don't seem to exist in LLVM Dialect.
|
||||||
|
elemType.dump();
|
||||||
|
llvm_unreachable("Unexpected LLVM type, cannot be converted to ONNX type.");
|
||||||
|
}
|
||||||
|
|
||||||
static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName,
|
static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName,
|
||||||
ModuleOp module, mlir::LLVM::LLVMType funcType, PatternRewriter &rewriter) {
|
ModuleOp module, mlir::LLVM::LLVMType funcType, PatternRewriter &rewriter) {
|
||||||
auto *context = module.getContext();
|
auto *context = module.getContext();
|
||||||
|
@ -374,6 +407,8 @@ public:
|
||||||
SET_DATA,
|
SET_DATA,
|
||||||
GET_SIZES,
|
GET_SIZES,
|
||||||
GET_STRIDES,
|
GET_STRIDES,
|
||||||
|
SET_DTYPE,
|
||||||
|
GET_DTYPE,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ApiSpec {
|
struct ApiSpec {
|
||||||
|
@ -469,7 +504,7 @@ public:
|
||||||
|
|
||||||
// Fill in the memref underlying ptrToMemRef with information extracted
|
// Fill in the memref underlying ptrToMemRef with information extracted
|
||||||
// from dynMemRef.
|
// from dynMemRef.
|
||||||
fillPtrToMemRefWithDynMemRef(
|
fillPtrToMemRefWithRtMemRef(
|
||||||
dynMemRef, ptrToMemRef, rewriter, loc, apiRegistry, llvmDialect);
|
dynMemRef, ptrToMemRef, rewriter, loc, apiRegistry, llvmDialect);
|
||||||
|
|
||||||
// ptrToMemRef will be an input to main computation graph function.
|
// ptrToMemRef will be an input to main computation graph function.
|
||||||
|
@ -518,14 +553,14 @@ public:
|
||||||
auto outMemRefRank = getRankFromMemRefType(outMemRefTy);
|
auto outMemRefRank = getRankFromMemRefType(outMemRefTy);
|
||||||
auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>(
|
auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>(
|
||||||
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
|
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
|
||||||
auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
|
auto outRtMemRef = callApi(rewriter, loc, apiRegistry,
|
||||||
API::CREATE_DYN_MEM_REF, {outMemRefRankVal});
|
API::CREATE_DYN_MEM_REF, {outMemRefRankVal});
|
||||||
fillDynMemRefWithMemRef(
|
fillRtMemRefWithMemRef(
|
||||||
memRef, outDynMemRef, rewriter, loc, apiRegistry, llvmDialect);
|
memRef, outRtMemRef, rewriter, loc, apiRegistry, llvmDialect);
|
||||||
auto idx = rewriter.create<LLVM::ConstantOp>(
|
auto idx = rewriter.create<LLVM::ConstantOp>(
|
||||||
loc, int32Ty, rewriter.getI32IntegerAttr(i));
|
loc, int32Ty, rewriter.getI32IntegerAttr(i));
|
||||||
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
|
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
|
||||||
{wrappedOutput, idx, outDynMemRef});
|
{wrappedOutput, idx, outRtMemRef});
|
||||||
}
|
}
|
||||||
// Return wrapped output.
|
// Return wrapped output.
|
||||||
rewriter.create<LLVM::ReturnOp>(
|
rewriter.create<LLVM::ReturnOp>(
|
||||||
|
@ -549,14 +584,16 @@ private:
|
||||||
// specifying its signature.
|
// specifying its signature.
|
||||||
// clang-format off
|
// clang-format off
|
||||||
std::vector<ApiSpec> apiSpecs = {
|
std::vector<ApiSpec> apiSpecs = {
|
||||||
ApiSpec(API::CREATE_ORDERED_DYN_MEM_REF_DICT, "createOrderedDynMemRefDict", opaquePtrTy, {}),
|
ApiSpec(API::CREATE_ORDERED_DYN_MEM_REF_DICT, "createOrderedRtMemRefDict", opaquePtrTy, {}),
|
||||||
ApiSpec(API::CREATE_DYN_MEM_REF, "createDynMemRef", opaquePtrTy, {int32Ty}),
|
ApiSpec(API::CREATE_DYN_MEM_REF, "createRtMemRef", opaquePtrTy, {int32Ty}),
|
||||||
ApiSpec(API::GET_DATA, "getData", opaquePtrTy, {opaquePtrTy}),
|
ApiSpec(API::GET_DATA, "getData", opaquePtrTy, {opaquePtrTy}),
|
||||||
ApiSpec(API::SET_DATA, "setData", voidTy, {opaquePtrTy, opaquePtrTy}),
|
ApiSpec(API::SET_DATA, "setData", voidTy, {opaquePtrTy, opaquePtrTy}),
|
||||||
ApiSpec(API::GET_DYN_MEM_REF, "getDynMemRef", opaquePtrTy, {opaquePtrTy, int32Ty}),
|
ApiSpec(API::GET_DYN_MEM_REF, "getRtMemRef", opaquePtrTy, {opaquePtrTy, int32Ty}),
|
||||||
ApiSpec(API::SET_DYN_MEM_REF, "setDynMemRef", voidTy, {opaquePtrTy, int32Ty, opaquePtrTy}),
|
ApiSpec(API::SET_DYN_MEM_REF, "setRtMemRef", voidTy, {opaquePtrTy, int32Ty, opaquePtrTy}),
|
||||||
ApiSpec(API::GET_SIZES, "getSizes", int64PtrTy, {opaquePtrTy}),
|
ApiSpec(API::GET_SIZES, "getSizes", int64PtrTy, {opaquePtrTy}),
|
||||||
ApiSpec(API::GET_STRIDES, "getStrides", int64PtrTy, {opaquePtrTy})
|
ApiSpec(API::GET_STRIDES, "getStrides", int64PtrTy, {opaquePtrTy}),
|
||||||
|
ApiSpec(API::GET_DTYPE, "getDType", int32Ty, {opaquePtrTy}),
|
||||||
|
ApiSpec(API::SET_DTYPE, "setDType", voidTy, {opaquePtrTy, int32Ty}),
|
||||||
};
|
};
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
|
@ -598,7 +635,7 @@ private:
|
||||||
return *entryPointEntryBlock;
|
return *entryPointEntryBlock;
|
||||||
}
|
}
|
||||||
|
|
||||||
void fillPtrToMemRefWithDynMemRef(Value &dynMemRef, Value &ptrToMemRef,
|
void fillPtrToMemRefWithRtMemRef(Value &dynMemRef, Value &ptrToMemRef,
|
||||||
PatternRewriter &rewriter, const Location &loc,
|
PatternRewriter &rewriter, const Location &loc,
|
||||||
const std::map<API, ApiSpec> &apiRegistry,
|
const std::map<API, ApiSpec> &apiRegistry,
|
||||||
LLVM::LLVMDialect *llvmDialect) const {
|
LLVM::LLVMDialect *llvmDialect) const {
|
||||||
|
@ -659,12 +696,13 @@ private:
|
||||||
rewriter.create<LLVM::StoreOp>(loc, memRef, ptrToMemRef);
|
rewriter.create<LLVM::StoreOp>(loc, memRef, ptrToMemRef);
|
||||||
}
|
}
|
||||||
|
|
||||||
void fillDynMemRefWithMemRef(Value &outMemRef, Value &outDynMemRef,
|
void fillRtMemRefWithMemRef(Value &outMemRef, Value &outRtMemRef,
|
||||||
PatternRewriter &rewriter, const Location &loc,
|
PatternRewriter &rewriter, const Location &loc,
|
||||||
const std::map<API, ApiSpec> &apiRegistry,
|
const std::map<API, ApiSpec> &apiRegistry,
|
||||||
LLVM::LLVMDialect *llvmDialect) const {
|
LLVM::LLVMDialect *llvmDialect) const {
|
||||||
auto outMemRefTy = outMemRef.getType().dyn_cast<LLVM::LLVMType>();
|
auto outMemRefTy = outMemRef.getType().dyn_cast<LLVM::LLVMType>();
|
||||||
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||||
|
auto int32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
|
||||||
|
|
||||||
// Extract the data pointer, and record it in dynamic mem ref created.
|
// Extract the data pointer, and record it in dynamic mem ref created.
|
||||||
Value outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(loc,
|
Value outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(loc,
|
||||||
|
@ -673,13 +711,19 @@ private:
|
||||||
outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>(
|
outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>(
|
||||||
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr);
|
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr);
|
||||||
callApi(rewriter, loc, apiRegistry, API::SET_DATA,
|
callApi(rewriter, loc, apiRegistry, API::SET_DATA,
|
||||||
{outDynMemRef, outMemRefDataPtr});
|
{outRtMemRef, outMemRefDataPtr});
|
||||||
|
auto elemTy = outMemRefTy.getStructElementType(0).getPointerElementTy();
|
||||||
|
auto onnxTy = llvmTypeToOnnxType(elemTy);
|
||||||
|
auto onnxTyVal = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, int32Ty, rewriter.getI32IntegerAttr(onnxTy));
|
||||||
|
callApi(
|
||||||
|
rewriter, loc, apiRegistry, API::SET_DTYPE, {outRtMemRef, onnxTyVal});
|
||||||
|
|
||||||
auto rank = getRankFromMemRefType(outMemRefTy);
|
auto rank = getRankFromMemRefType(outMemRefTy);
|
||||||
auto sizesArrayPtr =
|
auto sizesArrayPtr =
|
||||||
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {outDynMemRef});
|
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {outRtMemRef});
|
||||||
auto stridesArrayPtr =
|
auto stridesArrayPtr =
|
||||||
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {outDynMemRef});
|
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {outRtMemRef});
|
||||||
|
|
||||||
for (decltype(rank) i = 0; i < rank; i++) {
|
for (decltype(rank) i = 0; i < rank; i++) {
|
||||||
auto dimIdx = rewriter.create<LLVM::ConstantOp>(
|
auto dimIdx = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
|
|
@ -6,7 +6,7 @@ target_link_libraries(TestConv
|
||||||
rapidcheck
|
rapidcheck
|
||||||
MainUtils
|
MainUtils
|
||||||
ExecutionSession
|
ExecutionSession
|
||||||
DynMemRefUtils)
|
RtMemRefUtils)
|
||||||
|
|
||||||
target_include_directories(TestConv
|
target_include_directories(TestConv
|
||||||
PRIVATE
|
PRIVATE
|
||||||
|
|
|
@ -96,13 +96,13 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H,
|
||||||
onnx_mlir::ExecutionSession sess(
|
onnx_mlir::ExecutionSession sess(
|
||||||
pathStr + ".so", "_dyn_entry_point_main_graph");
|
pathStr + ".so", "_dyn_entry_point_main_graph");
|
||||||
|
|
||||||
std::vector<unique_ptr<DynMemRef>> inputs;
|
std::vector<unique_ptr<RtMemRef>> inputs;
|
||||||
auto xDmr = unique_ptr<DynMemRef>(getRndRealDmr<float>({N, C, H, W}));
|
auto xRmr = unique_ptr<RtMemRef>(getRndRealRmr<float>({N, C, H, W}));
|
||||||
inputs.emplace_back(move(xDmr));
|
inputs.emplace_back(move(xRmr));
|
||||||
auto wDmr = unique_ptr<DynMemRef>(getRndRealDmr<float>({C, C, kH, kW}));
|
auto wRmr = unique_ptr<RtMemRef>(getRndRealRmr<float>({C, C, kH, kW}));
|
||||||
inputs.emplace_back(move(wDmr));
|
inputs.emplace_back(move(wRmr));
|
||||||
|
|
||||||
auto ref = DynMemRef::create<float>({NOut, COut, HOut, WOut});
|
auto ref = RtMemRef::create<float>({NOut, COut, HOut, WOut});
|
||||||
auto &img = inputs.at(0);
|
auto &img = inputs.at(0);
|
||||||
auto &filter = inputs.at(1);
|
auto &filter = inputs.at(1);
|
||||||
for (int64_t n = 0; n < NOut; n++)
|
for (int64_t n = 0; n < NOut; n++)
|
||||||
|
@ -124,7 +124,7 @@ bool isOMConvTheSameAsNaiveImplFor(const int N, const int C, const int H,
|
||||||
auto outputs = sess.run(move(inputs));
|
auto outputs = sess.run(move(inputs));
|
||||||
auto &conv = outputs.at(0);
|
auto &conv = outputs.at(0);
|
||||||
|
|
||||||
return isDmrClose<float>(conv.get(), ref);
|
return isRmrClose<float>(conv.get(), ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
|
|
Loading…
Reference in New Issue