Merge remote-tracking branch 'origin/master' into matmul-shape
This commit is contained in:
commit
0bc07ef661
|
@ -3,27 +3,12 @@ jobs:
|
||||||
build:
|
build:
|
||||||
docker:
|
docker:
|
||||||
- image: circleci/python
|
- image: circleci/python
|
||||||
|
resource_class: medium+
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- run:
|
- run:
|
||||||
name: Installing GCC, CMake, Ninja, Protobuf
|
name: Installing GCC, CMake, Ninja, Protobuf
|
||||||
command: sudo apt-get update && sudo apt-get install -y gcc g++ cmake ninja-build protobuf-compiler
|
command: sudo apt-get update && sudo apt-get install -y gcc g++ cmake ninja-build protobuf-compiler
|
||||||
# Use cached mlir installation if possible.
|
|
||||||
- restore_cache:
|
|
||||||
key: V2-LLVM-PROJECT-{{ arch }}
|
|
||||||
- run:
|
|
||||||
name: Install MLIR
|
|
||||||
command: |
|
|
||||||
# Check whether cache restoration succeeds by checking whether
|
|
||||||
# mlir-opt executable exists.
|
|
||||||
if [ ! -f llvm-project/build/bin/mlir-opt ]; then
|
|
||||||
export MAKEFLAGS=-j4
|
|
||||||
source utils/install-mlir.sh
|
|
||||||
fi
|
|
||||||
- save_cache:
|
|
||||||
key: V2-LLVM-PROJECT-{{ arch }}
|
|
||||||
paths:
|
|
||||||
- llvm-project
|
|
||||||
- checkout:
|
- checkout:
|
||||||
path: ONNF
|
path: ONNF
|
||||||
- run:
|
- run:
|
||||||
|
@ -31,9 +16,30 @@ jobs:
|
||||||
command: |
|
command: |
|
||||||
cd ONNF
|
cd ONNF
|
||||||
git submodule update --init --recursive
|
git submodule update --init --recursive
|
||||||
|
# Use cached mlir installation if possible.
|
||||||
|
- restore_cache:
|
||||||
|
key: V4-LLVM-PROJECT-{{ arch }}
|
||||||
|
- run:
|
||||||
|
name: Install MLIR
|
||||||
|
command: |
|
||||||
|
# Check whether cache restoration succeeds by checking whether
|
||||||
|
# mlir-opt executable exists.
|
||||||
|
if [ ! -f llvm-project/build/bin/mlir-opt ]; then
|
||||||
|
source ONNF/utils/install-mlir.sh
|
||||||
|
fi
|
||||||
|
- save_cache:
|
||||||
|
key: V4-LLVM-PROJECT-{{ arch }}
|
||||||
|
paths:
|
||||||
|
- llvm-project
|
||||||
- run:
|
- run:
|
||||||
name: Install ONNF
|
name: Install ONNF
|
||||||
command: source ONNF/utils/install-onnf.sh
|
command: source ONNF/utils/install-onnf.sh
|
||||||
|
- run:
|
||||||
|
name: Run End-To-End Tests
|
||||||
|
command: |
|
||||||
|
sudo pip install -q onnx
|
||||||
|
cd ONNF/build
|
||||||
|
cmake --build . --target run-onnx-backend-test
|
||||||
- run:
|
- run:
|
||||||
name: Run DocCheck
|
name: Run DocCheck
|
||||||
command: cd ONNF/build && cmake --build . --target check-doc
|
command: cd ONNF/build && cmake --build . --target check-doc
|
||||||
|
|
|
@ -7,3 +7,6 @@
|
||||||
[submodule "third_party/pybind11"]
|
[submodule "third_party/pybind11"]
|
||||||
path = third_party/pybind11
|
path = third_party/pybind11
|
||||||
url = https://github.com/pybind/pybind11.git
|
url = https://github.com/pybind/pybind11.git
|
||||||
|
[submodule "third_party/variant"]
|
||||||
|
path = third_party/variant
|
||||||
|
url = git@github.com:mpark/variant.git
|
||||||
|
|
|
@ -22,9 +22,9 @@ include(MLIR.cmake)
|
||||||
add_subdirectory(third_party/onnx)
|
add_subdirectory(third_party/onnx)
|
||||||
add_subdirectory(third_party/benchmark)
|
add_subdirectory(third_party/benchmark)
|
||||||
add_subdirectory(third_party/pybind11)
|
add_subdirectory(third_party/pybind11)
|
||||||
|
add_subdirectory(third_party/variant)
|
||||||
|
|
||||||
set(CMAKE_CXX_STANDARD 14)
|
set(CMAKE_CXX_STANDARD 14)
|
||||||
add_subdirectory(src)
|
add_subdirectory(src)
|
||||||
add_subdirectory(doc)
|
add_subdirectory(doc)
|
||||||
add_subdirectory(test)
|
add_subdirectory(test)
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,8 @@ cmake -G Ninja ../llvm \
|
||||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||||
-DLLVM_ENABLE_RTTI=ON
|
-DLLVM_ENABLE_RTTI=ON
|
||||||
|
|
||||||
cmake --build . --target check-mlir -- ${MAKEFLAGS}
|
cmake --build . --target
|
||||||
|
cmake --build . --target check-mlir
|
||||||
```
|
```
|
||||||
|
|
||||||
Two environment variables need to be set:
|
Two environment variables need to be set:
|
||||||
|
@ -42,6 +43,7 @@ cmake ..
|
||||||
cmake --build . --target onnf
|
cmake --build . --target onnf
|
||||||
|
|
||||||
# Run FileCheck tests:
|
# Run FileCheck tests:
|
||||||
|
export LIT_OPTS=-v
|
||||||
cmake --build . --target check-mlir-lit
|
cmake --build . --target check-mlir-lit
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,7 @@ ONNX operations for which some work is needed.
|
||||||
| Selu | Tung | v | v | |
|
| Selu | Tung | v | v | |
|
||||||
| Sigmoid | Tung | v | v | |
|
| Sigmoid | Tung | v | v | |
|
||||||
| Sinh | Tung | v | v | |
|
| Sinh | Tung | v | v | |
|
||||||
|
| Softmax | Tung | v | v | |
|
||||||
| Sub | Tung | v | v | M |
|
| Sub | Tung | v | v | M |
|
||||||
| Sum | Tung | v | v | M |
|
| Sum | Tung | v | v | M |
|
||||||
| Tanh | Tung | v | v | |
|
| Tanh | Tung | v | v | |
|
||||||
|
|
|
@ -69,8 +69,9 @@ add_subdirectory(runtime)
|
||||||
add_executable(onnf main.cpp)
|
add_executable(onnf main.cpp)
|
||||||
|
|
||||||
target_link_libraries(onnf builder ${MLIRLibs} onnf_transform onnf_shape_inference onnf_lower_frontend)
|
target_link_libraries(onnf builder ${MLIRLibs} onnf_transform onnf_shape_inference onnf_lower_frontend)
|
||||||
set_target_properties(onnf PROPERTIES LINK_FLAGS "-lz")
|
|
||||||
whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs})
|
whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs})
|
||||||
|
find_package(ZLIB REQUIRED)
|
||||||
|
target_link_libraries(onnf ${ZLIB_LIBRARIES})
|
||||||
|
|
||||||
target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR})
|
target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR})
|
||||||
target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR})
|
target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR})
|
||||||
|
|
|
@ -7,8 +7,9 @@ add_library(builder
|
||||||
target_include_directories(builder PRIVATE ${CMAKE_SOURCE_DIR})
|
target_include_directories(builder PRIVATE ${CMAKE_SOURCE_DIR})
|
||||||
target_include_directories(builder PRIVATE ${CMAKE_BINARY_DIR})
|
target_include_directories(builder PRIVATE ${CMAKE_BINARY_DIR})
|
||||||
|
|
||||||
target_link_libraries(builder compiler onnx ${MLIRLibs} curses)
|
target_link_libraries(builder compiler onnx ${MLIRLibs} curses mpark_variant)
|
||||||
target_include_directories(builder
|
target_include_directories(builder
|
||||||
PRIVATE
|
PRIVATE
|
||||||
${CMAKE_SOURCE_DIR}/third_party/onnx
|
${CMAKE_SOURCE_DIR}/third_party/onnx
|
||||||
|
${CMAKE_SOURCE_DIR}/third_party/variant
|
||||||
${CMAKE_SOURCE_DIR})
|
${CMAKE_SOURCE_DIR})
|
||||||
|
|
|
@ -14,11 +14,16 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include <map>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <regex>
|
#include <regex>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <map>
|
|
||||||
|
// Using backported variant.
|
||||||
|
// bstd = backported standard library.
|
||||||
|
#include <mpark/variant.hpp>
|
||||||
|
namespace bstd = mpark;
|
||||||
|
|
||||||
#include "mlir/Analysis/Verifier.h"
|
#include "mlir/Analysis/Verifier.h"
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
|
@ -42,15 +47,15 @@
|
||||||
namespace onnf {
|
namespace onnf {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void replaceAll(
|
void replaceAll(std::string &str, const std::string &from,
|
||||||
std::string& str, const std::string& from, const std::string& to) {
|
const std::string &to) {
|
||||||
if (from.empty())
|
if (from.empty())
|
||||||
return;
|
return;
|
||||||
size_t start_pos = 0;
|
size_t start_pos = 0;
|
||||||
while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
|
while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
|
||||||
str.replace(start_pos, from.length(), to);
|
str.replace(start_pos, from.length(), to);
|
||||||
start_pos += to.length(); // In case 'to' contains 'from', like replacing
|
start_pos += to.length(); // In case 'to' contains 'from', like replacing
|
||||||
// 'x' with 'yx'
|
// 'x' with 'yx'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,10 +76,10 @@ struct OnnxOnnfSymbolMapping {
|
||||||
* @param name onnx tensor name.
|
* @param name onnx tensor name.
|
||||||
* @return onnf tensor corresponding to `name`.
|
* @return onnf tensor corresponding to `name`.
|
||||||
*/
|
*/
|
||||||
mlir::Value GetTensorByOnnxName(std::string name) {
|
mlir::Value GetTensorByOnnxName(const std::string &name) {
|
||||||
assert(onnx_name2onnf_tensor.find(legalize_name(name)) !=
|
assert(onnx_name2onnf_tensor.find(legalize_name(name)) !=
|
||||||
onnx_name2onnf_tensor.end() &&
|
onnx_name2onnf_tensor.end() &&
|
||||||
"Tensor not found");
|
"Tensor not found");
|
||||||
return onnx_name2onnf_tensor.at(legalize_name(name));
|
return onnx_name2onnf_tensor.at(legalize_name(name));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,9 +88,9 @@ struct OnnxOnnfSymbolMapping {
|
||||||
* @param name onnx tensor name.
|
* @param name onnx tensor name.
|
||||||
* @param tensor MLIR Value pointer.
|
* @param tensor MLIR Value pointer.
|
||||||
*/
|
*/
|
||||||
void AddMapping(std::string name, mlir::Value tensor) {
|
void AddMapping(const std::string &name, mlir::Value tensor) {
|
||||||
assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 &&
|
assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 &&
|
||||||
"Tensor already exists.");
|
"Tensor already exists.");
|
||||||
onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
|
onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -124,34 +129,34 @@ private:
|
||||||
// Convert type to MLIR type.
|
// Convert type to MLIR type.
|
||||||
// A complete list of types can be found in:
|
// A complete list of types can be found in:
|
||||||
// <onnf-build-folder>/third_party/onnx/onnx/onnx.pb.h
|
// <onnf-build-folder>/third_party/onnx/onnx/onnx.pb.h
|
||||||
mlir::Type TypeConvert(onnx::TensorProto_DataType intype) {
|
mlir::Type convertONNXTypeToMLIRType(onnx::TensorProto_DataType onnxType) {
|
||||||
switch (intype) {
|
switch (onnxType) {
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16:
|
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16:
|
||||||
return builder_.getF16Type();
|
return builder_.getF16Type();
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
|
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
|
||||||
return builder_.getF32Type();
|
return builder_.getF32Type();
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE:
|
case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE:
|
||||||
return builder_.getF64Type();
|
return builder_.getF64Type();
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
|
case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
|
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
|
||||||
return builder_.getIntegerType(8);
|
return builder_.getIntegerType(8);
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT16:
|
case onnx::TensorProto_DataType::TensorProto_DataType_INT16:
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
|
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
|
||||||
return builder_.getIntegerType(16);
|
return builder_.getIntegerType(16);
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT32:
|
case onnx::TensorProto_DataType::TensorProto_DataType_INT32:
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
|
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
|
||||||
return builder_.getIntegerType(32);
|
return builder_.getIntegerType(32);
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT64:
|
case onnx::TensorProto_DataType::TensorProto_DataType_INT64:
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
|
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
|
||||||
return builder_.getIntegerType(64);
|
return builder_.getIntegerType(64);
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
|
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
|
||||||
return builder_.getI1Type();
|
return builder_.getI1Type();
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
|
case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
|
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
|
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
|
case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
|
||||||
assert(false && "Unsupported data type encountered.");
|
assert(false && "Unsupported data type encountered.");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -169,8 +174,8 @@ private:
|
||||||
for (int i = 0; i < shape_proto.dim_size(); i++) {
|
for (int i = 0; i < shape_proto.dim_size(); i++) {
|
||||||
if (shape_proto.dim()[i].dim_value()) {
|
if (shape_proto.dim()[i].dim_value()) {
|
||||||
int dim_numeric_size = shape_proto.dim()[i].dim_value();
|
int dim_numeric_size = shape_proto.dim()[i].dim_value();
|
||||||
assert(
|
assert(dim_numeric_size != 0 &&
|
||||||
dim_numeric_size != 0 && "Parsed an input tensor with a dimension size of zero");
|
"Parsed an input tensor with a dimension size of zero");
|
||||||
if (dim_numeric_size > 0) {
|
if (dim_numeric_size > 0) {
|
||||||
dims.push_back(dim_numeric_size);
|
dims.push_back(dim_numeric_size);
|
||||||
} else { // If dim_value < 0, then dim is parametric.
|
} else { // If dim_value < 0, then dim is parametric.
|
||||||
|
@ -184,7 +189,7 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::Type elementType =
|
mlir::Type elementType =
|
||||||
TypeConvert(input.type().tensor_type().elem_type());
|
convertONNXTypeToMLIRType(input.type().tensor_type().elem_type());
|
||||||
llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size());
|
llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size());
|
||||||
arg_types.emplace_back(
|
arg_types.emplace_back(
|
||||||
mlir::RankedTensorType::get(tensor_dims, elementType));
|
mlir::RankedTensorType::get(tensor_dims, elementType));
|
||||||
|
@ -200,288 +205,111 @@ private:
|
||||||
void ImportInputTensorSymbol(const onnx::ValueInfoProto &input,
|
void ImportInputTensorSymbol(const onnx::ValueInfoProto &input,
|
||||||
mlir::Value symbol) {
|
mlir::Value symbol) {
|
||||||
auto input_tensor_legalized_name = legalize_name(input.name());
|
auto input_tensor_legalized_name = legalize_name(input.name());
|
||||||
assert(
|
assert(!frontend_symbols_.ContainKey(input_tensor_legalized_name) &&
|
||||||
!frontend_symbols_.ContainKey(input_tensor_legalized_name) &&
|
"Found duplicate legalized input tensor names.");
|
||||||
"Found duplicate legalized input tensor names.");
|
|
||||||
frontend_symbols_.AddMapping(input_tensor_legalized_name, symbol);
|
frontend_symbols_.AddMapping(input_tensor_legalized_name, symbol);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
typedef bstd::variant<int64_t, std::vector<int64_t>, float,
|
||||||
T get_attr_generic(onnx::NodeProto &node, std::string name,
|
std::vector<float>, std::string,
|
||||||
std::function<T(onnx::AttributeProto &)> attr_getter,
|
std::vector<std::string>>
|
||||||
T default_val) {
|
AttrValueType;
|
||||||
|
|
||||||
|
struct ONNXAttrVisitor {
|
||||||
|
ONNXAttrVisitor(std::string name, mlir::OpBuilder &builder)
|
||||||
|
: _builder(builder), _name(std::move(name)) {}
|
||||||
|
|
||||||
|
// Op builder.
|
||||||
|
mlir::OpBuilder &_builder;
|
||||||
|
|
||||||
|
// Name of the attribute being inspected.
|
||||||
|
std::string _name;
|
||||||
|
|
||||||
|
mlir::NamedAttribute operator()(int64_t const &r) {
|
||||||
|
auto val = _builder.getI32IntegerAttr(r);
|
||||||
|
return _builder.getNamedAttr(_name, val);
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::NamedAttribute operator()(std::vector<int64_t> const &ints) {
|
||||||
|
auto val = _builder.getI64ArrayAttr(ints);
|
||||||
|
return _builder.getNamedAttr(_name, val);
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::NamedAttribute operator()(float const &r) {
|
||||||
|
auto val = _builder.getF32FloatAttr(r);
|
||||||
|
return _builder.getNamedAttr(_name, val);
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::NamedAttribute operator()(std::vector<float> const &floats) {
|
||||||
|
auto val = _builder.getF32ArrayAttr(floats);
|
||||||
|
return _builder.getNamedAttr(_name, val);
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::NamedAttribute operator()(std::string const &s) {
|
||||||
|
auto val = _builder.getStringAttr(s);
|
||||||
|
return _builder.getNamedAttr(_name, val);
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::NamedAttribute operator()(std::vector<std::string> const &r) {
|
||||||
|
assert(false && "type of attribute value is not implemented");
|
||||||
|
auto val = _builder.getI32IntegerAttr(1);
|
||||||
|
return _builder.getNamedAttr(_name, val);
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
mlir::NamedAttribute convertNameValuePairToNamedAttribute(
|
||||||
|
std::pair<std::string, AttrValueType> nameAndVal) {
|
||||||
|
auto visitor = ONNXAttrVisitor(nameAndVal.first, builder_);
|
||||||
|
return mpark::visit(visitor, nameAndVal.second);
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::pair<std::string, AttrValueType>
|
||||||
|
convertAttributeProtoToNameValuePair(onnx::AttributeProto &attr) {
|
||||||
|
AttrValueType val;
|
||||||
|
switch (attr.type()) {
|
||||||
|
case onnx::AttributeProto::FLOAT:
|
||||||
|
return std::make_pair(attr.name(), AttrValueType(attr.f()));
|
||||||
|
case onnx::AttributeProto::INT:
|
||||||
|
return std::make_pair(attr.name(), AttrValueType(attr.i()));
|
||||||
|
case onnx::AttributeProto::STRING:
|
||||||
|
return std::make_pair(attr.name(), AttrValueType(attr.s()));
|
||||||
|
case onnx::AttributeProto::FLOATS:
|
||||||
|
val = AttrValueType(
|
||||||
|
std::vector<float>(attr.floats().begin(), attr.floats().end()));
|
||||||
|
return std::make_pair(attr.name(), val);
|
||||||
|
case onnx::AttributeProto::INTS:
|
||||||
|
val = AttrValueType(
|
||||||
|
std::vector<int64_t>(attr.ints().begin(), attr.ints().end()));
|
||||||
|
return std::make_pair(attr.name(), val);
|
||||||
|
default:
|
||||||
|
assert(false && "datatype for attribute is not implemented");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<mlir::NamedAttribute> ImportNodeAttributes(
|
||||||
|
const onnx::NodeProto &node,
|
||||||
|
std::initializer_list<std::pair<std::string, AttrValueType>>
|
||||||
|
defaultAttrList) {
|
||||||
|
std::vector<mlir::NamedAttribute> attributes;
|
||||||
|
std::set<std::string> definedAttributeSet;
|
||||||
for (int i = 0; i < node.attribute_size(); ++i) {
|
for (int i = 0; i < node.attribute_size(); ++i) {
|
||||||
auto attr = node.attribute(i);
|
auto attr = node.attribute(i);
|
||||||
if (attr.name() == name) {
|
auto nameValPair = convertAttributeProtoToNameValuePair(attr);
|
||||||
return attr_getter(attr);
|
attributes.push_back(convertNameValuePairToNamedAttribute(nameValPair));
|
||||||
}
|
definedAttributeSet.insert(attr.name());
|
||||||
}
|
}
|
||||||
return default_val;
|
for (const auto &defaultAttr : defaultAttrList) {
|
||||||
}
|
if (definedAttributeSet.find(defaultAttr.first) ==
|
||||||
|
definedAttributeSet.end())
|
||||||
template <typename T>
|
attributes.push_back(convertNameValuePairToNamedAttribute(defaultAttr));
|
||||||
T get_attr_generic(onnx::NodeProto &node, std::string name,
|
|
||||||
std::function<T(onnx::AttributeProto &)> attr_getter) {
|
|
||||||
for (int i = 0; i < node.attribute_size(); ++i) {
|
|
||||||
auto attr = node.attribute(i);
|
|
||||||
if (attr.name() == name) {
|
|
||||||
return attr_getter(attr);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
assert(false && "ONNX Node Attribute Not Found!");
|
return attributes;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto get_attr_ints(onnx::NodeProto &node, std::string name,
|
void ImportNodeGeneric(const onnx::NodeProto &node) {
|
||||||
std::vector<int> default_val) {
|
|
||||||
std::function<std::vector<int>(onnx::AttributeProto &)> attr_getter =
|
|
||||||
[](onnx::AttributeProto &attr) {
|
|
||||||
std::vector<int> ints(attr.ints_size());
|
|
||||||
std::copy(attr.ints().begin(), attr.ints().end(), ints.begin());
|
|
||||||
return ints;
|
|
||||||
};
|
|
||||||
auto r = get_attr_generic(node, name, attr_getter, default_val);
|
|
||||||
auto dataType =
|
|
||||||
mlir::RankedTensorType::get(r.size(), builder_.getIntegerType(32));
|
|
||||||
auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r));
|
|
||||||
auto aname = node.op_type() + "." + name;
|
|
||||||
auto attr_output = builder_.getNamedAttr(aname, attr_v);
|
|
||||||
return attr_output;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto get_attr_ints(onnx::NodeProto &node, std::string name) {
|
|
||||||
std::function<std::vector<int>(onnx::AttributeProto &)> attr_getter =
|
|
||||||
[](onnx::AttributeProto &attr) {
|
|
||||||
std::vector<int> ints(attr.ints_size());
|
|
||||||
std::copy(attr.ints().begin(), attr.ints().end(), ints.begin());
|
|
||||||
return ints;
|
|
||||||
};
|
|
||||||
auto r = get_attr_generic(node, name, attr_getter);
|
|
||||||
auto dataType =
|
|
||||||
mlir::RankedTensorType::get(r.size(), builder_.getIntegerType(32));
|
|
||||||
auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r));
|
|
||||||
auto aname = node.op_type() + "." + name;
|
|
||||||
auto attr_output = builder_.getNamedAttr(aname, attr_v);
|
|
||||||
return attr_output;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto get_attr_floats(onnx::NodeProto &node, std::string name) {
|
|
||||||
std::function<std::vector<float>(onnx::AttributeProto &)> attr_getter =
|
|
||||||
[](onnx::AttributeProto &attr) {
|
|
||||||
std::vector<float> floats(attr.floats_size());
|
|
||||||
std::copy(attr.floats().begin(), attr.floats().end(), floats.begin());
|
|
||||||
return floats;
|
|
||||||
};
|
|
||||||
auto r = get_attr_generic(node, name, attr_getter);
|
|
||||||
auto dataType =
|
|
||||||
mlir::RankedTensorType::get(r.size(), builder_.getF32Type());
|
|
||||||
auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r));
|
|
||||||
auto aname = node.op_type() + "." + name;
|
|
||||||
auto attr_output = builder_.getNamedAttr(aname, attr_v);
|
|
||||||
return attr_output;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto get_attr_floats(onnx::NodeProto &node, std::string name,
|
|
||||||
std::vector<float> default_val) {
|
|
||||||
std::function<std::vector<float>(onnx::AttributeProto &)> attr_getter =
|
|
||||||
[](onnx::AttributeProto &attr) {
|
|
||||||
std::vector<float> floats(attr.floats_size());
|
|
||||||
std::copy(attr.floats().begin(), attr.floats().end(), floats.begin());
|
|
||||||
return floats;
|
|
||||||
};
|
|
||||||
auto r = get_attr_generic(node, name, attr_getter, default_val);
|
|
||||||
auto dataType =
|
|
||||||
mlir::RankedTensorType::get(r.size(), builder_.getF32Type());
|
|
||||||
auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r));
|
|
||||||
auto aname = node.op_type() + "." + name;
|
|
||||||
auto attr_output = builder_.getNamedAttr(aname, attr_v);
|
|
||||||
return attr_output;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto get_attr_int(onnx::NodeProto &node, std::string name) {
|
|
||||||
std::function<int(onnx::AttributeProto &)> attr_getter =
|
|
||||||
[](onnx::AttributeProto &attr) { return attr.i(); };
|
|
||||||
int r = get_attr_generic(node, name, attr_getter);
|
|
||||||
auto attr_v = builder_.getI32IntegerAttr(r);
|
|
||||||
auto aname = node.op_type() + "." + name;
|
|
||||||
auto attr_output = builder_.getNamedAttr(aname, attr_v);
|
|
||||||
return attr_output;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto get_attr_int(onnx::NodeProto &node, std::string name, int default_val) {
|
|
||||||
std::function<int(onnx::AttributeProto &)> attr_getter =
|
|
||||||
[](onnx::AttributeProto &attr) { return attr.i(); };
|
|
||||||
int r = get_attr_generic(node, name, attr_getter, default_val);
|
|
||||||
auto attr_v = builder_.getI32IntegerAttr(r);
|
|
||||||
auto aname = node.op_type() + "." + name;
|
|
||||||
auto attr_output = builder_.getNamedAttr(aname, attr_v);
|
|
||||||
return attr_output;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto get_attr_float(onnx::NodeProto &node, std::string name) {
|
|
||||||
std::function<float(onnx::AttributeProto &)> attr_getter =
|
|
||||||
[](onnx::AttributeProto &attr) { return attr.f(); };
|
|
||||||
auto r = get_attr_generic(node, name, attr_getter);
|
|
||||||
auto attr_v = builder_.getF32FloatAttr(r);
|
|
||||||
auto aname = node.op_type() + "." + name;
|
|
||||||
return builder_.getNamedAttr(aname, attr_v);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto get_attr_float(onnx::NodeProto &node, std::string name,
|
|
||||||
float default_val) {
|
|
||||||
std::function<float(onnx::AttributeProto &)> attr_getter =
|
|
||||||
[](onnx::AttributeProto &attr) { return attr.f(); };
|
|
||||||
auto r = get_attr_generic(node, name, attr_getter, default_val);
|
|
||||||
auto attr_v = builder_.getF32FloatAttr(r);
|
|
||||||
auto aname = node.op_type() + "." + name;
|
|
||||||
return builder_.getNamedAttr(aname, attr_v);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto get_attr_string(onnx::NodeProto &node, std::string name) {
|
|
||||||
std::function<std::string(onnx::AttributeProto &)> attr_getter =
|
|
||||||
[](onnx::AttributeProto &attr) { return attr.s(); };
|
|
||||||
auto r = get_attr_generic(node, name, attr_getter);
|
|
||||||
auto attr_v = builder_.getStringAttr(r);
|
|
||||||
auto aname = node.op_type() + "." + name;
|
|
||||||
return builder_.getNamedAttr(aname, attr_v);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto get_attr_string(onnx::NodeProto &node, std::string name,
|
|
||||||
std::string default_val) {
|
|
||||||
std::function<std::string(onnx::AttributeProto &)> attr_getter =
|
|
||||||
[](onnx::AttributeProto &attr) { return attr.s(); };
|
|
||||||
auto r = get_attr_generic(node, name, attr_getter, default_val);
|
|
||||||
auto attr_v = builder_.getStringAttr(r);
|
|
||||||
auto aname = node.op_type() + "." + name;
|
|
||||||
return builder_.getNamedAttr(aname, attr_v);
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
auto get_attr_strings(onnx::NodeProto &node, std::string name) {
|
|
||||||
std::function<std::vector<std::string>(onnx::AttributeProto &)>
|
|
||||||
attr_getter =
|
|
||||||
[](onnx::AttributeProto &attr) {
|
|
||||||
std::vector<std::string> strings(attr.strings_size());
|
|
||||||
std::copy(attr.strings().begin(), attr.strings().end(),
|
|
||||||
strings.begin()); return strings;
|
|
||||||
};
|
|
||||||
auto r = get_attr_generic(node, name, attr_getter);
|
|
||||||
return r;
|
|
||||||
return builder_.getNamedAttr(aname, attr_v);
|
|
||||||
auto dataType =
|
|
||||||
mlir::RankedTensorType::get(r.size(), builder_.get???Type());
|
|
||||||
auto attr_v = mlir::DenseElementsAttr::get(dataType,
|
|
||||||
llvm::makeArrayRef(r)); auto aname = node.op_type() + "." + name; auto
|
|
||||||
attr_output = builder_.getNamedAttr(aname, attr_v); return attr_output;
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
auto get_default_ints(std::string default_str) {
|
|
||||||
std::vector<int> r;
|
|
||||||
auto start = default_str.find("{");
|
|
||||||
while (true) {
|
|
||||||
auto end = default_str.find(",", start + 1);
|
|
||||||
if (end == std::string::npos) {
|
|
||||||
end = default_str.find("}", start + 1);
|
|
||||||
if (end != std::string::npos && end > start + 1) {
|
|
||||||
r.push_back(std::stoi(default_str.substr(start + 1, end)));
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
r.push_back(std::stoi(default_str.substr(start + 1, end)));
|
|
||||||
}
|
|
||||||
start = end + 1;
|
|
||||||
}
|
|
||||||
return r;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto get_default_floats(std::string default_str) {
|
|
||||||
std::vector<float> r;
|
|
||||||
auto start = default_str.find("{");
|
|
||||||
while (true) {
|
|
||||||
auto end = default_str.find(",", start + 1);
|
|
||||||
if (end == std::string::npos) {
|
|
||||||
end = default_str.find("}", start + 1);
|
|
||||||
if (end != std::string::npos && end > start + 1) {
|
|
||||||
r.push_back(std::stof(default_str.substr(start + 1, end)));
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
r.push_back(std::stof(default_str.substr(start + 1, end)));
|
|
||||||
}
|
|
||||||
start = end + 1;
|
|
||||||
}
|
|
||||||
return r;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto get_default_strings(std::string default_str) {
|
|
||||||
std::vector<std::string> r;
|
|
||||||
auto start = default_str.find("{");
|
|
||||||
while (true) {
|
|
||||||
auto end = default_str.find(",", start + 1);
|
|
||||||
if (end == std::string::npos) {
|
|
||||||
end = default_str.find("}", start + 1);
|
|
||||||
if (end != std::string::npos && end > start + 1) {
|
|
||||||
r.push_back(default_str.substr(start + 1, end));
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
r.push_back(default_str.substr(start + 1, end));
|
|
||||||
}
|
|
||||||
start = end + 1;
|
|
||||||
}
|
|
||||||
return r;
|
|
||||||
}
|
|
||||||
|
|
||||||
onnx::TensorProto get_attr_tensor(onnx::NodeProto &node, std::string name) {
|
|
||||||
std::function<onnx::TensorProto(onnx::AttributeProto &)> attr_getter =
|
|
||||||
[](onnx::AttributeProto &attr) { return attr.t(); };
|
|
||||||
return get_attr_generic(node, name, attr_getter);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto ImportNodeAttr(onnx::NodeProto node, std::string attr_name,
|
|
||||||
std::string type_name, std::string default_str) {
|
|
||||||
if (default_str == "") {
|
|
||||||
if (type_name == "int") {
|
|
||||||
return get_attr_int(node, attr_name);
|
|
||||||
} else if (type_name == "float") {
|
|
||||||
return get_attr_float(node, attr_name);
|
|
||||||
} else if (type_name == "str") {
|
|
||||||
return get_attr_string(node, attr_name);
|
|
||||||
} else if (type_name == "ints") {
|
|
||||||
return get_attr_ints(node, attr_name);
|
|
||||||
} else if (type_name == "floats") {
|
|
||||||
return get_attr_floats(node, attr_name);
|
|
||||||
} else {
|
|
||||||
assert(
|
|
||||||
false &&
|
|
||||||
"Got an empty initializer or initializer for this "
|
|
||||||
"datatype is not implemented. Something is wrong.");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// with default value
|
|
||||||
if (type_name == "int") {
|
|
||||||
return get_attr_int(node, attr_name, std::stoi(default_str));
|
|
||||||
} else if (type_name == "float") {
|
|
||||||
return get_attr_float(node, attr_name, std::stof(default_str));
|
|
||||||
} else if (type_name == "str") {
|
|
||||||
return get_attr_string(node, attr_name, default_str);
|
|
||||||
} else if (type_name == "ints") {
|
|
||||||
return get_attr_ints(node, attr_name, get_default_ints(default_str));
|
|
||||||
} else if (type_name == "floats") {
|
|
||||||
return get_attr_floats(node, attr_name,
|
|
||||||
get_default_floats(default_str));
|
|
||||||
} else {
|
|
||||||
assert(
|
|
||||||
false &&
|
|
||||||
"Got an empty initializer or initializer for this "
|
|
||||||
"datatype is not implemented. Something is wrong.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ImportNodeGeneric(onnx::NodeProto node) {
|
|
||||||
std::vector<mlir::Value> inputs;
|
std::vector<mlir::Value> inputs;
|
||||||
for (auto item : node.input()) {
|
for (const auto &item : node.input()) {
|
||||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||||
}
|
}
|
||||||
|
@ -511,12 +339,12 @@ private:
|
||||||
* default}
|
* default}
|
||||||
*/
|
*/
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void ImportNodeOneOut(
|
void
|
||||||
onnx::NodeProto node, int nIn, int nOut,
|
ImportNodeOneOut(const onnx::NodeProto &node, int nIn, int nOut,
|
||||||
std::initializer_list<std::tuple<std::string, std::string, std::string>>
|
std::initializer_list<std::pair<std::string, AttrValueType>>
|
||||||
attrs) {
|
defaultAttrList) {
|
||||||
std::vector<mlir::Value> inputs;
|
std::vector<mlir::Value> inputs;
|
||||||
for (auto item : node.input()) {
|
for (const auto &item : node.input()) {
|
||||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||||
}
|
}
|
||||||
|
@ -528,22 +356,7 @@ private:
|
||||||
mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<mlir::NamedAttribute> attributes;
|
auto attributes = ImportNodeAttributes(node, defaultAttrList);
|
||||||
// for (auto [attr_name, attr_type, attr_default] : attrs) {
|
|
||||||
for (auto oneAttr : attrs) {
|
|
||||||
std::string attr_name;
|
|
||||||
std::string attr_type;
|
|
||||||
std::string attr_default;
|
|
||||||
std::tie(attr_name, attr_type, attr_default) = oneAttr;
|
|
||||||
if (attr_type != "") {
|
|
||||||
auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default);
|
|
||||||
attributes.push_back(attr);
|
|
||||||
} else {
|
|
||||||
// TODO: the attributes need special handling
|
|
||||||
// std::cout << "missing " << node.op_type() << " " << attr_name <<
|
|
||||||
// std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::StringRef OpName = node.op_type();
|
llvm::StringRef OpName = node.op_type();
|
||||||
|
|
||||||
|
@ -559,11 +372,11 @@ private:
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void ImportNodeMultipleOuts(
|
void ImportNodeMultipleOuts(
|
||||||
onnx::NodeProto node, int nIn, int nOut,
|
const onnx::NodeProto &node, int nIn, int nOut,
|
||||||
std::initializer_list<std::tuple<std::string, std::string, std::string>>
|
std::initializer_list<std::pair<std::string, AttrValueType>>
|
||||||
attrs) {
|
defaultAttrList) {
|
||||||
std::vector<mlir::Value> inputs;
|
std::vector<mlir::Value> inputs;
|
||||||
for (auto item : node.input()) {
|
for (const auto &item : node.input()) {
|
||||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||||
}
|
}
|
||||||
|
@ -575,21 +388,7 @@ private:
|
||||||
mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<mlir::NamedAttribute> attributes;
|
auto attributes = ImportNodeAttributes(node, defaultAttrList);
|
||||||
for (auto oneAttr : attrs) {
|
|
||||||
std::string attr_name;
|
|
||||||
std::string attr_type;
|
|
||||||
std::string attr_default;
|
|
||||||
std::tie(attr_name, attr_type, attr_default) = oneAttr;
|
|
||||||
if (attr_type != "") {
|
|
||||||
auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default);
|
|
||||||
attributes.push_back(attr);
|
|
||||||
} else {
|
|
||||||
// TODO: the attributes need special handling
|
|
||||||
// std::cout << "missing " << node.op_type() << " " << attr_name <<
|
|
||||||
// std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::StringRef OpName = node.op_type();
|
llvm::StringRef OpName = node.op_type();
|
||||||
|
|
||||||
|
@ -610,10 +409,10 @@ private:
|
||||||
* c++ does not allow template specialization inside a class scope
|
* c++ does not allow template specialization inside a class scope
|
||||||
* a specialized function is used
|
* a specialized function is used
|
||||||
*/
|
*/
|
||||||
void ImportNodeConv(
|
void
|
||||||
onnx::NodeProto node, int nOut,
|
ImportNodeConv(onnx::NodeProto node, int nIn, int nOut,
|
||||||
std::initializer_list<std::tuple<std::string, std::string, std::string>>
|
std::initializer_list<std::pair<std::string, AttrValueType>>
|
||||||
attrs) {
|
defaultAttrList) {
|
||||||
// Conv has attribute dilations, kernel_shape, pads, the default value of
|
// Conv has attribute dilations, kernel_shape, pads, the default value of
|
||||||
// which is determined by the shape of first argument. However, since the
|
// which is determined by the shape of first argument. However, since the
|
||||||
// shape is unknown now, these attributes can be not generated auto
|
// shape is unknown now, these attributes can be not generated auto
|
||||||
|
@ -627,29 +426,32 @@ private:
|
||||||
int nOps = node.input().size();
|
int nOps = node.input().size();
|
||||||
|
|
||||||
if (nOps == 2)
|
if (nOps == 2)
|
||||||
ImportNodeOneOut<mlir::ONNXConvNoBiasOp>(node, nOps, nOut, attrs);
|
ImportNodeOneOut<mlir::ONNXConvNoBiasOp>(
|
||||||
|
node, nOps, nOut, defaultAttrList);
|
||||||
else
|
else
|
||||||
ImportNodeOneOut<mlir::ONNXConvOp>(node, nOps, nOut, attrs);
|
ImportNodeOneOut<mlir::ONNXConvOp>(node, nOps, nOut, defaultAttrList);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* Special handle for MaxPool operations.
|
* Special handle for MaxPool operations.
|
||||||
*/
|
*/
|
||||||
void ImportNodeMaxPool(
|
void ImportNodeMaxPool(
|
||||||
onnx::NodeProto node, int nIn,
|
onnx::NodeProto node, int nIn, int nOut,
|
||||||
std::initializer_list<std::tuple<std::string, std::string, std::string>>
|
std::initializer_list<std::pair<std::string, AttrValueType>>
|
||||||
attrs) {
|
defaultAttrList) {
|
||||||
int nOuts = node.output().size();
|
int nOuts = node.output().size();
|
||||||
if (nOuts == 1) {
|
if (nOuts == 1) {
|
||||||
ImportNodeOneOut<mlir::ONNXMaxPoolSingleOutOp>(node, nIn, nOuts, attrs);
|
ImportNodeOneOut<mlir::ONNXMaxPoolSingleOutOp>(
|
||||||
|
node, nIn, nOuts, defaultAttrList);
|
||||||
} else {
|
} else {
|
||||||
ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>(node, nIn, nOuts, attrs);
|
ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>(
|
||||||
|
node, nIn, nOuts, defaultAttrList);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ImportNode(onnx::NodeProto node) {
|
void ImportNode(const onnx::NodeProto &node) {
|
||||||
std::vector<mlir::Value> inputs;
|
std::vector<mlir::Value> inputs;
|
||||||
for (auto item : node.input()) {
|
for (const auto &item : node.input()) {
|
||||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||||
}
|
}
|
||||||
|
@ -689,9 +491,8 @@ private:
|
||||||
llvm::SmallVectorImpl<mlir::Type> &ret_types,
|
llvm::SmallVectorImpl<mlir::Type> &ret_types,
|
||||||
llvm::SmallVectorImpl<mlir::Value> &ret_vals) {
|
llvm::SmallVectorImpl<mlir::Value> &ret_vals) {
|
||||||
auto output_tensor_legalized_name = legalize_name(output.name());
|
auto output_tensor_legalized_name = legalize_name(output.name());
|
||||||
assert(
|
assert(frontend_symbols_.ContainKey(output_tensor_legalized_name) &&
|
||||||
frontend_symbols_.ContainKey(output_tensor_legalized_name) &&
|
"Output tensor not found");
|
||||||
"Output tensor not found");
|
|
||||||
|
|
||||||
auto tensor_val =
|
auto tensor_val =
|
||||||
frontend_symbols_.GetTensorByOnnxName(output_tensor_legalized_name);
|
frontend_symbols_.GetTensorByOnnxName(output_tensor_legalized_name);
|
||||||
|
@ -750,9 +551,9 @@ private:
|
||||||
funcType = builder_.getFunctionType(arg_types, ret_types);
|
funcType = builder_.getFunctionType(arg_types, ret_types);
|
||||||
mainFunc.setType(funcType);
|
mainFunc.setType(funcType);
|
||||||
}
|
}
|
||||||
}; // FrontendGenImpl class
|
}; // FrontendGenImpl class
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace onnf
|
} // namespace onnf
|
||||||
|
|
||||||
namespace onnf {
|
namespace onnf {
|
||||||
|
|
||||||
|
@ -775,4 +576,4 @@ void ImportFrontendModelFile(std::string model_fname,
|
||||||
FrontendGenImpl myONNXGen(context);
|
FrontendGenImpl myONNXGen(context);
|
||||||
module = myONNXGen.ImportONNXModel(model);
|
module = myONNXGen.ImportONNXModel(model);
|
||||||
}
|
}
|
||||||
} // namespace onnf
|
} // namespace onnf
|
||||||
|
|
|
@ -16,13 +16,13 @@
|
||||||
});
|
});
|
||||||
}else if (OpName == "ArgMax") {
|
}else if (OpName == "ArgMax") {
|
||||||
ImportNodeOneOut<mlir::ONNXArgMaxOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXArgMaxOp>(node, 1, 1, {
|
||||||
{"axis","int","0"}
|
{"axis", 0}
|
||||||
,{"keepdims","int","1"}
|
,{"keepdims", 1}
|
||||||
});
|
});
|
||||||
}else if (OpName == "ArgMin") {
|
}else if (OpName == "ArgMin") {
|
||||||
ImportNodeOneOut<mlir::ONNXArgMinOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXArgMinOp>(node, 1, 1, {
|
||||||
{"axis","int","0"}
|
{"axis", 0}
|
||||||
,{"keepdims","int","1"}
|
,{"keepdims", 1}
|
||||||
});
|
});
|
||||||
}else if (OpName == "Asin") {
|
}else if (OpName == "Asin") {
|
||||||
ImportNodeOneOut<mlir::ONNXAsinOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXAsinOp>(node, 1, 1, {
|
||||||
|
@ -38,25 +38,22 @@
|
||||||
});
|
});
|
||||||
}else if (OpName == "AveragePool") {
|
}else if (OpName == "AveragePool") {
|
||||||
ImportNodeOneOut<mlir::ONNXAveragePoolOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXAveragePoolOp>(node, 1, 1, {
|
||||||
{"auto_pad","str","NOTSET"}
|
{"auto_pad", "NOTSET"}
|
||||||
,{"ceil_mode","int","0"}
|
,{"ceil_mode", 0}
|
||||||
,{"count_include_pad","int","0"}
|
,{"count_include_pad", 0}
|
||||||
,{"kernel_shape","ints", ""}
|
,{"kernel_shape", std::vector<int64_t> {}}
|
||||||
,{"pads","", ""}
|
|
||||||
,{"strides","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "BatchNormalization") {
|
}else if (OpName == "BatchNormalization") {
|
||||||
ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, 5, 5, {
|
ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, 5, 5, {
|
||||||
{"epsilon","float","1e-05"}
|
{"epsilon", (float)1e-05}
|
||||||
,{"momentum","float","0.9"}
|
,{"momentum", (float)0.9}
|
||||||
});
|
});
|
||||||
}else if (OpName == "BitShift") {
|
}else if (OpName == "BitShift") {
|
||||||
ImportNodeOneOut<mlir::ONNXBitShiftOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXBitShiftOp>(node, 2, 1, {
|
||||||
{"direction","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "Cast") {
|
}else if (OpName == "Cast") {
|
||||||
ImportNodeOneOut<mlir::ONNXCastOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXCastOp>(node, 1, 1, {
|
||||||
{"to","int", "0"}
|
{"to", 0}
|
||||||
});
|
});
|
||||||
}else if (OpName == "Ceil") {
|
}else if (OpName == "Ceil") {
|
||||||
ImportNodeOneOut<mlir::ONNXCeilOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXCeilOp>(node, 1, 1, {
|
||||||
|
@ -66,54 +63,35 @@
|
||||||
});
|
});
|
||||||
}else if (OpName == "Compress") {
|
}else if (OpName == "Compress") {
|
||||||
ImportNodeOneOut<mlir::ONNXCompressOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXCompressOp>(node, 2, 1, {
|
||||||
{"axis","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "Concat") {
|
}else if (OpName == "Concat") {
|
||||||
ImportNodeOneOut<mlir::ONNXConcatOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXConcatOp>(node, 1, 1, {
|
||||||
{"axis","int", "0"}
|
{"axis", 0}
|
||||||
});
|
});
|
||||||
}else if (OpName == "ConcatFromSequence") {
|
}else if (OpName == "ConcatFromSequence") {
|
||||||
ImportNodeOneOut<mlir::ONNXConcatFromSequenceOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXConcatFromSequenceOp>(node, 1, 1, {
|
||||||
{"axis","", ""}
|
{"new_axis", 0}
|
||||||
,{"new_axis","int","0"}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "Constant") {
|
}else if (OpName == "Constant") {
|
||||||
ImportNodeOneOut<mlir::ONNXConstantOp>(node, 0, 1, {
|
ImportNodeOneOut<mlir::ONNXConstantOp>(node, 0, 1, {
|
||||||
{"sparse_value","", ""}
|
|
||||||
,{"value","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "ConstantOfShape") {
|
}else if (OpName == "ConstantOfShape") {
|
||||||
ImportNodeOneOut<mlir::ONNXConstantOfShapeOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXConstantOfShapeOp>(node, 1, 1, {
|
||||||
{"value","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "Conv") {
|
}else if (OpName == "Conv") {
|
||||||
ImportNodeConv(node, 1, {
|
ImportNodeConv(node, 3, 1, {
|
||||||
{"auto_pad","str","NOTSET"}
|
{"auto_pad", "NOTSET"}
|
||||||
,{"dilations","", ""}
|
,{"group", 1}
|
||||||
,{"group","int", "1"}
|
|
||||||
,{"kernel_shape","", ""}
|
|
||||||
,{"pads","", ""}
|
|
||||||
,{"strides","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "ConvInteger") {
|
}else if (OpName == "ConvInteger") {
|
||||||
ImportNodeOneOut<mlir::ONNXConvIntegerOp>(node, 4, 1, {
|
ImportNodeOneOut<mlir::ONNXConvIntegerOp>(node, 4, 1, {
|
||||||
{"auto_pad","str","NOTSET"}
|
{"auto_pad", "NOTSET"}
|
||||||
,{"dilations","", ""}
|
,{"group", 1}
|
||||||
,{"group","int","1"}
|
|
||||||
,{"kernel_shape","", ""}
|
|
||||||
,{"pads","", ""}
|
|
||||||
,{"strides","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "ConvTranspose") {
|
}else if (OpName == "ConvTranspose") {
|
||||||
ImportNodeOneOut<mlir::ONNXConvTransposeOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXConvTransposeOp>(node, 3, 1, {
|
||||||
{"auto_pad","str","NOTSET"}
|
{"auto_pad", "NOTSET"}
|
||||||
,{"dilations","", ""}
|
,{"group", 1}
|
||||||
,{"group","int","1"}
|
|
||||||
,{"kernel_shape","", ""}
|
|
||||||
,{"output_padding","", ""}
|
|
||||||
,{"output_shape","", ""}
|
|
||||||
,{"pads","", ""}
|
|
||||||
,{"strides","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "Cos") {
|
}else if (OpName == "Cos") {
|
||||||
ImportNodeOneOut<mlir::ONNXCosOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXCosOp>(node, 1, 1, {
|
||||||
|
@ -123,13 +101,12 @@
|
||||||
});
|
});
|
||||||
}else if (OpName == "CumSum") {
|
}else if (OpName == "CumSum") {
|
||||||
ImportNodeOneOut<mlir::ONNXCumSumOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXCumSumOp>(node, 2, 1, {
|
||||||
{"exclusive","int","0"}
|
{"exclusive", 0}
|
||||||
,{"reverse","int","0"}
|
,{"reverse", 0}
|
||||||
});
|
});
|
||||||
}else if (OpName == "DepthToSpace") {
|
}else if (OpName == "DepthToSpace") {
|
||||||
ImportNodeOneOut<mlir::ONNXDepthToSpaceOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXDepthToSpaceOp>(node, 1, 1, {
|
||||||
{"blocksize","", ""}
|
{"mode", "DCR"}
|
||||||
,{"mode","str","DCR"}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "DequantizeLinear") {
|
}else if (OpName == "DequantizeLinear") {
|
||||||
ImportNodeOneOut<mlir::ONNXDequantizeLinearOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXDequantizeLinearOp>(node, 3, 1, {
|
||||||
|
@ -142,14 +119,14 @@
|
||||||
});
|
});
|
||||||
}else if (OpName == "Dropout") {
|
}else if (OpName == "Dropout") {
|
||||||
ImportNodeMultipleOuts<mlir::ONNXDropoutOp>(node, 1, 2, {
|
ImportNodeMultipleOuts<mlir::ONNXDropoutOp>(node, 1, 2, {
|
||||||
{"ratio","float","0.5"}
|
{"ratio", (float)0.5}
|
||||||
});
|
});
|
||||||
}else if (OpName == "DynamicQuantizeLinear") {
|
}else if (OpName == "DynamicQuantizeLinear") {
|
||||||
ImportNodeMultipleOuts<mlir::ONNXDynamicQuantizeLinearOp>(node, 1, 3, {
|
ImportNodeMultipleOuts<mlir::ONNXDynamicQuantizeLinearOp>(node, 1, 3, {
|
||||||
});
|
});
|
||||||
}else if (OpName == "Elu") {
|
}else if (OpName == "Elu") {
|
||||||
ImportNodeOneOut<mlir::ONNXEluOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXEluOp>(node, 1, 1, {
|
||||||
{"alpha","float","1.0"}
|
{"alpha", (float)1.0}
|
||||||
});
|
});
|
||||||
}else if (OpName == "Equal") {
|
}else if (OpName == "Equal") {
|
||||||
ImportNodeOneOut<mlir::ONNXEqualOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXEqualOp>(node, 2, 1, {
|
||||||
|
@ -165,50 +142,44 @@
|
||||||
});
|
});
|
||||||
}else if (OpName == "EyeLike") {
|
}else if (OpName == "EyeLike") {
|
||||||
ImportNodeOneOut<mlir::ONNXEyeLikeOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXEyeLikeOp>(node, 1, 1, {
|
||||||
{"dtype","", ""}
|
{"k", 0}
|
||||||
,{"k","int","0"}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "Flatten") {
|
}else if (OpName == "Flatten") {
|
||||||
ImportNodeOneOut<mlir::ONNXFlattenOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXFlattenOp>(node, 1, 1, {
|
||||||
{"axis","int","1"}
|
{"axis", 1}
|
||||||
});
|
});
|
||||||
}else if (OpName == "Floor") {
|
}else if (OpName == "Floor") {
|
||||||
ImportNodeOneOut<mlir::ONNXFloorOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXFloorOp>(node, 1, 1, {
|
||||||
});
|
});
|
||||||
}else if (OpName == "GRU") {
|
}else if (OpName == "GRU") {
|
||||||
ImportNodeMultipleOuts<mlir::ONNXGRUOp>(node, 6, 2, {
|
ImportNodeMultipleOuts<mlir::ONNXGRUOp>(node, 6, 2, {
|
||||||
{"activation_alpha","", ""}
|
{"direction", "forward"}
|
||||||
,{"activation_beta","", ""}
|
,{"linear_before_reset", 0}
|
||||||
,{"activations","", ""}
|
|
||||||
,{"clip","", ""}
|
|
||||||
,{"direction","str","forward"}
|
|
||||||
,{"hidden_size","", ""}
|
|
||||||
,{"linear_before_reset","int","0"}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "Gather") {
|
}else if (OpName == "Gather") {
|
||||||
ImportNodeOneOut<mlir::ONNXGatherOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXGatherOp>(node, 2, 1, {
|
||||||
{"axis","int","0"}
|
{"axis", 0}
|
||||||
});
|
});
|
||||||
}else if (OpName == "GatherElements") {
|
}else if (OpName == "GatherElements") {
|
||||||
ImportNodeOneOut<mlir::ONNXGatherElementsOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXGatherElementsOp>(node, 2, 1, {
|
||||||
{"axis","int","0"}
|
{"axis", 0}
|
||||||
});
|
});
|
||||||
}else if (OpName == "GatherND") {
|
}else if (OpName == "GatherND") {
|
||||||
ImportNodeOneOut<mlir::ONNXGatherNDOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXGatherNDOp>(node, 2, 1, {
|
||||||
});
|
});
|
||||||
}else if (OpName == "Gemm") {
|
}else if (OpName == "Gemm") {
|
||||||
ImportNodeOneOut<mlir::ONNXGemmOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXGemmOp>(node, 3, 1, {
|
||||||
{"alpha","float","1.0"}
|
{"alpha", (float)1.0}
|
||||||
,{"beta","float","1.0"}
|
,{"beta", (float)1.0}
|
||||||
,{"transA","int","0"}
|
,{"transA", 0}
|
||||||
,{"transB","int","0"}
|
,{"transB", 0}
|
||||||
});
|
});
|
||||||
}else if (OpName == "GlobalAveragePool") {
|
}else if (OpName == "GlobalAveragePool") {
|
||||||
ImportNodeOneOut<mlir::ONNXGlobalAveragePoolOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXGlobalAveragePoolOp>(node, 1, 1, {
|
||||||
});
|
});
|
||||||
}else if (OpName == "GlobalLpPool") {
|
}else if (OpName == "GlobalLpPool") {
|
||||||
ImportNodeOneOut<mlir::ONNXGlobalLpPoolOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXGlobalLpPoolOp>(node, 1, 1, {
|
||||||
{"p","int","2"}
|
{"p", 2}
|
||||||
});
|
});
|
||||||
}else if (OpName == "GlobalMaxPool") {
|
}else if (OpName == "GlobalMaxPool") {
|
||||||
ImportNodeOneOut<mlir::ONNXGlobalMaxPoolOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXGlobalMaxPoolOp>(node, 1, 1, {
|
||||||
|
@ -218,53 +189,45 @@
|
||||||
});
|
});
|
||||||
}else if (OpName == "HardSigmoid") {
|
}else if (OpName == "HardSigmoid") {
|
||||||
ImportNodeOneOut<mlir::ONNXHardSigmoidOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXHardSigmoidOp>(node, 1, 1, {
|
||||||
{"alpha","float","0.2"}
|
{"alpha", (float)0.2}
|
||||||
,{"beta","float","0.5"}
|
,{"beta", (float)0.5}
|
||||||
});
|
});
|
||||||
}else if (OpName == "Hardmax") {
|
}else if (OpName == "Hardmax") {
|
||||||
ImportNodeOneOut<mlir::ONNXHardmaxOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXHardmaxOp>(node, 1, 1, {
|
||||||
{"axis","int","1"}
|
{"axis", 1}
|
||||||
});
|
});
|
||||||
}else if (OpName == "Identity") {
|
}else if (OpName == "Identity") {
|
||||||
ImportNodeOneOut<mlir::ONNXIdentityOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXIdentityOp>(node, 1, 1, {
|
||||||
});
|
});
|
||||||
}else if (OpName == "If") {
|
}else if (OpName == "If") {
|
||||||
ImportNodeOneOut<mlir::ONNXIfOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXIfOp>(node, 1, 1, {
|
||||||
{"else_branch","", ""}
|
|
||||||
,{"then_branch","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "InstanceNormalization") {
|
}else if (OpName == "InstanceNormalization") {
|
||||||
ImportNodeOneOut<mlir::ONNXInstanceNormalizationOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXInstanceNormalizationOp>(node, 3, 1, {
|
||||||
{"epsilon","float","1e-05"}
|
{"epsilon", (float)1e-05}
|
||||||
});
|
});
|
||||||
}else if (OpName == "IsInf") {
|
}else if (OpName == "IsInf") {
|
||||||
ImportNodeOneOut<mlir::ONNXIsInfOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXIsInfOp>(node, 1, 1, {
|
||||||
{"detect_negative","int","1"}
|
{"detect_negative", 1}
|
||||||
,{"detect_positive","int","1"}
|
,{"detect_positive", 1}
|
||||||
});
|
});
|
||||||
}else if (OpName == "IsNaN") {
|
}else if (OpName == "IsNaN") {
|
||||||
ImportNodeOneOut<mlir::ONNXIsNaNOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXIsNaNOp>(node, 1, 1, {
|
||||||
});
|
});
|
||||||
}else if (OpName == "LRN") {
|
}else if (OpName == "LRN") {
|
||||||
ImportNodeOneOut<mlir::ONNXLRNOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXLRNOp>(node, 1, 1, {
|
||||||
{"alpha","float","0.0001"}
|
{"alpha", (float)0.0001}
|
||||||
,{"beta","float","0.75"}
|
,{"beta", (float)0.75}
|
||||||
,{"bias","float","1.0"}
|
,{"bias", (float)1.0}
|
||||||
,{"size","int", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "LSTM") {
|
}else if (OpName == "LSTM") {
|
||||||
ImportNodeMultipleOuts<mlir::ONNXLSTMOp>(node, 8, 3, {
|
ImportNodeMultipleOuts<mlir::ONNXLSTMOp>(node, 8, 3, {
|
||||||
{"activation_alpha","", ""}
|
{"direction", "forward"}
|
||||||
,{"activation_beta","", ""}
|
,{"input_forget", 0}
|
||||||
,{"activations","", ""}
|
|
||||||
,{"clip","", ""}
|
|
||||||
,{"direction","str","forward"}
|
|
||||||
,{"hidden_size","", ""}
|
|
||||||
,{"input_forget","int","0"}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "LeakyRelu") {
|
}else if (OpName == "LeakyRelu") {
|
||||||
ImportNodeOneOut<mlir::ONNXLeakyReluOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXLeakyReluOp>(node, 1, 1, {
|
||||||
{"alpha","float","0.01"}
|
{"alpha", (float)0.01}
|
||||||
});
|
});
|
||||||
}else if (OpName == "Less") {
|
}else if (OpName == "Less") {
|
||||||
ImportNodeOneOut<mlir::ONNXLessOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXLessOp>(node, 2, 1, {
|
||||||
|
@ -274,24 +237,20 @@
|
||||||
});
|
});
|
||||||
}else if (OpName == "LogSoftmax") {
|
}else if (OpName == "LogSoftmax") {
|
||||||
ImportNodeOneOut<mlir::ONNXLogSoftmaxOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXLogSoftmaxOp>(node, 1, 1, {
|
||||||
{"axis","int","1"}
|
{"axis", 1}
|
||||||
});
|
});
|
||||||
}else if (OpName == "Loop") {
|
}else if (OpName == "Loop") {
|
||||||
ImportNodeOneOut<mlir::ONNXLoopOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXLoopOp>(node, 3, 1, {
|
||||||
{"body","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "LpNormalization") {
|
}else if (OpName == "LpNormalization") {
|
||||||
ImportNodeOneOut<mlir::ONNXLpNormalizationOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXLpNormalizationOp>(node, 1, 1, {
|
||||||
{"axis","int","-1"}
|
{"axis", -1}
|
||||||
,{"p","int","2"}
|
,{"p", 2}
|
||||||
});
|
});
|
||||||
}else if (OpName == "LpPool") {
|
}else if (OpName == "LpPool") {
|
||||||
ImportNodeOneOut<mlir::ONNXLpPoolOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXLpPoolOp>(node, 1, 1, {
|
||||||
{"auto_pad","str","NOTSET"}
|
{"auto_pad", "NOTSET"}
|
||||||
,{"kernel_shape","", ""}
|
,{"p", 2}
|
||||||
,{"p","int","2"}
|
|
||||||
,{"pads","", ""}
|
|
||||||
,{"strides","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "MatMul") {
|
}else if (OpName == "MatMul") {
|
||||||
ImportNodeOneOut<mlir::ONNXMatMulOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXMatMulOp>(node, 2, 1, {
|
||||||
|
@ -303,55 +262,47 @@
|
||||||
ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, {
|
||||||
});
|
});
|
||||||
}else if (OpName == "MaxPool") {
|
}else if (OpName == "MaxPool") {
|
||||||
ImportNodeMaxPool(node, 1, {
|
ImportNodeMaxPool(node, 1, 2, {
|
||||||
{"auto_pad","str","NOTSET"}
|
{"auto_pad", "NOTSET"}
|
||||||
,{"ceil_mode","int","0"}
|
,{"ceil_mode", 0}
|
||||||
,{"dilations","", ""}
|
,{"kernel_shape", std::vector<int64_t> {}}
|
||||||
,{"kernel_shape","ints", ""}
|
,{"storage_order", 0}
|
||||||
,{"pads","", ""}
|
|
||||||
,{"storage_order","int","0"}
|
|
||||||
,{"strides","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "MaxRoiPool") {
|
}else if (OpName == "MaxRoiPool") {
|
||||||
ImportNodeOneOut<mlir::ONNXMaxRoiPoolOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXMaxRoiPoolOp>(node, 2, 1, {
|
||||||
{"pooled_shape","", ""}
|
{"spatial_scale", (float)1.0}
|
||||||
,{"spatial_scale","float","1.0"}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "MaxUnpool") {
|
}else if (OpName == "MaxUnpool") {
|
||||||
ImportNodeOneOut<mlir::ONNXMaxUnpoolOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXMaxUnpoolOp>(node, 3, 1, {
|
||||||
{"kernel_shape","", ""}
|
|
||||||
,{"pads","", ""}
|
|
||||||
,{"strides","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "Mean") {
|
}else if (OpName == "Mean") {
|
||||||
ImportNodeOneOut<mlir::ONNXMeanOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXMeanOp>(node, 1, 1, {
|
||||||
});
|
});
|
||||||
}else if (OpName == "MeanVarianceNormalization") {
|
}else if (OpName == "MeanVarianceNormalization") {
|
||||||
ImportNodeOneOut<mlir::ONNXMeanVarianceNormalizationOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXMeanVarianceNormalizationOp>(node, 1, 1, {
|
||||||
{"axes","ints","{'0', '2', '3'}"}
|
{"axes", std::vector<int64_t>{0, 2, 3}}
|
||||||
});
|
});
|
||||||
}else if (OpName == "Min") {
|
}else if (OpName == "Min") {
|
||||||
ImportNodeOneOut<mlir::ONNXMinOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXMinOp>(node, 1, 1, {
|
||||||
});
|
});
|
||||||
}else if (OpName == "Mod") {
|
}else if (OpName == "Mod") {
|
||||||
ImportNodeOneOut<mlir::ONNXModOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXModOp>(node, 2, 1, {
|
||||||
{"fmod","int","0"}
|
{"fmod", 0}
|
||||||
});
|
});
|
||||||
}else if (OpName == "Mul") {
|
}else if (OpName == "Mul") {
|
||||||
ImportNodeOneOut<mlir::ONNXMulOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXMulOp>(node, 2, 1, {
|
||||||
});
|
});
|
||||||
}else if (OpName == "Multinomial") {
|
}else if (OpName == "Multinomial") {
|
||||||
ImportNodeOneOut<mlir::ONNXMultinomialOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXMultinomialOp>(node, 1, 1, {
|
||||||
{"dtype","int","6"}
|
{"dtype", 6}
|
||||||
,{"sample_size","int","1"}
|
,{"sample_size", 1}
|
||||||
,{"seed","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "Neg") {
|
}else if (OpName == "Neg") {
|
||||||
ImportNodeOneOut<mlir::ONNXNegOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXNegOp>(node, 1, 1, {
|
||||||
});
|
});
|
||||||
}else if (OpName == "NonMaxSuppression") {
|
}else if (OpName == "NonMaxSuppression") {
|
||||||
ImportNodeOneOut<mlir::ONNXNonMaxSuppressionOp>(node, 5, 1, {
|
ImportNodeOneOut<mlir::ONNXNonMaxSuppressionOp>(node, 5, 1, {
|
||||||
{"center_point_box","int","0"}
|
{"center_point_box", 0}
|
||||||
});
|
});
|
||||||
}else if (OpName == "NonZero") {
|
}else if (OpName == "NonZero") {
|
||||||
ImportNodeOneOut<mlir::ONNXNonZeroOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXNonZeroOp>(node, 1, 1, {
|
||||||
|
@ -361,7 +312,7 @@
|
||||||
});
|
});
|
||||||
}else if (OpName == "OneHot") {
|
}else if (OpName == "OneHot") {
|
||||||
ImportNodeOneOut<mlir::ONNXOneHotOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXOneHotOp>(node, 3, 1, {
|
||||||
{"axis","int","-1"}
|
{"axis", -1}
|
||||||
});
|
});
|
||||||
}else if (OpName == "Or") {
|
}else if (OpName == "Or") {
|
||||||
ImportNodeOneOut<mlir::ONNXOrOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXOrOp>(node, 2, 1, {
|
||||||
|
@ -371,19 +322,15 @@
|
||||||
});
|
});
|
||||||
}else if (OpName == "Pad") {
|
}else if (OpName == "Pad") {
|
||||||
ImportNodeOneOut<mlir::ONNXPadOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXPadOp>(node, 3, 1, {
|
||||||
{"mode","str","constant"}
|
{"mode", "constant"}
|
||||||
});
|
});
|
||||||
}else if (OpName == "Pow") {
|
}else if (OpName == "Pow") {
|
||||||
ImportNodeOneOut<mlir::ONNXPowOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXPowOp>(node, 2, 1, {
|
||||||
});
|
});
|
||||||
}else if (OpName == "QLinearConv") {
|
}else if (OpName == "QLinearConv") {
|
||||||
ImportNodeOneOut<mlir::ONNXQLinearConvOp>(node, 9, 1, {
|
ImportNodeOneOut<mlir::ONNXQLinearConvOp>(node, 9, 1, {
|
||||||
{"auto_pad","str","NOTSET"}
|
{"auto_pad", "NOTSET"}
|
||||||
,{"dilations","", ""}
|
,{"group", 1}
|
||||||
,{"group","int","1"}
|
|
||||||
,{"kernel_shape","", ""}
|
|
||||||
,{"pads","", ""}
|
|
||||||
,{"strides","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "QLinearMatMul") {
|
}else if (OpName == "QLinearMatMul") {
|
||||||
ImportNodeOneOut<mlir::ONNXQLinearMatMulOp>(node, 8, 1, {
|
ImportNodeOneOut<mlir::ONNXQLinearMatMulOp>(node, 8, 1, {
|
||||||
|
@ -393,42 +340,32 @@
|
||||||
});
|
});
|
||||||
}else if (OpName == "RNN") {
|
}else if (OpName == "RNN") {
|
||||||
ImportNodeMultipleOuts<mlir::ONNXRNNOp>(node, 6, 2, {
|
ImportNodeMultipleOuts<mlir::ONNXRNNOp>(node, 6, 2, {
|
||||||
{"activation_alpha","floats", "{}"}
|
{"activation_alpha", std::vector<float> {}}
|
||||||
,{"activation_beta","floats", "{}"}
|
,{"activation_beta", std::vector<float> {}}
|
||||||
,{"activations","", "{Tannh, Tanh}"}
|
,{"activations", std::vector<std::string>{"Tanh", "Tanh"}}
|
||||||
,{"clip","", ""}
|
,{"direction", "forward"}
|
||||||
,{"direction","str","forward"}
|
|
||||||
,{"hidden_size","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "RandomNormal") {
|
}else if (OpName == "RandomNormal") {
|
||||||
ImportNodeOneOut<mlir::ONNXRandomNormalOp>(node, 0, 1, {
|
ImportNodeOneOut<mlir::ONNXRandomNormalOp>(node, 0, 1, {
|
||||||
{"dtype","int","1"}
|
{"dtype", 1}
|
||||||
,{"mean","float","0.0"}
|
,{"mean", (float)0.0}
|
||||||
,{"scale","float","1.0"}
|
,{"scale", (float)1.0}
|
||||||
,{"seed","", ""}
|
|
||||||
,{"shape","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "RandomNormalLike") {
|
}else if (OpName == "RandomNormalLike") {
|
||||||
ImportNodeOneOut<mlir::ONNXRandomNormalLikeOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXRandomNormalLikeOp>(node, 1, 1, {
|
||||||
{"dtype","", ""}
|
{"mean", (float)0.0}
|
||||||
,{"mean","float","0.0"}
|
,{"scale", (float)1.0}
|
||||||
,{"scale","float","1.0"}
|
|
||||||
,{"seed","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "RandomUniform") {
|
}else if (OpName == "RandomUniform") {
|
||||||
ImportNodeOneOut<mlir::ONNXRandomUniformOp>(node, 0, 1, {
|
ImportNodeOneOut<mlir::ONNXRandomUniformOp>(node, 0, 1, {
|
||||||
{"dtype","int","1"}
|
{"dtype", 1}
|
||||||
,{"high","float","1.0"}
|
,{"high", (float)1.0}
|
||||||
,{"low","float","0.0"}
|
,{"low", (float)0.0}
|
||||||
,{"seed","", ""}
|
|
||||||
,{"shape","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "RandomUniformLike") {
|
}else if (OpName == "RandomUniformLike") {
|
||||||
ImportNodeOneOut<mlir::ONNXRandomUniformLikeOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXRandomUniformLikeOp>(node, 1, 1, {
|
||||||
{"dtype","", ""}
|
{"high", (float)1.0}
|
||||||
,{"high","float","1.0"}
|
,{"low", (float)0.0}
|
||||||
,{"low","float","0.0"}
|
|
||||||
,{"seed","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "Range") {
|
}else if (OpName == "Range") {
|
||||||
ImportNodeOneOut<mlir::ONNXRangeOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXRangeOp>(node, 3, 1, {
|
||||||
|
@ -438,53 +375,43 @@
|
||||||
});
|
});
|
||||||
}else if (OpName == "ReduceL1") {
|
}else if (OpName == "ReduceL1") {
|
||||||
ImportNodeOneOut<mlir::ONNXReduceL1Op>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXReduceL1Op>(node, 1, 1, {
|
||||||
{"axes","", ""}
|
{"keepdims", 1}
|
||||||
,{"keepdims","int","1"}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "ReduceL2") {
|
}else if (OpName == "ReduceL2") {
|
||||||
ImportNodeOneOut<mlir::ONNXReduceL2Op>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXReduceL2Op>(node, 1, 1, {
|
||||||
{"axes","", ""}
|
{"keepdims", 1}
|
||||||
,{"keepdims","int","1"}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "ReduceLogSum") {
|
}else if (OpName == "ReduceLogSum") {
|
||||||
ImportNodeOneOut<mlir::ONNXReduceLogSumOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXReduceLogSumOp>(node, 1, 1, {
|
||||||
{"axes","", ""}
|
{"keepdims", 1}
|
||||||
,{"keepdims","int","1"}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "ReduceLogSumExp") {
|
}else if (OpName == "ReduceLogSumExp") {
|
||||||
ImportNodeOneOut<mlir::ONNXReduceLogSumExpOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXReduceLogSumExpOp>(node, 1, 1, {
|
||||||
{"axes","", ""}
|
{"keepdims", 1}
|
||||||
,{"keepdims","int","1"}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "ReduceMax") {
|
}else if (OpName == "ReduceMax") {
|
||||||
ImportNodeOneOut<mlir::ONNXReduceMaxOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXReduceMaxOp>(node, 1, 1, {
|
||||||
{"axes","", ""}
|
{"keepdims", 1}
|
||||||
,{"keepdims","int","1"}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "ReduceMean") {
|
}else if (OpName == "ReduceMean") {
|
||||||
ImportNodeOneOut<mlir::ONNXReduceMeanOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXReduceMeanOp>(node, 1, 1, {
|
||||||
{"axes","", ""}
|
{"keepdims", 1}
|
||||||
,{"keepdims","int","1"}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "ReduceMin") {
|
}else if (OpName == "ReduceMin") {
|
||||||
ImportNodeOneOut<mlir::ONNXReduceMinOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXReduceMinOp>(node, 1, 1, {
|
||||||
{"axes","", ""}
|
{"keepdims", 1}
|
||||||
,{"keepdims","int","1"}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "ReduceProd") {
|
}else if (OpName == "ReduceProd") {
|
||||||
ImportNodeOneOut<mlir::ONNXReduceProdOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXReduceProdOp>(node, 1, 1, {
|
||||||
{"axes","", ""}
|
{"keepdims", 1}
|
||||||
,{"keepdims","int","1"}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "ReduceSum") {
|
}else if (OpName == "ReduceSum") {
|
||||||
ImportNodeOneOut<mlir::ONNXReduceSumOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXReduceSumOp>(node, 1, 1, {
|
||||||
{"axes","", ""}
|
{"keepdims", 1}
|
||||||
,{"keepdims","int","1"}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "ReduceSumSquare") {
|
}else if (OpName == "ReduceSumSquare") {
|
||||||
ImportNodeOneOut<mlir::ONNXReduceSumSquareOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXReduceSumSquareOp>(node, 1, 1, {
|
||||||
{"axes","", ""}
|
{"keepdims", 1}
|
||||||
,{"keepdims","int","1"}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "Relu") {
|
}else if (OpName == "Relu") {
|
||||||
ImportNodeOneOut<mlir::ONNXReluOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXReluOp>(node, 1, 1, {
|
||||||
|
@ -494,53 +421,47 @@
|
||||||
});
|
});
|
||||||
}else if (OpName == "Resize") {
|
}else if (OpName == "Resize") {
|
||||||
ImportNodeOneOut<mlir::ONNXResizeOp>(node, 4, 1, {
|
ImportNodeOneOut<mlir::ONNXResizeOp>(node, 4, 1, {
|
||||||
{"coordinate_transformation_mode","str","half_pixel"}
|
{"coordinate_transformation_mode", "half_pixel"}
|
||||||
,{"cubic_coeff_a","float","-0.75"}
|
,{"cubic_coeff_a", (float)-0.75}
|
||||||
,{"exclude_outside","int","0"}
|
,{"exclude_outside", 0}
|
||||||
,{"extrapolation_value","float","0.0"}
|
,{"extrapolation_value", (float)0.0}
|
||||||
,{"mode","str","nearest"}
|
,{"mode", "nearest"}
|
||||||
,{"nearest_mode","str","round_prefer_floor"}
|
,{"nearest_mode", "round_prefer_floor"}
|
||||||
});
|
});
|
||||||
}else if (OpName == "ReverseSequence") {
|
}else if (OpName == "ReverseSequence") {
|
||||||
ImportNodeOneOut<mlir::ONNXReverseSequenceOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXReverseSequenceOp>(node, 2, 1, {
|
||||||
{"batch_axis","int","1"}
|
{"batch_axis", 1}
|
||||||
,{"time_axis","int","0"}
|
,{"time_axis", 0}
|
||||||
});
|
});
|
||||||
}else if (OpName == "RoiAlign") {
|
}else if (OpName == "RoiAlign") {
|
||||||
ImportNodeOneOut<mlir::ONNXRoiAlignOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXRoiAlignOp>(node, 3, 1, {
|
||||||
{"mode","str","avg"}
|
{"mode", "avg"}
|
||||||
,{"output_height","int","1"}
|
,{"output_height", 1}
|
||||||
,{"output_width","int","1"}
|
,{"output_width", 1}
|
||||||
,{"sampling_ratio","int","0"}
|
,{"sampling_ratio", 0}
|
||||||
,{"spatial_scale","float","1.0"}
|
,{"spatial_scale", (float)1.0}
|
||||||
});
|
});
|
||||||
}else if (OpName == "Round") {
|
}else if (OpName == "Round") {
|
||||||
ImportNodeOneOut<mlir::ONNXRoundOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXRoundOp>(node, 1, 1, {
|
||||||
});
|
});
|
||||||
}else if (OpName == "Scan") {
|
}else if (OpName == "Scan") {
|
||||||
ImportNodeOneOut<mlir::ONNXScanOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXScanOp>(node, 1, 1, {
|
||||||
{"body","", ""}
|
|
||||||
,{"num_scan_inputs","", ""}
|
|
||||||
,{"scan_input_axes","", ""}
|
|
||||||
,{"scan_input_directions","", ""}
|
|
||||||
,{"scan_output_axes","", ""}
|
|
||||||
,{"scan_output_directions","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "Scatter") {
|
}else if (OpName == "Scatter") {
|
||||||
ImportNodeOneOut<mlir::ONNXScatterOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXScatterOp>(node, 3, 1, {
|
||||||
{"axis","int","0"}
|
{"axis", 0}
|
||||||
});
|
});
|
||||||
}else if (OpName == "ScatterElements") {
|
}else if (OpName == "ScatterElements") {
|
||||||
ImportNodeOneOut<mlir::ONNXScatterElementsOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXScatterElementsOp>(node, 3, 1, {
|
||||||
{"axis","int","0"}
|
{"axis", 0}
|
||||||
});
|
});
|
||||||
}else if (OpName == "ScatterND") {
|
}else if (OpName == "ScatterND") {
|
||||||
ImportNodeOneOut<mlir::ONNXScatterNDOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXScatterNDOp>(node, 3, 1, {
|
||||||
});
|
});
|
||||||
}else if (OpName == "Selu") {
|
}else if (OpName == "Selu") {
|
||||||
ImportNodeOneOut<mlir::ONNXSeluOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXSeluOp>(node, 1, 1, {
|
||||||
{"alpha","float","1.67326"}
|
{"alpha", (float)1.67326}
|
||||||
,{"gamma","float","1.0507"}
|
,{"gamma", (float)1.0507}
|
||||||
});
|
});
|
||||||
}else if (OpName == "SequenceAt") {
|
}else if (OpName == "SequenceAt") {
|
||||||
ImportNodeOneOut<mlir::ONNXSequenceAtOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXSequenceAtOp>(node, 2, 1, {
|
||||||
|
@ -550,7 +471,6 @@
|
||||||
});
|
});
|
||||||
}else if (OpName == "SequenceEmpty") {
|
}else if (OpName == "SequenceEmpty") {
|
||||||
ImportNodeOneOut<mlir::ONNXSequenceEmptyOp>(node, 0, 1, {
|
ImportNodeOneOut<mlir::ONNXSequenceEmptyOp>(node, 0, 1, {
|
||||||
{"dtype","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "SequenceErase") {
|
}else if (OpName == "SequenceErase") {
|
||||||
ImportNodeOneOut<mlir::ONNXSequenceEraseOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXSequenceEraseOp>(node, 2, 1, {
|
||||||
|
@ -566,8 +486,8 @@
|
||||||
});
|
});
|
||||||
}else if (OpName == "Shrink") {
|
}else if (OpName == "Shrink") {
|
||||||
ImportNodeOneOut<mlir::ONNXShrinkOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXShrinkOp>(node, 1, 1, {
|
||||||
{"bias","float","0.0"}
|
{"bias", (float)0.0}
|
||||||
,{"lambd","float","0.5"}
|
,{"lambd", (float)0.5}
|
||||||
});
|
});
|
||||||
}else if (OpName == "Sigmoid") {
|
}else if (OpName == "Sigmoid") {
|
||||||
ImportNodeOneOut<mlir::ONNXSigmoidOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXSigmoidOp>(node, 1, 1, {
|
||||||
|
@ -589,7 +509,7 @@
|
||||||
});
|
});
|
||||||
}else if (OpName == "Softmax") {
|
}else if (OpName == "Softmax") {
|
||||||
ImportNodeOneOut<mlir::ONNXSoftmaxOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXSoftmaxOp>(node, 1, 1, {
|
||||||
{"axis","int","1"}
|
{"axis", 1}
|
||||||
});
|
});
|
||||||
}else if (OpName == "Softplus") {
|
}else if (OpName == "Softplus") {
|
||||||
ImportNodeOneOut<mlir::ONNXSoftplusOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXSoftplusOp>(node, 1, 1, {
|
||||||
|
@ -599,31 +519,26 @@
|
||||||
});
|
});
|
||||||
}else if (OpName == "SpaceToDepth") {
|
}else if (OpName == "SpaceToDepth") {
|
||||||
ImportNodeOneOut<mlir::ONNXSpaceToDepthOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXSpaceToDepthOp>(node, 1, 1, {
|
||||||
{"blocksize","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "Split") {
|
}else if (OpName == "Split") {
|
||||||
ImportNodeOneOut<mlir::ONNXSplitOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXSplitOp>(node, 1, 1, {
|
||||||
{"axis","int","0"}
|
{"axis", 0}
|
||||||
,{"split","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "SplitToSequence") {
|
}else if (OpName == "SplitToSequence") {
|
||||||
ImportNodeOneOut<mlir::ONNXSplitToSequenceOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXSplitToSequenceOp>(node, 2, 1, {
|
||||||
{"axis","int","0"}
|
{"axis", 0}
|
||||||
,{"keepdims","int","1"}
|
,{"keepdims", 1}
|
||||||
});
|
});
|
||||||
}else if (OpName == "Sqrt") {
|
}else if (OpName == "Sqrt") {
|
||||||
ImportNodeOneOut<mlir::ONNXSqrtOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXSqrtOp>(node, 1, 1, {
|
||||||
});
|
});
|
||||||
}else if (OpName == "Squeeze") {
|
}else if (OpName == "Squeeze") {
|
||||||
ImportNodeOneOut<mlir::ONNXSqueezeOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXSqueezeOp>(node, 1, 1, {
|
||||||
{"axes","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "StringNormalizer") {
|
}else if (OpName == "StringNormalizer") {
|
||||||
ImportNodeOneOut<mlir::ONNXStringNormalizerOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXStringNormalizerOp>(node, 1, 1, {
|
||||||
{"case_change_action","str","NONE"}
|
{"case_change_action", "NONE"}
|
||||||
,{"is_case_sensitive","int","0"}
|
,{"is_case_sensitive", 0}
|
||||||
,{"locale","", ""}
|
|
||||||
,{"stopwords","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "Sub") {
|
}else if (OpName == "Sub") {
|
||||||
ImportNodeOneOut<mlir::ONNXSubOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXSubOp>(node, 2, 1, {
|
||||||
|
@ -639,45 +554,34 @@
|
||||||
});
|
});
|
||||||
}else if (OpName == "TfIdfVectorizer") {
|
}else if (OpName == "TfIdfVectorizer") {
|
||||||
ImportNodeOneOut<mlir::ONNXTfIdfVectorizerOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXTfIdfVectorizerOp>(node, 1, 1, {
|
||||||
{"max_gram_length","", ""}
|
|
||||||
,{"max_skip_count","", ""}
|
|
||||||
,{"min_gram_length","", ""}
|
|
||||||
,{"mode","", ""}
|
|
||||||
,{"ngram_counts","", ""}
|
|
||||||
,{"ngram_indexes","", ""}
|
|
||||||
,{"pool_int64s","", ""}
|
|
||||||
,{"pool_strings","", ""}
|
|
||||||
,{"weights","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "ThresholdedRelu") {
|
}else if (OpName == "ThresholdedRelu") {
|
||||||
ImportNodeOneOut<mlir::ONNXThresholdedReluOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXThresholdedReluOp>(node, 1, 1, {
|
||||||
{"alpha","float","1.0"}
|
{"alpha", (float)1.0}
|
||||||
});
|
});
|
||||||
}else if (OpName == "Tile") {
|
}else if (OpName == "Tile") {
|
||||||
ImportNodeOneOut<mlir::ONNXTileOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXTileOp>(node, 2, 1, {
|
||||||
});
|
});
|
||||||
}else if (OpName == "TopK") {
|
}else if (OpName == "TopK") {
|
||||||
ImportNodeMultipleOuts<mlir::ONNXTopKOp>(node, 2, 2, {
|
ImportNodeMultipleOuts<mlir::ONNXTopKOp>(node, 2, 2, {
|
||||||
{"axis","int","-1"}
|
{"axis", -1}
|
||||||
,{"largest","int","1"}
|
,{"largest", 1}
|
||||||
,{"sorted","int","1"}
|
,{"sorted", 1}
|
||||||
});
|
});
|
||||||
}else if (OpName == "Transpose") {
|
}else if (OpName == "Transpose") {
|
||||||
ImportNodeOneOut<mlir::ONNXTransposeOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXTransposeOp>(node, 1, 1, {
|
||||||
{"perm","", ""}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "Unique") {
|
}else if (OpName == "Unique") {
|
||||||
ImportNodeMultipleOuts<mlir::ONNXUniqueOp>(node, 1, 4, {
|
ImportNodeMultipleOuts<mlir::ONNXUniqueOp>(node, 1, 4, {
|
||||||
{"axis","", ""}
|
{"sorted", 1}
|
||||||
,{"sorted","int","1"}
|
|
||||||
});
|
});
|
||||||
}else if (OpName == "Unsqueeze") {
|
}else if (OpName == "Unsqueeze") {
|
||||||
ImportNodeOneOut<mlir::ONNXUnsqueezeOp>(node, 1, 1, {
|
ImportNodeOneOut<mlir::ONNXUnsqueezeOp>(node, 1, 1, {
|
||||||
{"axes","ints", ""}
|
{"axes", std::vector<int64_t> {}}
|
||||||
});
|
});
|
||||||
}else if (OpName == "Upsample") {
|
}else if (OpName == "Upsample") {
|
||||||
ImportNodeOneOut<mlir::ONNXUpsampleOp>(node, 2, 1, {
|
ImportNodeOneOut<mlir::ONNXUpsampleOp>(node, 2, 1, {
|
||||||
{"mode","str","nearest"}
|
{"mode", "nearest"}
|
||||||
});
|
});
|
||||||
}else if (OpName == "Where") {
|
}else if (OpName == "Where") {
|
||||||
ImportNodeOneOut<mlir::ONNXWhereOp>(node, 3, 1, {
|
ImportNodeOneOut<mlir::ONNXWhereOp>(node, 3, 1, {
|
||||||
|
|
|
@ -267,7 +267,7 @@ def gen_schema(schema) :
|
||||||
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
|
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
|
||||||
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
||||||
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||||
'Identity', 'Cos', 'Log', 'Transpose']
|
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax']
|
||||||
CanonicalList=['Add', 'Identity']
|
CanonicalList=['Add', 'Identity']
|
||||||
line_indent = ' '
|
line_indent = ' '
|
||||||
|
|
||||||
|
@ -368,17 +368,17 @@ def gen_code(schema,fefile) :
|
||||||
("MaxPool", "ImportNodeMaxPool"),
|
("MaxPool", "ImportNodeMaxPool"),
|
||||||
#("Transpose", "ImportNodeTranspose")
|
#("Transpose", "ImportNodeTranspose")
|
||||||
])
|
])
|
||||||
special_type = dict([
|
list_str = 'std::vector'
|
||||||
("AveragePool "+"kernel_shape", '"ints", ""'),
|
empty_ints = list_str+'<int> {}'
|
||||||
("MaxPool "+"kernel_shape", '"ints", ""'),
|
empty_floats = list_str+'<float> {}'
|
||||||
("Cast "+"to", '"int", "0"'),
|
special_default = dict([
|
||||||
("Concat "+"axis", '"int", "0"'),
|
("AveragePool "+"kernel_shape", empty_ints),
|
||||||
("Conv "+"group", '"int", "1"'),
|
("MaxPool "+"kernel_shape", empty_ints),
|
||||||
("Unsqueeze "+"axes", '"ints", ""'),
|
("Cast "+"to", '0'),
|
||||||
("RNN "+"activation_alpha", '"floats", "{}"'),
|
("Concat "+"axis", '0'),
|
||||||
("RNN "+"activation_beta", '"floats", "{}"'),
|
("Unsqueeze "+"axes", empty_ints),
|
||||||
("RNN "+"activations", '"", "{Tannh, Tanh}"'),
|
("RNN "+"activation_alpha", empty_floats),
|
||||||
("LRN "+"size", '"int", ""')
|
("RNN "+"activation_beta", empty_floats)
|
||||||
])
|
])
|
||||||
line_indent = ' '
|
line_indent = ' '
|
||||||
fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n')
|
fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n')
|
||||||
|
@ -400,21 +400,9 @@ def gen_code(schema,fefile) :
|
||||||
if schema.attributes:
|
if schema.attributes:
|
||||||
first_attr = True
|
first_attr = True
|
||||||
for _, attr in sorted(schema.attributes.items()):
|
for _, attr in sorted(schema.attributes.items()):
|
||||||
attr_line = line_indent+line_indent+line_indent+line_indent
|
#only generate default attr list
|
||||||
if not first_attr:
|
if schema.name+' '+attr.name in special_default:
|
||||||
attr_line += ',{'
|
attr_value = special_default[schema.name+' '+attr.name]
|
||||||
else :
|
|
||||||
attr_line += ' {'
|
|
||||||
first_attr = False
|
|
||||||
|
|
||||||
attr_line += '"'+attr.name+'",'
|
|
||||||
|
|
||||||
if schema.name+' '+attr.name in special_type:
|
|
||||||
attr_line += special_type[schema.name+' '+attr.name]
|
|
||||||
# option holds either required or default value
|
|
||||||
elif attr.required:
|
|
||||||
attr_line += '"", ""'
|
|
||||||
|
|
||||||
elif attr.default_value.name:
|
elif attr.default_value.name:
|
||||||
default_value = helper.get_attribute_value(attr.default_value)
|
default_value = helper.get_attribute_value(attr.default_value)
|
||||||
|
|
||||||
|
@ -430,28 +418,35 @@ def gen_code(schema,fefile) :
|
||||||
return str(value)
|
return str(value)
|
||||||
|
|
||||||
if isinstance(default_value, list):
|
if isinstance(default_value, list):
|
||||||
|
|
||||||
value = default_value[0]
|
value = default_value[0]
|
||||||
default_value = [format_value(val) for val in default_value]
|
default_value = [format_value(val) for val in default_value]
|
||||||
|
attr_option_str = '{}'.format(default_value)
|
||||||
|
attr_option_str = attr_option_str.replace('[', '{', 1)
|
||||||
|
attr_option_str = attr_option_str.replace(']', '}', 1)
|
||||||
# TODO the list type is homogenous or htergeneous?
|
# TODO the list type is homogenous or htergeneous?
|
||||||
|
|
||||||
if isinstance(value, float) :
|
if isinstance(value, float) :
|
||||||
attr_type_str = '"floats"'
|
attr_type_str = list_str+'<float>'
|
||||||
|
attr_option_str = attr_option_str.replace("'", '')
|
||||||
elif isinstance(value, int) :
|
elif isinstance(value, int) :
|
||||||
attr_type_str = '"ints"'
|
attr_type_str = list_str+'<int>'
|
||||||
|
attr_option_str = attr_option_str.replace("'", '')
|
||||||
elif isinstance(value, str) :
|
elif isinstance(value, str) :
|
||||||
attr_type_str = '"strs"'
|
attr_type_str = list_str+'<std::string>'
|
||||||
|
attr_option_str = attr_option_str.replace("'", '"')
|
||||||
elif isinstance(value, (bytes, bytearray)) :
|
elif isinstance(value, (bytes, bytearray)) :
|
||||||
attr_type_str = '"strs"'
|
attr_type_str = list_str+'<std::string>'
|
||||||
|
attr_option_str = attr_option_str.replace("'", '"')
|
||||||
else :
|
else :
|
||||||
attr_type_str = '"unknowns"'
|
attr_type_str = '"unknowns"'
|
||||||
attr_option_str = '"{}"'.format(default_value)
|
|
||||||
attr_option_str = attr_option_str.replace('[', '{', 1)
|
|
||||||
attr_option_str = attr_option_str.replace(']', '}', 1)
|
|
||||||
else:
|
else:
|
||||||
if isinstance(default_value, float) :
|
if isinstance(default_value, float) :
|
||||||
attr_type_str = '"float"'
|
attr_type_str = '(float)'
|
||||||
|
attr_option_str = default_value
|
||||||
elif isinstance(default_value, int) :
|
elif isinstance(default_value, int) :
|
||||||
attr_type_str = '"int"'
|
attr_option_str = default_value
|
||||||
|
attr_type_str=''
|
||||||
elif isinstance(default_value, str) :
|
elif isinstance(default_value, str) :
|
||||||
attr_type_str = '"str"'
|
attr_type_str = '"str"'
|
||||||
elif isinstance(default_value, (bytes, bytearray)) :
|
elif isinstance(default_value, (bytes, bytearray)) :
|
||||||
|
@ -459,11 +454,25 @@ def gen_code(schema,fefile) :
|
||||||
else :
|
else :
|
||||||
attr_type_str = '"unknown"'
|
attr_type_str = '"unknown"'
|
||||||
default_value = format_value(default_value)
|
default_value = format_value(default_value)
|
||||||
attr_option_str = '"{}"'.format(default_value)
|
if attr_type_str == '"str"' :
|
||||||
attr_line += attr_type_str+','+attr_option_str
|
attr_option_str = '"'+default_value+'"'
|
||||||
|
attr_type_str=''
|
||||||
|
else :
|
||||||
|
attr_option_str = default_value
|
||||||
|
attr_value = attr_type_str+attr_option_str
|
||||||
else:
|
else:
|
||||||
#TODO why?
|
#no default value
|
||||||
attr_line += '"", ""'
|
continue
|
||||||
|
|
||||||
|
attr_line = line_indent+line_indent+line_indent+line_indent
|
||||||
|
if not first_attr:
|
||||||
|
attr_line += ',{'
|
||||||
|
else :
|
||||||
|
attr_line += ' {'
|
||||||
|
first_attr = False
|
||||||
|
|
||||||
|
attr_line += '"'+attr.name+'", '
|
||||||
|
attr_line += attr_value
|
||||||
attr_line += '}\n'
|
attr_line += '}\n'
|
||||||
fefile.write(attr_line)
|
fefile.write(attr_line)
|
||||||
fefile.write(line_indent+line_indent+line_indent+'});\n')
|
fefile.write(line_indent+line_indent+line_indent+'});\n')
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
#include "mlir/IR/Function.h"
|
#include "mlir/IR/Function.h"
|
||||||
#include "mlir/IR/IntegerSet.h"
|
#include "mlir/IR/IntegerSet.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include "mlir/IR/Module.h"
|
||||||
#include "mlir/IR/OpImplementation.h"
|
#include "mlir/IR/OpImplementation.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "llvm/ADT/SetVector.h"
|
#include "llvm/ADT/SetVector.h"
|
||||||
|
@ -157,6 +158,14 @@ void ONNXReciprocalOp::inferShapes() {
|
||||||
getResult().setType(getOperand().getType());
|
getResult().setType(getOperand().getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Softmax
|
||||||
|
/// Infer the output shape of the ONNXSoftmaxOp. This method is required by
|
||||||
|
/// the shape inference interface.
|
||||||
|
void ONNXSoftmaxOp::inferShapes() {
|
||||||
|
getResult().setType(getOperand().getType());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Add
|
// Add
|
||||||
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
||||||
|
@ -484,13 +493,38 @@ void ONNXTransposeOp::inferShapes() {
|
||||||
|
|
||||||
// Naive transposition which handles the default case of
|
// Naive transposition which handles the default case of
|
||||||
// reversing the shape of the tensor (similar to numpy.transpose).
|
// reversing the shape of the tensor (similar to numpy.transpose).
|
||||||
// TODO: Once attributes are supported we can handle the case where the
|
|
||||||
// transposition uses a permutation vector to interchange the axes.
|
|
||||||
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
||||||
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
|
SmallVector<int64_t, 2> dims;
|
||||||
|
|
||||||
|
if (auto permutation = getAttrOfType<ArrayAttr>(
|
||||||
|
ONNXTransposeOp::getPermAttrName())) {
|
||||||
|
// Perform transposition according to perm attribute.
|
||||||
|
for (auto perm : permutation.getValue())
|
||||||
|
dims.emplace_back(arrayTy.getShape()[perm.cast<IntegerAttr>().getInt()]);
|
||||||
|
} else {
|
||||||
|
// Default
|
||||||
|
for (auto dim : llvm::reverse(arrayTy.getShape()))
|
||||||
|
dims.emplace_back(dim);
|
||||||
|
}
|
||||||
|
|
||||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult verify(ONNXTransposeOp op) {
|
||||||
|
auto module = op.getParentOfType<ModuleOp>();
|
||||||
|
if (!module)
|
||||||
|
op.emitError("Expected to belong to a module.");
|
||||||
|
|
||||||
|
if (auto permutation = op.getAttrOfType<ArrayAttr>(
|
||||||
|
ONNXTransposeOp::getPermAttrName())) {
|
||||||
|
for (auto perm : permutation.getValue())
|
||||||
|
if (perm.cast<IntegerAttr>().getInt() < 0)
|
||||||
|
op.emitError("Cannot tranpose, permuation contains negative index.");
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TableGen'd op method definitions
|
// TableGen'd op method definitions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -2831,7 +2831,7 @@ def ONNXSliceOp:ONNX_Op<"Slice",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXSoftmaxOp:ONNX_Op<"Softmax",
|
def ONNXSoftmaxOp:ONNX_Op<"Softmax",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Softmax operation";
|
let summary = "ONNX Softmax operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"The operator computes the softmax (normalized exponential) values for each layer in the batch"
|
"The operator computes the softmax (normalized exponential) values for each layer in the batch"
|
||||||
|
@ -3098,6 +3098,12 @@ def ONNXTransposeOp:ONNX_Op<"Transpose",
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data);
|
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
static StringRef getPermAttrName() { return "perm"; }
|
||||||
|
}];
|
||||||
|
|
||||||
|
let verifier = [{ return ::verify(*this); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXUniqueOp:ONNX_Op<"Unique",
|
def ONNXUniqueOp:ONNX_Op<"Unique",
|
||||||
|
|
11
src/main.cpp
11
src/main.cpp
|
@ -135,10 +135,13 @@ int main(int argc, char *argv[]) {
|
||||||
if (mlir::failed(pm.run(*module)))
|
if (mlir::failed(pm.run(*module)))
|
||||||
return 4;
|
return 4;
|
||||||
|
|
||||||
module->dump();
|
|
||||||
|
|
||||||
// Write LLVM bitcode to disk.
|
if (emissionTarget == EmitLLVMBC) {
|
||||||
if (emissionTarget == EmitLLVMBC)
|
// Write LLVM bitcode to disk.
|
||||||
EmitLLVMBitCode(module);
|
EmitLLVMBitCode(module);
|
||||||
|
printf("LLVM bitcode written to ./model.bc");
|
||||||
|
} else
|
||||||
|
module->dump();
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
@ -420,8 +420,8 @@ Value mapToLowerScalarOp<ONNXHardSigmoidOp>(
|
||||||
// Constant 1)
|
// Constant 1)
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value operand = operands[0];
|
Value operand = operands[0];
|
||||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.alpha");
|
auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha");
|
||||||
auto betaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.beta");
|
auto betaAttr = op->getAttrOfType<FloatAttr>("beta");
|
||||||
auto elementType = result_types[0];
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
|
@ -455,7 +455,7 @@ Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
Value operand = operands[0];
|
Value operand = operands[0];
|
||||||
auto elementType = result_types[0];
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("Elu.alpha");
|
auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha");
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
||||||
|
@ -508,7 +508,7 @@ Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op,
|
||||||
Value operand = operands[0];
|
Value operand = operands[0];
|
||||||
auto elementType = result_types[0];
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("LeakyRelu.alpha");
|
auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha");
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
||||||
auto lessThanZero =
|
auto lessThanZero =
|
||||||
|
@ -533,8 +533,8 @@ Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
// alpha)))
|
// alpha)))
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value operand = operands[0];
|
Value operand = operands[0];
|
||||||
auto alphaAttr = op->getAttrOfType<FloatAttr>("Selu.alpha");
|
auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha");
|
||||||
auto gammaAttr = op->getAttrOfType<FloatAttr>("Selu.gamma");
|
auto gammaAttr = op->getAttrOfType<FloatAttr>("gamma");
|
||||||
auto elementType = result_types[0];
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
|
@ -824,6 +824,225 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
||||||
|
ONNXSoftmaxOpLowering(MLIRContext *ctx)
|
||||||
|
: ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {}
|
||||||
|
PatternMatchResult
|
||||||
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
// softmax(x) = let max_x = max(x) in
|
||||||
|
// let exp_x = exp(x - max_x) in
|
||||||
|
// let sum = sum(exp_x) in
|
||||||
|
// exp_x / sum
|
||||||
|
auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
|
||||||
|
int64_t rank = tensorType.getRank();
|
||||||
|
int64_t axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
|
||||||
|
axis = axis >= 0 ? axis : rank + axis;
|
||||||
|
assert(axis >= -rank && axis <= rank - 1);
|
||||||
|
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
|
auto memRefType = convertTensorToMemRef(tensorType);
|
||||||
|
auto elementType = memRefType.getElementType();
|
||||||
|
|
||||||
|
Value alloc;
|
||||||
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
if (hasAllConstantDimensions(memRefType))
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
|
else
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
|
||||||
|
operands[0]);
|
||||||
|
|
||||||
|
// Shape of the result
|
||||||
|
auto memRefShape = memRefType.getShape();
|
||||||
|
|
||||||
|
// Insert allocations and deallocations for sum and max.
|
||||||
|
MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0);
|
||||||
|
Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
|
||||||
|
Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
|
||||||
|
Value zero =
|
||||||
|
rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||||
|
Value negInfinity = rewriter.create<ConstantOp>(
|
||||||
|
loc,
|
||||||
|
FloatAttr::get(elementType, -std::numeric_limits<float>::infinity()));
|
||||||
|
|
||||||
|
// Define loops.
|
||||||
|
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
|
||||||
|
std::vector<Value> originalLoops;
|
||||||
|
originalLoops.reserve(rank);
|
||||||
|
for (auto result : loopsOp.getResults()) {
|
||||||
|
originalLoops.push_back(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define loop optimization.
|
||||||
|
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
|
||||||
|
std::vector<Value> optimizedLoops;
|
||||||
|
optimizedLoops.reserve(rank);
|
||||||
|
for (auto result : optimizedLoopsOp.getResults()) {
|
||||||
|
optimizedLoops.push_back(result);
|
||||||
|
}
|
||||||
|
Block &optimizationBlock = optimizedLoopsOp.region().front();
|
||||||
|
|
||||||
|
// Coerce the input into a 2-D tensor. `axis` will be the coercing point.
|
||||||
|
// This coercing follows the softmax definition in ONNX:
|
||||||
|
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softmax
|
||||||
|
// Here, we create an outer loop and inner loop for handling the two
|
||||||
|
// dimensions. The outer loop is only created once `axis` is not zero.
|
||||||
|
|
||||||
|
// Define an outer loop with respect to axis.
|
||||||
|
std::vector<Value> outerLoops, optimizedOuterLoops;
|
||||||
|
outerLoops.reserve(axis);
|
||||||
|
optimizedOuterLoops.reserve(axis);
|
||||||
|
for (int i = 0; i < axis; ++i) {
|
||||||
|
outerLoops.push_back(originalLoops[i]);
|
||||||
|
optimizedOuterLoops.push_back(optimizedLoops[i]);
|
||||||
|
}
|
||||||
|
KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops);
|
||||||
|
for (int i = 0; i < axis; ++i) {
|
||||||
|
if (memRefShape[i] < 0) {
|
||||||
|
outerPack.pushConstantBound(0);
|
||||||
|
outerPack.pushOperandBound(
|
||||||
|
rewriter.create<DimOp>(loc, operands[0], i).getResult());
|
||||||
|
} else {
|
||||||
|
outerPack.pushConstantBound(0);
|
||||||
|
outerPack.pushConstantBound(memRefShape[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Define an inner loop with respect to axis.
|
||||||
|
std::vector<Value> innerLoops, optimizedInnerLoops;
|
||||||
|
innerLoops.reserve(rank - axis);
|
||||||
|
optimizedInnerLoops.reserve(rank - axis);
|
||||||
|
for (int i = axis; i < rank; ++i) {
|
||||||
|
innerLoops.push_back(originalLoops[i]);
|
||||||
|
optimizedInnerLoops.push_back(optimizedLoops[i]);
|
||||||
|
}
|
||||||
|
KrnlIterateOperandPack innerPack(rewriter, innerLoops, optimizedInnerLoops);
|
||||||
|
for (int i = axis; i < rank; ++i) {
|
||||||
|
if (memRefShape[i] < 0) {
|
||||||
|
innerPack.pushConstantBound(0);
|
||||||
|
innerPack.pushOperandBound(
|
||||||
|
rewriter.create<DimOp>(loc, operands[0], i).getResult());
|
||||||
|
} else {
|
||||||
|
innerPack.pushConstantBound(0);
|
||||||
|
innerPack.pushConstantBound(memRefShape[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
KrnlIterateOp outerIterateOp, maxIterateOp, sumIterateOp, softmaxIterateOp;
|
||||||
|
SmallVector<Value, 4> outerLoopIVs;
|
||||||
|
if (axis != 0) {
|
||||||
|
outerIterateOp = rewriter.create<KrnlIterateOp>(loc, outerPack);
|
||||||
|
|
||||||
|
// No optimization
|
||||||
|
rewriter.setInsertionPointToEnd(&optimizationBlock);
|
||||||
|
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
||||||
|
rewriter.setInsertionPoint(optimizedLoopsOp);
|
||||||
|
|
||||||
|
// Insert instructions inside the outer loop.
|
||||||
|
Block &outerIterationBlock = outerIterateOp.bodyRegion().front();
|
||||||
|
rewriter.setInsertionPointToStart(&outerIterationBlock);
|
||||||
|
for (auto arg : outerIterationBlock.getArguments())
|
||||||
|
outerLoopIVs.push_back(arg);
|
||||||
|
|
||||||
|
// Reset accumulators.
|
||||||
|
rewriter.create<StoreOp>(loc, zero, sumOp);
|
||||||
|
rewriter.create<StoreOp>(loc, negInfinity, maxOp);
|
||||||
|
|
||||||
|
// Create an inner loop to compute max.
|
||||||
|
maxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
|
||||||
|
// Create an inner loop to compute sum.
|
||||||
|
sumIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
|
||||||
|
// Create an inner loop to compute softmax.
|
||||||
|
softmaxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
|
||||||
|
} else {
|
||||||
|
// Reset accumulators.
|
||||||
|
rewriter.create<StoreOp>(loc, zero, sumOp);
|
||||||
|
rewriter.create<StoreOp>(loc, negInfinity, maxOp);
|
||||||
|
|
||||||
|
// Create an inner loop to compute max.
|
||||||
|
maxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
|
||||||
|
// Create an inner loop to compute sum.
|
||||||
|
sumIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
|
||||||
|
// Create an inner loop to compute softmax.
|
||||||
|
softmaxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
|
||||||
|
|
||||||
|
// No optimization
|
||||||
|
rewriter.setInsertionPointToEnd(&optimizationBlock);
|
||||||
|
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
||||||
|
rewriter.setInsertionPoint(optimizedLoopsOp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert instructions inside the max loop.
|
||||||
|
Block &maxIterationBlock = maxIterateOp.bodyRegion().front();
|
||||||
|
rewriter.setInsertionPointToStart(&maxIterationBlock);
|
||||||
|
|
||||||
|
// Get induction variables.
|
||||||
|
SmallVector<Value, 4> maxLoopIVs;
|
||||||
|
for (auto arg : outerLoopIVs)
|
||||||
|
maxLoopIVs.push_back(arg);
|
||||||
|
for (auto arg : maxIterationBlock.getArguments())
|
||||||
|
maxLoopIVs.push_back(arg);
|
||||||
|
|
||||||
|
// Compute the max value.
|
||||||
|
Value max = rewriter.create<LoadOp>(loc, maxOp);
|
||||||
|
Value nextMax = rewriter.create<LoadOp>(loc, operands[0], maxLoopIVs);
|
||||||
|
auto maxCond =
|
||||||
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, max, nextMax);
|
||||||
|
max = rewriter.create<SelectOp>(loc, maxCond, max, nextMax);
|
||||||
|
rewriter.create<StoreOp>(loc, max, maxOp);
|
||||||
|
|
||||||
|
// Get the max.
|
||||||
|
rewriter.setInsertionPoint(sumIterateOp);
|
||||||
|
max = rewriter.create<LoadOp>(loc, maxOp);
|
||||||
|
|
||||||
|
// Insert instructions inside the sum loop.
|
||||||
|
Block &sumIterationBlock = sumIterateOp.bodyRegion().front();
|
||||||
|
rewriter.setInsertionPointToStart(&sumIterationBlock);
|
||||||
|
|
||||||
|
// Get induction variables.
|
||||||
|
SmallVector<Value, 4> sumLoopIVs;
|
||||||
|
for (auto arg : outerLoopIVs)
|
||||||
|
sumLoopIVs.push_back(arg);
|
||||||
|
for (auto arg : sumIterationBlock.getArguments())
|
||||||
|
sumLoopIVs.push_back(arg);
|
||||||
|
|
||||||
|
// Sum up values.
|
||||||
|
Value sum = rewriter.create<LoadOp>(loc, sumOp);
|
||||||
|
Value next = rewriter.create<LoadOp>(loc, operands[0], sumLoopIVs);
|
||||||
|
Value sub = rewriter.create<SubFOp>(loc, next, max);
|
||||||
|
Value exp = rewriter.create<ExpOp>(loc, sub);
|
||||||
|
sum = rewriter.create<AddFOp>(loc, sum, exp);
|
||||||
|
rewriter.create<StoreOp>(loc, sum, sumOp);
|
||||||
|
// Store intermediate values in the result to avoid recomputation.
|
||||||
|
rewriter.create<StoreOp>(loc, exp, alloc, sumLoopIVs);
|
||||||
|
|
||||||
|
// Get the sum.
|
||||||
|
rewriter.setInsertionPoint(softmaxIterateOp);
|
||||||
|
sum = rewriter.create<LoadOp>(loc, sumOp);
|
||||||
|
|
||||||
|
// Insert instructions inside the softmax loop.
|
||||||
|
Block &softmaxIterationBlock = softmaxIterateOp.bodyRegion().front();
|
||||||
|
rewriter.setInsertionPointToStart(&softmaxIterationBlock);
|
||||||
|
|
||||||
|
// Get induction variables.
|
||||||
|
SmallVector<Value, 4> softmaxLoopIVs;
|
||||||
|
for (auto arg : outerLoopIVs)
|
||||||
|
softmaxLoopIVs.push_back(arg);
|
||||||
|
for (auto arg : softmaxIterationBlock.getArguments())
|
||||||
|
softmaxLoopIVs.push_back(arg);
|
||||||
|
|
||||||
|
// Compute softmax.
|
||||||
|
Value expLoadedVal = rewriter.create<LoadOp>(loc, alloc, softmaxLoopIVs);
|
||||||
|
Value result = rewriter.create<DivFOp>(loc, expLoadedVal, sum);
|
||||||
|
rewriter.create<StoreOp>(loc, result, alloc, softmaxLoopIVs);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, alloc);
|
||||||
|
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct ONNXReshapeOpLowering : public ConversionPattern {
|
struct ONNXReshapeOpLowering : public ConversionPattern {
|
||||||
ONNXReshapeOpLowering(MLIRContext *ctx)
|
ONNXReshapeOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
|
||||||
|
@ -1005,7 +1224,8 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
|
||||||
ONNXReshapeOpLowering, ONNXEntryPointLowering>(&getContext());
|
ONNXReshapeOpLowering, ONNXEntryPointLowering,
|
||||||
|
ONNXSoftmaxOpLowering>(&getContext());
|
||||||
|
|
||||||
// With the target and rewrite patterns defined, we can now attempt the
|
// With the target and rewrite patterns defined, we can now attempt the
|
||||||
// conversion. The conversion will signal failure if any of our `illegal`
|
// conversion. The conversion will signal failure if any of our `illegal`
|
||||||
|
|
|
@ -116,7 +116,8 @@ public:
|
||||||
op->getName().getStringRef() != "onnx.Gemm" &&
|
op->getName().getStringRef() != "onnx.Gemm" &&
|
||||||
op->getName().getStringRef() != "onnx.GemmNoBias" &&
|
op->getName().getStringRef() != "onnx.GemmNoBias" &&
|
||||||
op->getName().getStringRef() != "onnx.Reshape" &&
|
op->getName().getStringRef() != "onnx.Reshape" &&
|
||||||
op->getName().getStringRef() != "onnx.Transpose")
|
op->getName().getStringRef() != "onnx.Transpose" &&
|
||||||
|
op->getName().getStringRef() != "onnx.Softmax")
|
||||||
return false;
|
return false;
|
||||||
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
||||||
return !result_type.isa<RankedTensorType>();
|
return !result_type.isa<RankedTensorType>();
|
||||||
|
|
|
@ -1 +1,2 @@
|
||||||
add_subdirectory(mlir)
|
add_subdirectory(mlir)
|
||||||
|
add_subdirectory(backend)
|
|
@ -0,0 +1,10 @@
|
||||||
|
configure_file(test.py test.py COPYONLY)
|
||||||
|
configure_file(test_config.py.in test_config.py)
|
||||||
|
|
||||||
|
find_package(PythonInterp 3 REQUIRED)
|
||||||
|
add_custom_target(run-onnx-backend-test
|
||||||
|
COMMAND ${PYTHON_EXECUTABLE}
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/test.py)
|
||||||
|
|
||||||
|
add_dependencies(run-onnx-backend-test onnf)
|
||||||
|
add_dependencies(run-onnx-backend-test pyruntime)
|
|
@ -3,46 +3,51 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import itertools
|
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
import onnx.backend.base
|
import onnx.backend.base
|
||||||
import onnx.backend.test
|
import onnx.backend.test
|
||||||
|
|
||||||
from onnx.backend.base import Device, DeviceType
|
from onnx.backend.base import Device, DeviceType
|
||||||
import onnx.shape_inference
|
|
||||||
import onnx.version_converter
|
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import test_config
|
||||||
|
|
||||||
|
VERBOSE = bool(os.environ.get("VERBOSE"))
|
||||||
|
|
||||||
|
CXX = test_config.CXX_PATH
|
||||||
|
ONNF = os.path.join(test_config.ONNF_BUILD_PATH, "bin/onnf")
|
||||||
|
LLC = os.path.join(test_config.LLVM_PROJ_BUILD_PATH, "bin/llc")
|
||||||
|
|
||||||
|
# Make lib folder under build directory visible in PYTHONPATH
|
||||||
|
doc_check_base_dir = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
RUNTIME_DIR = os.path.join(test_config.ONNF_BUILD_PATH, "lib")
|
||||||
|
sys.path.append(RUNTIME_DIR)
|
||||||
from pyruntime import ExecutionSession
|
from pyruntime import ExecutionSession
|
||||||
|
|
||||||
CXX = os.getenv('CXX')
|
|
||||||
ONNF = os.getenv('ONNF')
|
def execute_commands(cmds):
|
||||||
LLC = os.getenv('LLC')
|
if (VERBOSE):
|
||||||
RT_DIR = os.getenv('RT_DIR')
|
print(" ".join(cmds))
|
||||||
assert CXX and ONNF and LLC and RT_DIR, "tools path not set"
|
subprocess.run(cmds, stdout=subprocess.PIPE)
|
||||||
|
|
||||||
|
|
||||||
class DummyBackend(onnx.backend.base.Backend):
|
class DummyBackend(onnx.backend.base.Backend):
|
||||||
@classmethod
|
@classmethod
|
||||||
def prepare(
|
def prepare(cls, model, device='CPU', **kwargs):
|
||||||
cls,
|
|
||||||
model,
|
|
||||||
device='CPU',
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
super(DummyBackend, cls).prepare(model, device, **kwargs)
|
super(DummyBackend, cls).prepare(model, device, **kwargs)
|
||||||
# Save model to disk as temp_model.onnx.
|
# Save model to disk as temp_model.onnx.
|
||||||
onnx.save(model, "temp_model.onnx")
|
onnx.save(model, "temp_model.onnx")
|
||||||
# Call frontend to process temp_model.onnx, bit code will be generated.
|
# Call frontend to process temp_model.onnx, bit code will be generated.
|
||||||
subprocess.run([ONNF, "temp_model.onnx"], stdout=subprocess.PIPE)
|
execute_commands([ONNF, "temp_model.onnx"])
|
||||||
# Call llc to generate object file from bitcode.
|
# Call llc to generate object file from bitcode.
|
||||||
subprocess.run([LLC, "-filetype=obj", "model.bc"],
|
execute_commands(
|
||||||
stdout=subprocess.PIPE)
|
[LLC, "-filetype=obj", "-relocation-model=pic", "model.bc"])
|
||||||
# Generate shared library from object file, linking with c runtime.
|
# Generate shared library from object file, linking with c runtime.
|
||||||
subprocess.run([
|
execute_commands([
|
||||||
CXX, "-shared", "model.o", "-o", "model.so", "-L" + RT_DIR,
|
CXX, "-shared", "-fPIC", "model.o", "-o", "model.so",
|
||||||
"-lcruntime"
|
"-L" + RUNTIME_DIR, "-lcruntime"
|
||||||
],
|
])
|
||||||
stdout=subprocess.PIPE)
|
|
||||||
return ExecutionSession("./model.so", "_dyn_entry_point_main_graph")
|
return ExecutionSession("./model.so", "_dyn_entry_point_main_graph")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -125,6 +130,14 @@ test_to_enable = [
|
||||||
"test_sigmoid_cpu",
|
"test_sigmoid_cpu",
|
||||||
"test_sigmoid_example_cpu",
|
"test_sigmoid_example_cpu",
|
||||||
|
|
||||||
|
# Softmax Op:
|
||||||
|
"test_softmax_axis_0_cpu",
|
||||||
|
"test_softmax_axis_1_cpu",
|
||||||
|
"test_softmax_axis_2_cpu",
|
||||||
|
"test_softmax_default_axis_cpu",
|
||||||
|
"test_softmax_example_cpu",
|
||||||
|
"test_softmax_large_number_cpu",
|
||||||
|
|
||||||
# Sum Op:
|
# Sum Op:
|
||||||
#"test_sum_example_cpu", <- error
|
#"test_sum_example_cpu", <- error
|
||||||
"test_sum_one_input_cpu",
|
"test_sum_one_input_cpu",
|
||||||
|
@ -140,18 +153,15 @@ import inspect
|
||||||
all_tests = inspect.getmembers(
|
all_tests = inspect.getmembers(
|
||||||
backend_test.test_cases["OnnxBackendNodeModelTest"])
|
backend_test.test_cases["OnnxBackendNodeModelTest"])
|
||||||
all_test_names = list(map(lambda x: x[0], all_tests))
|
all_test_names = list(map(lambda x: x[0], all_tests))
|
||||||
|
|
||||||
|
# Ensure that test names specified in test_to_enable actually exist.
|
||||||
for test_name in test_to_enable:
|
for test_name in test_to_enable:
|
||||||
assert test_name in all_test_names, "test name {} not found".format(test_name)
|
assert test_name in all_test_names, "test name {} not found, it is likely "
|
||||||
|
"that you may have misspelled the test name or the specified test does not "
|
||||||
|
"exist in the version of onnx package you installed.".format(
|
||||||
|
test_name)
|
||||||
backend_test.include(r"^{}$".format(test_name))
|
backend_test.include(r"^{}$".format(test_name))
|
||||||
|
|
||||||
|
|
||||||
def tearDownModule():
|
|
||||||
print()
|
|
||||||
print("*" * 40)
|
|
||||||
print("A total of {} tests should have run".format(len(test_to_enable)))
|
|
||||||
print("*" * 40)
|
|
||||||
|
|
||||||
|
|
||||||
# import all test cases at global scope to make them visible to python.unittest
|
# import all test cases at global scope to make them visible to python.unittest
|
||||||
globals().update(backend_test.test_cases)
|
globals().update(backend_test.test_cases)
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
ONNF_BUILD_PATH = "@CMAKE_BINARY_DIR@"
|
||||||
|
LLVM_PROJ_BUILD_PATH = "@LLVM_PROJ_BUILD@"
|
||||||
|
CXX_PATH = "@CMAKE_CXX_COMPILER@"
|
|
@ -1,12 +1,12 @@
|
||||||
// RUN: onnf-opt %s -mlir-print-op-generic | FileCheck -check-prefix=GENERIC %s
|
// RUN: onnf-opt %s -mlir-print-op-generic | FileCheck -check-prefix=GENERIC %s
|
||||||
// RUN: onnf-opt %s | FileCheck %s
|
// RUN: onnf-opt %s | FileCheck %s
|
||||||
|
|
||||||
// GENERIC-DAG: #{{.*}} = () -> (0)
|
// GENERIC-DAG: #{{.*}} = affine_map<() -> (0)>
|
||||||
// GENERIC-DAG: #{{.*}} = () -> (10)
|
// GENERIC-DAG: #{{.*}} = affine_map<() -> (10)>
|
||||||
// GENERIC-DAG: #{{.*}} = () -> (1)
|
// GENERIC-DAG: #{{.*}} = affine_map<() -> (1)>
|
||||||
// GENERIC-DAG: #{{.*}} = () -> (11)
|
// GENERIC-DAG: #{{.*}} = affine_map<() -> (11)>
|
||||||
// GENERIC-DAG: #{{.*}} = (d0, d1) -> (d0 - d1)
|
// GENERIC-DAG: #{{.*}} = affine_map<(d0, d1) -> (d0 - d1)>
|
||||||
// GENERIC-DAG: #{{.*}} = (d0, d1) -> (d0 + d1)
|
// GENERIC-DAG: #{{.*}} = affine_map<(d0, d1) -> (d0 + d1)>
|
||||||
|
|
||||||
func @simple_iterate(%N : index) {
|
func @simple_iterate(%N : index) {
|
||||||
%ii, %ij, %ik = krnl.define_loops 3
|
%ii, %ij, %ik = krnl.define_loops 3
|
||||||
|
@ -55,18 +55,18 @@ func @affine_map_bound(%N : index) {
|
||||||
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
|
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
|
||||||
// GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index):
|
// GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index):
|
||||||
// CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10, %{{.*}} -> %{{.*}} = 0 to 10) {
|
// CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10, %{{.*}} -> %{{.*}} = 0 to 10) {
|
||||||
krnl.iterate(%oi, %oj) with (%ii -> %i = ()->(0)() to ()->(10)(), %ij -> %j = 0 to 10) {
|
krnl.iterate(%oi, %oj) with (%ii -> %i = affine_map<()->(0)>() to affine_map<()->(10)>(), %ij -> %j = 0 to 10) {
|
||||||
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
|
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
|
||||||
// GENERIC-NEXT: ^bb0(%{{.*}}: index):
|
// GENERIC-NEXT: ^bb0(%{{.*}}: index):
|
||||||
// CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = #{{.*}}(%{{.*}}, %{{.*}}) to #{{.*}}(%{{.*}}, %{{.*}})) {
|
// CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = #{{.*}}(%{{.*}}, %{{.*}}) to #{{.*}}(%{{.*}}, %{{.*}})) {
|
||||||
krnl.iterate(%ok) with (%ik -> %k = (d0, d1)->(d0 - d1)(%i, %j) to (d0, d1)->(d0 + d1)(%i, %j)) {
|
krnl.iterate(%ok) with (%ik -> %k = affine_map<(d0, d1)->(d0 - d1)>(%i, %j) to affine_map<(d0, d1)->(d0 + d1)>(%i, %j)) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
|
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
|
||||||
// GENERIC-NEXT: ^bb0(%{{.*}}: index):
|
// GENERIC-NEXT: ^bb0(%{{.*}}: index):
|
||||||
// CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = max #map{{.*}}(%{{.*}}, %{{.*}}) to min #map{{.*}}(%{{.*}}, %{{.*}})[%{{.*}}]) {
|
// CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = max #map{{.*}}(%{{.*}}, %{{.*}}) to min #map{{.*}}(%{{.*}}, %{{.*}})[%{{.*}}]) {
|
||||||
krnl.iterate(%ok) with (%ik -> %k = max (d0, d1)->(d0 - d1, 0)(%i, %j) to min (d0, d1)[s0]->(d0 + d1, s0)(%i, %j)[%N]) {
|
krnl.iterate(%ok) with (%ik -> %k = max affine_map<(d0, d1)->(d0 - d1, 0)>(%i, %j) to min affine_map<(d0, d1)[s0]->(d0 + d1, s0)>(%i, %j)[%N]) {
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -385,7 +385,7 @@ func @test_min(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*
|
||||||
}
|
}
|
||||||
|
|
||||||
func @test_elu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
func @test_elu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.Elu"(%arg0) {Elu.alpha=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
|
%0 = "onnx.Elu"(%arg0) {alpha=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_elu
|
// CHECK-LABEL: test_elu
|
||||||
|
@ -411,7 +411,7 @@ func @test_elu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
}
|
}
|
||||||
|
|
||||||
func @test_leakyrelu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
func @test_leakyrelu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.LeakyRelu"(%arg0) {LeakyRelu.alpha=1.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
|
%0 = "onnx.LeakyRelu"(%arg0) {alpha=1.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_leakyrelu
|
// CHECK-LABEL: test_leakyrelu
|
||||||
|
@ -434,7 +434,7 @@ func @test_leakyrelu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
}
|
}
|
||||||
|
|
||||||
func @test_selu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
func @test_selu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.Selu"(%arg0) {Selu.alpha=1.0:f32, Selu.gamma=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
|
%0 = "onnx.Selu"(%arg0) {alpha=1.0:f32, gamma=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_selu
|
// CHECK-LABEL: test_selu
|
||||||
|
@ -461,7 +461,7 @@ func @test_selu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
}
|
}
|
||||||
|
|
||||||
func @test_hardsigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
func @test_hardsigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.HardSigmoid"(%arg0) {HardSigmoid.alpha=1.0:f32, HardSigmoid.beta=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
|
%0 = "onnx.HardSigmoid"(%arg0) {alpha=1.0:f32, beta=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_hardsigmoid
|
// CHECK-LABEL: test_hardsigmoid
|
||||||
|
@ -533,3 +533,49 @@ func @test_add_with_broadcasting(%arg0 : tensor<?xf32>, %arg1 : tensor<?x10xf32>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
// CHECK: return [[RES]] : memref<?x10xf32>
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @test_softmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Softmax"(%arg0) {axis=1:i32} : (tensor<10x10xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_softmax
|
||||||
|
// CHECK: [[MAX:%.+]] = alloc() : memref<f32>
|
||||||
|
// CHECK: [[SUM:%.+]] = alloc() : memref<f32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32>
|
||||||
|
// CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32
|
||||||
|
// CHECK: [[CST_0:%.+]] = constant 0xFF800000 : f32
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, %3#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to 10) {
|
||||||
|
// CHECK: store [[CST]], [[SUM]][] : memref<f32>
|
||||||
|
// CHECK: store [[CST_0]], [[MAX]][] : memref<f32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD1:%.+]] = load [[MAX]][] : memref<f32>
|
||||||
|
// CHECK: [[LOAD2:%.+]] = load %arg0[%arg1, %arg2] : memref<10x10xf32>
|
||||||
|
// CHECK: [[COND:%.+]] = cmpf "ogt", [[LOAD1]], [[LOAD2]] : f32
|
||||||
|
// CHECK: [[SELECT:%.+]] = select [[COND]], [[LOAD1]], [[LOAD2]] : f32
|
||||||
|
// CHECK: store [[SELECT]], [[MAX]][] : memref<f32>
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: %5 = load [[MAX]][] : memref<f32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD1]] = load [[SUM]][] : memref<f32>
|
||||||
|
// CHECK: [[LOAD2]] = load %arg0[%arg1, %arg2] : memref<10x10xf32>
|
||||||
|
// CHECK: [[SUB:%.+]] = subf [[LOAD2]], %5 : f32
|
||||||
|
// CHECK: [[EXP:%.+]] = exp [[SUB]] : f32
|
||||||
|
// CHECK: [[ADD:%.+]] = addf [[LOAD1]], [[EXP]] : f32
|
||||||
|
// CHECK: store [[ADD]], [[SUM]][] : memref<f32>
|
||||||
|
// CHECK: store %10, [[RES]][%arg1, %arg2] : memref<10x10xf32>
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: %6 = load [[SUM]][] : memref<f32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD1]] = load [[RES]][%arg1, %arg2] : memref<10x10xf32>
|
||||||
|
// CHECK: [[DIV:%.+]] = divf [[LOAD1]], %6 : f32
|
||||||
|
// CHECK: store [[DIV]], [[RES]][%arg1, %arg2] : memref<10x10xf32>
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: dealloc [[SUM]] : memref<f32>
|
||||||
|
// CHECK: dealloc [[MAX]] : memref<f32>
|
||||||
|
// CHECK: return [[RES]] : memref<10x10xf32>
|
||||||
|
}
|
||||||
|
|
|
@ -648,8 +648,8 @@ func @test_min_min(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tens
|
||||||
}
|
}
|
||||||
|
|
||||||
func @test_elu_elu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
func @test_elu_elu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.Elu"(%arg0) {Elu.alpha=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
|
%0 = "onnx.Elu"(%arg0) {alpha=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
%1 = "onnx.Elu"(%0) {Elu.alpha=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32>
|
%1 = "onnx.Elu"(%0) {alpha=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_elu_elu
|
// CHECK-LABEL: test_elu_elu
|
||||||
|
@ -701,8 +701,8 @@ func @test_elu_elu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
}
|
}
|
||||||
|
|
||||||
func @test_leakyrelu_leakyrelu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
func @test_leakyrelu_leakyrelu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.LeakyRelu"(%arg0) {LeakyRelu.alpha=1.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
|
%0 = "onnx.LeakyRelu"(%arg0) {alpha=1.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
%1 = "onnx.LeakyRelu"(%0) {LeakyRelu.alpha=1.0:f32} : (tensor<*xf32>) -> tensor<*xf32>
|
%1 = "onnx.LeakyRelu"(%0) {alpha=1.0:f32} : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_leakyrelu_leakyrelu
|
// CHECK-LABEL: test_leakyrelu_leakyrelu
|
||||||
|
@ -748,8 +748,8 @@ func @test_leakyrelu_leakyrelu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
}
|
}
|
||||||
|
|
||||||
func @test_selu_selu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
func @test_selu_selu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.Selu"(%arg0) {Selu.alpha=1.0:f32, Selu.gamma=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
|
%0 = "onnx.Selu"(%arg0) {alpha=1.0:f32, gamma=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
%1 = "onnx.Selu"(%0) {Selu.alpha=1.0:f32, Selu.gamma=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32>
|
%1 = "onnx.Selu"(%0) {alpha=1.0:f32, gamma=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_selu_selu
|
// CHECK-LABEL: test_selu_selu
|
||||||
|
@ -803,8 +803,8 @@ func @test_selu_selu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
}
|
}
|
||||||
|
|
||||||
func @test_hardsigmoid_hardsigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
func @test_hardsigmoid_hardsigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.HardSigmoid"(%arg0) {HardSigmoid.alpha=1.0:f32, HardSigmoid.beta=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
|
%0 = "onnx.HardSigmoid"(%arg0) {alpha=1.0:f32, beta=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
%1 = "onnx.HardSigmoid"(%0) {HardSigmoid.alpha=1.0:f32, HardSigmoid.beta=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32>
|
%1 = "onnx.HardSigmoid"(%0) {alpha=1.0:f32, beta=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
// CHECK-LABEL: test_hardsigmoid_hardsigmoid
|
// CHECK-LABEL: test_hardsigmoid_hardsigmoid
|
||||||
|
|
|
@ -11,8 +11,26 @@ func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: return [[RES]] : tensor<32x1x5x5xf32>
|
// CHECK: return [[RES]] : tensor<32x1x5x5xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_default_transpose
|
||||||
|
// CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32>
|
||||||
|
// CHECK: return [[RES]] : tensor<32x1x5x5xf32>
|
||||||
|
|
||||||
|
/// Test shape inference for transposition when perm attribute is specified.
|
||||||
|
func @test_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_transpose
|
||||||
|
// CHECK: [[RES_ATTR:%.+]] = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<1x5x32x5xf32>
|
||||||
|
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x5xf32>
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Test the shape inferencing scheme for the matmul operation.
|
/// Test the shape inferencing scheme for the matmul operation.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// MatMul: 1-D x 1-D
|
/// MatMul: 1-D x 1-D
|
||||||
|
|
||||||
func @test_matmul_1(%arg0 : tensor<32xf32>, %arg1 : tensor<32xf32>) -> tensor<*xf32> {
|
func @test_matmul_1(%arg0 : tensor<32xf32>, %arg1 : tensor<32xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32xf32>) -> tensor<*xf32>
|
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
@ -23,6 +41,7 @@ func @test_matmul_1(%arg0 : tensor<32xf32>, %arg1 : tensor<32xf32>) -> tensor<*x
|
||||||
}
|
}
|
||||||
|
|
||||||
/// MatMul: K-D x 2-D (K > 2)
|
/// MatMul: K-D x 2-D (K > 2)
|
||||||
|
|
||||||
func @test_matmul_2(%arg0 : tensor<16x?x64x42xf32>, %arg1 : tensor<42x32xf32>) -> tensor<*xf32> {
|
func @test_matmul_2(%arg0 : tensor<16x?x64x42xf32>, %arg1 : tensor<42x32xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x42xf32>, tensor<42x32xf32>) -> tensor<*xf32>
|
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x42xf32>, tensor<42x32xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
@ -33,6 +52,7 @@ func @test_matmul_2(%arg0 : tensor<16x?x64x42xf32>, %arg1 : tensor<42x32xf32>) -
|
||||||
}
|
}
|
||||||
|
|
||||||
/// MatMul: 2-D x K-D (K > 2)
|
/// MatMul: 2-D x K-D (K > 2)
|
||||||
|
|
||||||
func @test_matmul_3(%arg0 : tensor<64x42xf32>, %arg1 : tensor<16x?x42x32xf32>) -> tensor<*xf32> {
|
func @test_matmul_3(%arg0 : tensor<64x42xf32>, %arg1 : tensor<16x?x42x32xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor<16x?x42x32xf32>) -> tensor<*xf32>
|
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor<16x?x42x32xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
@ -43,6 +63,7 @@ func @test_matmul_3(%arg0 : tensor<64x42xf32>, %arg1 : tensor<16x?x42x32xf32>) -
|
||||||
}
|
}
|
||||||
|
|
||||||
/// MatMul: 2-D x K-D (K > 2)
|
/// MatMul: 2-D x K-D (K > 2)
|
||||||
|
|
||||||
func @test_matmul_4(%arg0 : tensor<64x42xf32>, %arg1 : tensor<?x?x?x?xf32>) -> tensor<*xf32> {
|
func @test_matmul_4(%arg0 : tensor<64x42xf32>, %arg1 : tensor<?x?x?x?xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor<?x?x?x?xf32>) -> tensor<*xf32>
|
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor<?x?x?x?xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
@ -53,6 +74,7 @@ func @test_matmul_4(%arg0 : tensor<64x42xf32>, %arg1 : tensor<?x?x?x?xf32>) -> t
|
||||||
}
|
}
|
||||||
|
|
||||||
/// MatMul: K1-D x K2-D (K1 > 2, K2 > 2)
|
/// MatMul: K1-D x K2-D (K1 > 2, K2 > 2)
|
||||||
|
|
||||||
func @test_matmul_5(%arg0 : tensor<16x?x?x42xf32>, %arg1 : tensor<32x?x64x42x32xf32>) -> tensor<*xf32> {
|
func @test_matmul_5(%arg0 : tensor<16x?x?x42xf32>, %arg1 : tensor<32x?x64x42x32xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x?x42xf32>, tensor<32x?x64x42x32xf32>) -> tensor<*xf32>
|
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x?x42xf32>, tensor<32x?x64x42x32xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
@ -63,6 +85,7 @@ func @test_matmul_5(%arg0 : tensor<16x?x?x42xf32>, %arg1 : tensor<32x?x64x42x32x
|
||||||
}
|
}
|
||||||
|
|
||||||
/// MatMul: 1-D x 2-D
|
/// MatMul: 1-D x 2-D
|
||||||
|
|
||||||
func @test_matmul_6(%arg0 : tensor<32xf32>, %arg1 : tensor<32x64xf32>) -> tensor<*xf32> {
|
func @test_matmul_6(%arg0 : tensor<32xf32>, %arg1 : tensor<32x64xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32x64xf32>) -> tensor<*xf32>
|
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32x64xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
@ -73,6 +96,7 @@ func @test_matmul_6(%arg0 : tensor<32xf32>, %arg1 : tensor<32x64xf32>) -> tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
/// MatMul: 2-D x 1-D
|
/// MatMul: 2-D x 1-D
|
||||||
|
|
||||||
func @test_matmul_7(%arg0 : tensor<32x64xf32>, %arg1 : tensor<64xf32>) -> tensor<*xf32> {
|
func @test_matmul_7(%arg0 : tensor<32x64xf32>, %arg1 : tensor<64xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32>
|
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
@ -83,6 +107,7 @@ func @test_matmul_7(%arg0 : tensor<32x64xf32>, %arg1 : tensor<64xf32>) -> tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
/// MatMul: 2-D x 2-D
|
/// MatMul: 2-D x 2-D
|
||||||
|
|
||||||
func @test_matmul_8(%arg0 : tensor<32x64xf32>, %arg1 : tensor<64x128xf32>) -> tensor<*xf32> {
|
func @test_matmul_8(%arg0 : tensor<32x64xf32>, %arg1 : tensor<64x128xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64x128xf32>) -> tensor<*xf32>
|
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64x128xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit 3c7fc8266bb46046b42c2dc2663f9f505f0cec28
|
|
@ -9,4 +9,5 @@ cmake -G Ninja ../llvm \
|
||||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||||
-DLLVM_ENABLE_RTTI=ON
|
-DLLVM_ENABLE_RTTI=ON
|
||||||
|
|
||||||
cmake --build . --target check-mlir -- ${MAKEFLAGS}
|
cmake --build . --target
|
||||||
|
cmake --build . --target check-mlir
|
|
@ -7,4 +7,5 @@ cmake ..
|
||||||
cmake --build . --target onnf
|
cmake --build . --target onnf
|
||||||
|
|
||||||
# Run FileCheck tests:
|
# Run FileCheck tests:
|
||||||
|
export LIT_OPTS=-v
|
||||||
cmake --build . --target check-mlir-lit
|
cmake --build . --target check-mlir-lit
|
Loading…
Reference in New Issue