Merge remote-tracking branch 'origin/master' into matmul-shape

This commit is contained in:
Doru Bercea 2020-01-22 15:29:09 -05:00
commit 0bc07ef661
26 changed files with 818 additions and 728 deletions

View File

@ -3,27 +3,12 @@ jobs:
build: build:
docker: docker:
- image: circleci/python - image: circleci/python
resource_class: medium+
steps: steps:
- run: - run:
name: Installing GCC, CMake, Ninja, Protobuf name: Installing GCC, CMake, Ninja, Protobuf
command: sudo apt-get update && sudo apt-get install -y gcc g++ cmake ninja-build protobuf-compiler command: sudo apt-get update && sudo apt-get install -y gcc g++ cmake ninja-build protobuf-compiler
# Use cached mlir installation if possible.
- restore_cache:
key: V2-LLVM-PROJECT-{{ arch }}
- run:
name: Install MLIR
command: |
# Check whether cache restoration succeeds by checking whether
# mlir-opt executable exists.
if [ ! -f llvm-project/build/bin/mlir-opt ]; then
export MAKEFLAGS=-j4
source utils/install-mlir.sh
fi
- save_cache:
key: V2-LLVM-PROJECT-{{ arch }}
paths:
- llvm-project
- checkout: - checkout:
path: ONNF path: ONNF
- run: - run:
@ -31,9 +16,30 @@ jobs:
command: | command: |
cd ONNF cd ONNF
git submodule update --init --recursive git submodule update --init --recursive
# Use cached mlir installation if possible.
- restore_cache:
key: V4-LLVM-PROJECT-{{ arch }}
- run:
name: Install MLIR
command: |
# Check whether cache restoration succeeds by checking whether
# mlir-opt executable exists.
if [ ! -f llvm-project/build/bin/mlir-opt ]; then
source ONNF/utils/install-mlir.sh
fi
- save_cache:
key: V4-LLVM-PROJECT-{{ arch }}
paths:
- llvm-project
- run: - run:
name: Install ONNF name: Install ONNF
command: source ONNF/utils/install-onnf.sh command: source ONNF/utils/install-onnf.sh
- run:
name: Run End-To-End Tests
command: |
sudo pip install -q onnx
cd ONNF/build
cmake --build . --target run-onnx-backend-test
- run: - run:
name: Run DocCheck name: Run DocCheck
command: cd ONNF/build && cmake --build . --target check-doc command: cd ONNF/build && cmake --build . --target check-doc

3
.gitmodules vendored
View File

@ -7,3 +7,6 @@
[submodule "third_party/pybind11"] [submodule "third_party/pybind11"]
path = third_party/pybind11 path = third_party/pybind11
url = https://github.com/pybind/pybind11.git url = https://github.com/pybind/pybind11.git
[submodule "third_party/variant"]
path = third_party/variant
url = git@github.com:mpark/variant.git

View File

@ -22,9 +22,9 @@ include(MLIR.cmake)
add_subdirectory(third_party/onnx) add_subdirectory(third_party/onnx)
add_subdirectory(third_party/benchmark) add_subdirectory(third_party/benchmark)
add_subdirectory(third_party/pybind11) add_subdirectory(third_party/pybind11)
add_subdirectory(third_party/variant)
set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_STANDARD 14)
add_subdirectory(src) add_subdirectory(src)
add_subdirectory(doc) add_subdirectory(doc)
add_subdirectory(test) add_subdirectory(test)

View File

@ -20,7 +20,8 @@ cmake -G Ninja ../llvm \
-DLLVM_ENABLE_ASSERTIONS=ON \ -DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_ENABLE_RTTI=ON -DLLVM_ENABLE_RTTI=ON
cmake --build . --target check-mlir -- ${MAKEFLAGS} cmake --build . --target
cmake --build . --target check-mlir
``` ```
Two environment variables need to be set: Two environment variables need to be set:
@ -42,6 +43,7 @@ cmake ..
cmake --build . --target onnf cmake --build . --target onnf
# Run FileCheck tests: # Run FileCheck tests:
export LIT_OPTS=-v
cmake --build . --target check-mlir-lit cmake --build . --target check-mlir-lit
``` ```

View File

@ -27,6 +27,7 @@ ONNX operations for which some work is needed.
| Selu | Tung | v | v | | | Selu | Tung | v | v | |
| Sigmoid | Tung | v | v | | | Sigmoid | Tung | v | v | |
| Sinh | Tung | v | v | | | Sinh | Tung | v | v | |
| Softmax | Tung | v | v | |
| Sub | Tung | v | v | M | | Sub | Tung | v | v | M |
| Sum | Tung | v | v | M | | Sum | Tung | v | v | M |
| Tanh | Tung | v | v | | | Tanh | Tung | v | v | |

View File

@ -69,8 +69,9 @@ add_subdirectory(runtime)
add_executable(onnf main.cpp) add_executable(onnf main.cpp)
target_link_libraries(onnf builder ${MLIRLibs} onnf_transform onnf_shape_inference onnf_lower_frontend) target_link_libraries(onnf builder ${MLIRLibs} onnf_transform onnf_shape_inference onnf_lower_frontend)
set_target_properties(onnf PROPERTIES LINK_FLAGS "-lz")
whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs}) whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs})
find_package(ZLIB REQUIRED)
target_link_libraries(onnf ${ZLIB_LIBRARIES})
target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR}) target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR})
target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR}) target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR})

View File

@ -7,8 +7,9 @@ add_library(builder
target_include_directories(builder PRIVATE ${CMAKE_SOURCE_DIR}) target_include_directories(builder PRIVATE ${CMAKE_SOURCE_DIR})
target_include_directories(builder PRIVATE ${CMAKE_BINARY_DIR}) target_include_directories(builder PRIVATE ${CMAKE_BINARY_DIR})
target_link_libraries(builder compiler onnx ${MLIRLibs} curses) target_link_libraries(builder compiler onnx ${MLIRLibs} curses mpark_variant)
target_include_directories(builder target_include_directories(builder
PRIVATE PRIVATE
${CMAKE_SOURCE_DIR}/third_party/onnx ${CMAKE_SOURCE_DIR}/third_party/onnx
${CMAKE_SOURCE_DIR}/third_party/variant
${CMAKE_SOURCE_DIR}) ${CMAKE_SOURCE_DIR})

View File

@ -14,11 +14,16 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include <map>
#include <numeric> #include <numeric>
#include <regex> #include <regex>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <map>
// Using backported variant.
// bstd = backported standard library.
#include <mpark/variant.hpp>
namespace bstd = mpark;
#include "mlir/Analysis/Verifier.h" #include "mlir/Analysis/Verifier.h"
#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Dialect/StandardOps/Ops.h"
@ -42,15 +47,15 @@
namespace onnf { namespace onnf {
namespace { namespace {
void replaceAll( void replaceAll(std::string &str, const std::string &from,
std::string& str, const std::string& from, const std::string& to) { const std::string &to) {
if (from.empty()) if (from.empty())
return; return;
size_t start_pos = 0; size_t start_pos = 0;
while ((start_pos = str.find(from, start_pos)) != std::string::npos) { while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
str.replace(start_pos, from.length(), to); str.replace(start_pos, from.length(), to);
start_pos += to.length(); // In case 'to' contains 'from', like replacing start_pos += to.length(); // In case 'to' contains 'from', like replacing
// 'x' with 'yx' // 'x' with 'yx'
} }
} }
@ -71,10 +76,10 @@ struct OnnxOnnfSymbolMapping {
* @param name onnx tensor name. * @param name onnx tensor name.
* @return onnf tensor corresponding to `name`. * @return onnf tensor corresponding to `name`.
*/ */
mlir::Value GetTensorByOnnxName(std::string name) { mlir::Value GetTensorByOnnxName(const std::string &name) {
assert(onnx_name2onnf_tensor.find(legalize_name(name)) != assert(onnx_name2onnf_tensor.find(legalize_name(name)) !=
onnx_name2onnf_tensor.end() && onnx_name2onnf_tensor.end() &&
"Tensor not found"); "Tensor not found");
return onnx_name2onnf_tensor.at(legalize_name(name)); return onnx_name2onnf_tensor.at(legalize_name(name));
} }
@ -83,9 +88,9 @@ struct OnnxOnnfSymbolMapping {
* @param name onnx tensor name. * @param name onnx tensor name.
* @param tensor MLIR Value pointer. * @param tensor MLIR Value pointer.
*/ */
void AddMapping(std::string name, mlir::Value tensor) { void AddMapping(const std::string &name, mlir::Value tensor) {
assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 && assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 &&
"Tensor already exists."); "Tensor already exists.");
onnx_name2onnf_tensor.emplace(legalize_name(name), tensor); onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
} }
@ -124,34 +129,34 @@ private:
// Convert type to MLIR type. // Convert type to MLIR type.
// A complete list of types can be found in: // A complete list of types can be found in:
// <onnf-build-folder>/third_party/onnx/onnx/onnx.pb.h // <onnf-build-folder>/third_party/onnx/onnx/onnx.pb.h
mlir::Type TypeConvert(onnx::TensorProto_DataType intype) { mlir::Type convertONNXTypeToMLIRType(onnx::TensorProto_DataType onnxType) {
switch (intype) { switch (onnxType) {
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16: case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16:
return builder_.getF16Type(); return builder_.getF16Type();
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT: case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
return builder_.getF32Type(); return builder_.getF32Type();
case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE: case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE:
return builder_.getF64Type(); return builder_.getF64Type();
case onnx::TensorProto_DataType::TensorProto_DataType_INT8: case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8: case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
return builder_.getIntegerType(8); return builder_.getIntegerType(8);
case onnx::TensorProto_DataType::TensorProto_DataType_INT16: case onnx::TensorProto_DataType::TensorProto_DataType_INT16:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16: case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
return builder_.getIntegerType(16); return builder_.getIntegerType(16);
case onnx::TensorProto_DataType::TensorProto_DataType_INT32: case onnx::TensorProto_DataType::TensorProto_DataType_INT32:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32: case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
return builder_.getIntegerType(32); return builder_.getIntegerType(32);
case onnx::TensorProto_DataType::TensorProto_DataType_INT64: case onnx::TensorProto_DataType::TensorProto_DataType_INT64:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64: case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
return builder_.getIntegerType(64); return builder_.getIntegerType(64);
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
return builder_.getI1Type(); return builder_.getI1Type();
case onnx::TensorProto_DataType::TensorProto_DataType_STRING: case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64: case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128: case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED: case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
assert(false && "Unsupported data type encountered."); assert(false && "Unsupported data type encountered.");
return nullptr; return nullptr;
} }
} }
@ -169,8 +174,8 @@ private:
for (int i = 0; i < shape_proto.dim_size(); i++) { for (int i = 0; i < shape_proto.dim_size(); i++) {
if (shape_proto.dim()[i].dim_value()) { if (shape_proto.dim()[i].dim_value()) {
int dim_numeric_size = shape_proto.dim()[i].dim_value(); int dim_numeric_size = shape_proto.dim()[i].dim_value();
assert( assert(dim_numeric_size != 0 &&
dim_numeric_size != 0 && "Parsed an input tensor with a dimension size of zero"); "Parsed an input tensor with a dimension size of zero");
if (dim_numeric_size > 0) { if (dim_numeric_size > 0) {
dims.push_back(dim_numeric_size); dims.push_back(dim_numeric_size);
} else { // If dim_value < 0, then dim is parametric. } else { // If dim_value < 0, then dim is parametric.
@ -184,7 +189,7 @@ private:
} }
mlir::Type elementType = mlir::Type elementType =
TypeConvert(input.type().tensor_type().elem_type()); convertONNXTypeToMLIRType(input.type().tensor_type().elem_type());
llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size()); llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size());
arg_types.emplace_back( arg_types.emplace_back(
mlir::RankedTensorType::get(tensor_dims, elementType)); mlir::RankedTensorType::get(tensor_dims, elementType));
@ -200,288 +205,111 @@ private:
void ImportInputTensorSymbol(const onnx::ValueInfoProto &input, void ImportInputTensorSymbol(const onnx::ValueInfoProto &input,
mlir::Value symbol) { mlir::Value symbol) {
auto input_tensor_legalized_name = legalize_name(input.name()); auto input_tensor_legalized_name = legalize_name(input.name());
assert( assert(!frontend_symbols_.ContainKey(input_tensor_legalized_name) &&
!frontend_symbols_.ContainKey(input_tensor_legalized_name) && "Found duplicate legalized input tensor names.");
"Found duplicate legalized input tensor names.");
frontend_symbols_.AddMapping(input_tensor_legalized_name, symbol); frontend_symbols_.AddMapping(input_tensor_legalized_name, symbol);
} }
template <typename T> typedef bstd::variant<int64_t, std::vector<int64_t>, float,
T get_attr_generic(onnx::NodeProto &node, std::string name, std::vector<float>, std::string,
std::function<T(onnx::AttributeProto &)> attr_getter, std::vector<std::string>>
T default_val) { AttrValueType;
struct ONNXAttrVisitor {
ONNXAttrVisitor(std::string name, mlir::OpBuilder &builder)
: _builder(builder), _name(std::move(name)) {}
// Op builder.
mlir::OpBuilder &_builder;
// Name of the attribute being inspected.
std::string _name;
mlir::NamedAttribute operator()(int64_t const &r) {
auto val = _builder.getI32IntegerAttr(r);
return _builder.getNamedAttr(_name, val);
}
mlir::NamedAttribute operator()(std::vector<int64_t> const &ints) {
auto val = _builder.getI64ArrayAttr(ints);
return _builder.getNamedAttr(_name, val);
}
mlir::NamedAttribute operator()(float const &r) {
auto val = _builder.getF32FloatAttr(r);
return _builder.getNamedAttr(_name, val);
}
mlir::NamedAttribute operator()(std::vector<float> const &floats) {
auto val = _builder.getF32ArrayAttr(floats);
return _builder.getNamedAttr(_name, val);
}
mlir::NamedAttribute operator()(std::string const &s) {
auto val = _builder.getStringAttr(s);
return _builder.getNamedAttr(_name, val);
}
mlir::NamedAttribute operator()(std::vector<std::string> const &r) {
assert(false && "type of attribute value is not implemented");
auto val = _builder.getI32IntegerAttr(1);
return _builder.getNamedAttr(_name, val);
};
};
mlir::NamedAttribute convertNameValuePairToNamedAttribute(
std::pair<std::string, AttrValueType> nameAndVal) {
auto visitor = ONNXAttrVisitor(nameAndVal.first, builder_);
return mpark::visit(visitor, nameAndVal.second);
}
static std::pair<std::string, AttrValueType>
convertAttributeProtoToNameValuePair(onnx::AttributeProto &attr) {
AttrValueType val;
switch (attr.type()) {
case onnx::AttributeProto::FLOAT:
return std::make_pair(attr.name(), AttrValueType(attr.f()));
case onnx::AttributeProto::INT:
return std::make_pair(attr.name(), AttrValueType(attr.i()));
case onnx::AttributeProto::STRING:
return std::make_pair(attr.name(), AttrValueType(attr.s()));
case onnx::AttributeProto::FLOATS:
val = AttrValueType(
std::vector<float>(attr.floats().begin(), attr.floats().end()));
return std::make_pair(attr.name(), val);
case onnx::AttributeProto::INTS:
val = AttrValueType(
std::vector<int64_t>(attr.ints().begin(), attr.ints().end()));
return std::make_pair(attr.name(), val);
default:
assert(false && "datatype for attribute is not implemented");
break;
}
}
std::vector<mlir::NamedAttribute> ImportNodeAttributes(
const onnx::NodeProto &node,
std::initializer_list<std::pair<std::string, AttrValueType>>
defaultAttrList) {
std::vector<mlir::NamedAttribute> attributes;
std::set<std::string> definedAttributeSet;
for (int i = 0; i < node.attribute_size(); ++i) { for (int i = 0; i < node.attribute_size(); ++i) {
auto attr = node.attribute(i); auto attr = node.attribute(i);
if (attr.name() == name) { auto nameValPair = convertAttributeProtoToNameValuePair(attr);
return attr_getter(attr); attributes.push_back(convertNameValuePairToNamedAttribute(nameValPair));
} definedAttributeSet.insert(attr.name());
} }
return default_val; for (const auto &defaultAttr : defaultAttrList) {
} if (definedAttributeSet.find(defaultAttr.first) ==
definedAttributeSet.end())
template <typename T> attributes.push_back(convertNameValuePairToNamedAttribute(defaultAttr));
T get_attr_generic(onnx::NodeProto &node, std::string name,
std::function<T(onnx::AttributeProto &)> attr_getter) {
for (int i = 0; i < node.attribute_size(); ++i) {
auto attr = node.attribute(i);
if (attr.name() == name) {
return attr_getter(attr);
}
} }
assert(false && "ONNX Node Attribute Not Found!"); return attributes;
} }
auto get_attr_ints(onnx::NodeProto &node, std::string name, void ImportNodeGeneric(const onnx::NodeProto &node) {
std::vector<int> default_val) {
std::function<std::vector<int>(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) {
std::vector<int> ints(attr.ints_size());
std::copy(attr.ints().begin(), attr.ints().end(), ints.begin());
return ints;
};
auto r = get_attr_generic(node, name, attr_getter, default_val);
auto dataType =
mlir::RankedTensorType::get(r.size(), builder_.getIntegerType(32));
auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r));
auto aname = node.op_type() + "." + name;
auto attr_output = builder_.getNamedAttr(aname, attr_v);
return attr_output;
}
auto get_attr_ints(onnx::NodeProto &node, std::string name) {
std::function<std::vector<int>(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) {
std::vector<int> ints(attr.ints_size());
std::copy(attr.ints().begin(), attr.ints().end(), ints.begin());
return ints;
};
auto r = get_attr_generic(node, name, attr_getter);
auto dataType =
mlir::RankedTensorType::get(r.size(), builder_.getIntegerType(32));
auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r));
auto aname = node.op_type() + "." + name;
auto attr_output = builder_.getNamedAttr(aname, attr_v);
return attr_output;
}
auto get_attr_floats(onnx::NodeProto &node, std::string name) {
std::function<std::vector<float>(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) {
std::vector<float> floats(attr.floats_size());
std::copy(attr.floats().begin(), attr.floats().end(), floats.begin());
return floats;
};
auto r = get_attr_generic(node, name, attr_getter);
auto dataType =
mlir::RankedTensorType::get(r.size(), builder_.getF32Type());
auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r));
auto aname = node.op_type() + "." + name;
auto attr_output = builder_.getNamedAttr(aname, attr_v);
return attr_output;
}
auto get_attr_floats(onnx::NodeProto &node, std::string name,
std::vector<float> default_val) {
std::function<std::vector<float>(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) {
std::vector<float> floats(attr.floats_size());
std::copy(attr.floats().begin(), attr.floats().end(), floats.begin());
return floats;
};
auto r = get_attr_generic(node, name, attr_getter, default_val);
auto dataType =
mlir::RankedTensorType::get(r.size(), builder_.getF32Type());
auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r));
auto aname = node.op_type() + "." + name;
auto attr_output = builder_.getNamedAttr(aname, attr_v);
return attr_output;
}
auto get_attr_int(onnx::NodeProto &node, std::string name) {
std::function<int(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) { return attr.i(); };
int r = get_attr_generic(node, name, attr_getter);
auto attr_v = builder_.getI32IntegerAttr(r);
auto aname = node.op_type() + "." + name;
auto attr_output = builder_.getNamedAttr(aname, attr_v);
return attr_output;
}
auto get_attr_int(onnx::NodeProto &node, std::string name, int default_val) {
std::function<int(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) { return attr.i(); };
int r = get_attr_generic(node, name, attr_getter, default_val);
auto attr_v = builder_.getI32IntegerAttr(r);
auto aname = node.op_type() + "." + name;
auto attr_output = builder_.getNamedAttr(aname, attr_v);
return attr_output;
}
auto get_attr_float(onnx::NodeProto &node, std::string name) {
std::function<float(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) { return attr.f(); };
auto r = get_attr_generic(node, name, attr_getter);
auto attr_v = builder_.getF32FloatAttr(r);
auto aname = node.op_type() + "." + name;
return builder_.getNamedAttr(aname, attr_v);
}
auto get_attr_float(onnx::NodeProto &node, std::string name,
float default_val) {
std::function<float(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) { return attr.f(); };
auto r = get_attr_generic(node, name, attr_getter, default_val);
auto attr_v = builder_.getF32FloatAttr(r);
auto aname = node.op_type() + "." + name;
return builder_.getNamedAttr(aname, attr_v);
}
auto get_attr_string(onnx::NodeProto &node, std::string name) {
std::function<std::string(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) { return attr.s(); };
auto r = get_attr_generic(node, name, attr_getter);
auto attr_v = builder_.getStringAttr(r);
auto aname = node.op_type() + "." + name;
return builder_.getNamedAttr(aname, attr_v);
}
auto get_attr_string(onnx::NodeProto &node, std::string name,
std::string default_val) {
std::function<std::string(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) { return attr.s(); };
auto r = get_attr_generic(node, name, attr_getter, default_val);
auto attr_v = builder_.getStringAttr(r);
auto aname = node.op_type() + "." + name;
return builder_.getNamedAttr(aname, attr_v);
}
/*
auto get_attr_strings(onnx::NodeProto &node, std::string name) {
std::function<std::vector<std::string>(onnx::AttributeProto &)>
attr_getter =
[](onnx::AttributeProto &attr) {
std::vector<std::string> strings(attr.strings_size());
std::copy(attr.strings().begin(), attr.strings().end(),
strings.begin()); return strings;
};
auto r = get_attr_generic(node, name, attr_getter);
return r;
return builder_.getNamedAttr(aname, attr_v);
auto dataType =
mlir::RankedTensorType::get(r.size(), builder_.get???Type());
auto attr_v = mlir::DenseElementsAttr::get(dataType,
llvm::makeArrayRef(r)); auto aname = node.op_type() + "." + name; auto
attr_output = builder_.getNamedAttr(aname, attr_v); return attr_output;
}
*/
auto get_default_ints(std::string default_str) {
std::vector<int> r;
auto start = default_str.find("{");
while (true) {
auto end = default_str.find(",", start + 1);
if (end == std::string::npos) {
end = default_str.find("}", start + 1);
if (end != std::string::npos && end > start + 1) {
r.push_back(std::stoi(default_str.substr(start + 1, end)));
}
break;
} else {
r.push_back(std::stoi(default_str.substr(start + 1, end)));
}
start = end + 1;
}
return r;
}
auto get_default_floats(std::string default_str) {
std::vector<float> r;
auto start = default_str.find("{");
while (true) {
auto end = default_str.find(",", start + 1);
if (end == std::string::npos) {
end = default_str.find("}", start + 1);
if (end != std::string::npos && end > start + 1) {
r.push_back(std::stof(default_str.substr(start + 1, end)));
}
break;
} else {
r.push_back(std::stof(default_str.substr(start + 1, end)));
}
start = end + 1;
}
return r;
}
auto get_default_strings(std::string default_str) {
std::vector<std::string> r;
auto start = default_str.find("{");
while (true) {
auto end = default_str.find(",", start + 1);
if (end == std::string::npos) {
end = default_str.find("}", start + 1);
if (end != std::string::npos && end > start + 1) {
r.push_back(default_str.substr(start + 1, end));
}
break;
} else {
r.push_back(default_str.substr(start + 1, end));
}
start = end + 1;
}
return r;
}
onnx::TensorProto get_attr_tensor(onnx::NodeProto &node, std::string name) {
std::function<onnx::TensorProto(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) { return attr.t(); };
return get_attr_generic(node, name, attr_getter);
}
auto ImportNodeAttr(onnx::NodeProto node, std::string attr_name,
std::string type_name, std::string default_str) {
if (default_str == "") {
if (type_name == "int") {
return get_attr_int(node, attr_name);
} else if (type_name == "float") {
return get_attr_float(node, attr_name);
} else if (type_name == "str") {
return get_attr_string(node, attr_name);
} else if (type_name == "ints") {
return get_attr_ints(node, attr_name);
} else if (type_name == "floats") {
return get_attr_floats(node, attr_name);
} else {
assert(
false &&
"Got an empty initializer or initializer for this "
"datatype is not implemented. Something is wrong.");
}
} else {
// with default value
if (type_name == "int") {
return get_attr_int(node, attr_name, std::stoi(default_str));
} else if (type_name == "float") {
return get_attr_float(node, attr_name, std::stof(default_str));
} else if (type_name == "str") {
return get_attr_string(node, attr_name, default_str);
} else if (type_name == "ints") {
return get_attr_ints(node, attr_name, get_default_ints(default_str));
} else if (type_name == "floats") {
return get_attr_floats(node, attr_name,
get_default_floats(default_str));
} else {
assert(
false &&
"Got an empty initializer or initializer for this "
"datatype is not implemented. Something is wrong.");
}
}
}
void ImportNodeGeneric(onnx::NodeProto node) {
std::vector<mlir::Value> inputs; std::vector<mlir::Value> inputs;
for (auto item : node.input()) { for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) { if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
} }
@ -511,12 +339,12 @@ private:
* default} * default}
*/ */
template <typename T> template <typename T>
void ImportNodeOneOut( void
onnx::NodeProto node, int nIn, int nOut, ImportNodeOneOut(const onnx::NodeProto &node, int nIn, int nOut,
std::initializer_list<std::tuple<std::string, std::string, std::string>> std::initializer_list<std::pair<std::string, AttrValueType>>
attrs) { defaultAttrList) {
std::vector<mlir::Value> inputs; std::vector<mlir::Value> inputs;
for (auto item : node.input()) { for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) { if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
} }
@ -528,22 +356,7 @@ private:
mlir::UnrankedTensorType::get(builder_.getF32Type())); mlir::UnrankedTensorType::get(builder_.getF32Type()));
} }
std::vector<mlir::NamedAttribute> attributes; auto attributes = ImportNodeAttributes(node, defaultAttrList);
// for (auto [attr_name, attr_type, attr_default] : attrs) {
for (auto oneAttr : attrs) {
std::string attr_name;
std::string attr_type;
std::string attr_default;
std::tie(attr_name, attr_type, attr_default) = oneAttr;
if (attr_type != "") {
auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default);
attributes.push_back(attr);
} else {
// TODO: the attributes need special handling
// std::cout << "missing " << node.op_type() << " " << attr_name <<
// std::endl;
}
}
llvm::StringRef OpName = node.op_type(); llvm::StringRef OpName = node.op_type();
@ -559,11 +372,11 @@ private:
template <typename T> template <typename T>
void ImportNodeMultipleOuts( void ImportNodeMultipleOuts(
onnx::NodeProto node, int nIn, int nOut, const onnx::NodeProto &node, int nIn, int nOut,
std::initializer_list<std::tuple<std::string, std::string, std::string>> std::initializer_list<std::pair<std::string, AttrValueType>>
attrs) { defaultAttrList) {
std::vector<mlir::Value> inputs; std::vector<mlir::Value> inputs;
for (auto item : node.input()) { for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) { if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
} }
@ -575,21 +388,7 @@ private:
mlir::UnrankedTensorType::get(builder_.getF32Type())); mlir::UnrankedTensorType::get(builder_.getF32Type()));
} }
std::vector<mlir::NamedAttribute> attributes; auto attributes = ImportNodeAttributes(node, defaultAttrList);
for (auto oneAttr : attrs) {
std::string attr_name;
std::string attr_type;
std::string attr_default;
std::tie(attr_name, attr_type, attr_default) = oneAttr;
if (attr_type != "") {
auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default);
attributes.push_back(attr);
} else {
// TODO: the attributes need special handling
// std::cout << "missing " << node.op_type() << " " << attr_name <<
// std::endl;
}
}
llvm::StringRef OpName = node.op_type(); llvm::StringRef OpName = node.op_type();
@ -610,10 +409,10 @@ private:
* c++ does not allow template specialization inside a class scope * c++ does not allow template specialization inside a class scope
* a specialized function is used * a specialized function is used
*/ */
void ImportNodeConv( void
onnx::NodeProto node, int nOut, ImportNodeConv(onnx::NodeProto node, int nIn, int nOut,
std::initializer_list<std::tuple<std::string, std::string, std::string>> std::initializer_list<std::pair<std::string, AttrValueType>>
attrs) { defaultAttrList) {
// Conv has attribute dilations, kernel_shape, pads, the default value of // Conv has attribute dilations, kernel_shape, pads, the default value of
// which is determined by the shape of first argument. However, since the // which is determined by the shape of first argument. However, since the
// shape is unknown now, these attributes can be not generated auto // shape is unknown now, these attributes can be not generated auto
@ -627,29 +426,32 @@ private:
int nOps = node.input().size(); int nOps = node.input().size();
if (nOps == 2) if (nOps == 2)
ImportNodeOneOut<mlir::ONNXConvNoBiasOp>(node, nOps, nOut, attrs); ImportNodeOneOut<mlir::ONNXConvNoBiasOp>(
node, nOps, nOut, defaultAttrList);
else else
ImportNodeOneOut<mlir::ONNXConvOp>(node, nOps, nOut, attrs); ImportNodeOneOut<mlir::ONNXConvOp>(node, nOps, nOut, defaultAttrList);
} }
/*! /*!
* Special handle for MaxPool operations. * Special handle for MaxPool operations.
*/ */
void ImportNodeMaxPool( void ImportNodeMaxPool(
onnx::NodeProto node, int nIn, onnx::NodeProto node, int nIn, int nOut,
std::initializer_list<std::tuple<std::string, std::string, std::string>> std::initializer_list<std::pair<std::string, AttrValueType>>
attrs) { defaultAttrList) {
int nOuts = node.output().size(); int nOuts = node.output().size();
if (nOuts == 1) { if (nOuts == 1) {
ImportNodeOneOut<mlir::ONNXMaxPoolSingleOutOp>(node, nIn, nOuts, attrs); ImportNodeOneOut<mlir::ONNXMaxPoolSingleOutOp>(
node, nIn, nOuts, defaultAttrList);
} else { } else {
ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>(node, nIn, nOuts, attrs); ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>(
node, nIn, nOuts, defaultAttrList);
} }
} }
void ImportNode(onnx::NodeProto node) { void ImportNode(const onnx::NodeProto &node) {
std::vector<mlir::Value> inputs; std::vector<mlir::Value> inputs;
for (auto item : node.input()) { for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) { if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
} }
@ -689,9 +491,8 @@ private:
llvm::SmallVectorImpl<mlir::Type> &ret_types, llvm::SmallVectorImpl<mlir::Type> &ret_types,
llvm::SmallVectorImpl<mlir::Value> &ret_vals) { llvm::SmallVectorImpl<mlir::Value> &ret_vals) {
auto output_tensor_legalized_name = legalize_name(output.name()); auto output_tensor_legalized_name = legalize_name(output.name());
assert( assert(frontend_symbols_.ContainKey(output_tensor_legalized_name) &&
frontend_symbols_.ContainKey(output_tensor_legalized_name) && "Output tensor not found");
"Output tensor not found");
auto tensor_val = auto tensor_val =
frontend_symbols_.GetTensorByOnnxName(output_tensor_legalized_name); frontend_symbols_.GetTensorByOnnxName(output_tensor_legalized_name);
@ -750,9 +551,9 @@ private:
funcType = builder_.getFunctionType(arg_types, ret_types); funcType = builder_.getFunctionType(arg_types, ret_types);
mainFunc.setType(funcType); mainFunc.setType(funcType);
} }
}; // FrontendGenImpl class }; // FrontendGenImpl class
} // namespace } // namespace
} // namespace onnf } // namespace onnf
namespace onnf { namespace onnf {
@ -775,4 +576,4 @@ void ImportFrontendModelFile(std::string model_fname,
FrontendGenImpl myONNXGen(context); FrontendGenImpl myONNXGen(context);
module = myONNXGen.ImportONNXModel(model); module = myONNXGen.ImportONNXModel(model);
} }
} // namespace onnf } // namespace onnf

View File

@ -16,13 +16,13 @@
}); });
}else if (OpName == "ArgMax") { }else if (OpName == "ArgMax") {
ImportNodeOneOut<mlir::ONNXArgMaxOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXArgMaxOp>(node, 1, 1, {
{"axis","int","0"} {"axis", 0}
,{"keepdims","int","1"} ,{"keepdims", 1}
}); });
}else if (OpName == "ArgMin") { }else if (OpName == "ArgMin") {
ImportNodeOneOut<mlir::ONNXArgMinOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXArgMinOp>(node, 1, 1, {
{"axis","int","0"} {"axis", 0}
,{"keepdims","int","1"} ,{"keepdims", 1}
}); });
}else if (OpName == "Asin") { }else if (OpName == "Asin") {
ImportNodeOneOut<mlir::ONNXAsinOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXAsinOp>(node, 1, 1, {
@ -38,25 +38,22 @@
}); });
}else if (OpName == "AveragePool") { }else if (OpName == "AveragePool") {
ImportNodeOneOut<mlir::ONNXAveragePoolOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXAveragePoolOp>(node, 1, 1, {
{"auto_pad","str","NOTSET"} {"auto_pad", "NOTSET"}
,{"ceil_mode","int","0"} ,{"ceil_mode", 0}
,{"count_include_pad","int","0"} ,{"count_include_pad", 0}
,{"kernel_shape","ints", ""} ,{"kernel_shape", std::vector<int64_t> {}}
,{"pads","", ""}
,{"strides","", ""}
}); });
}else if (OpName == "BatchNormalization") { }else if (OpName == "BatchNormalization") {
ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, 5, 5, { ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, 5, 5, {
{"epsilon","float","1e-05"} {"epsilon", (float)1e-05}
,{"momentum","float","0.9"} ,{"momentum", (float)0.9}
}); });
}else if (OpName == "BitShift") { }else if (OpName == "BitShift") {
ImportNodeOneOut<mlir::ONNXBitShiftOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXBitShiftOp>(node, 2, 1, {
{"direction","", ""}
}); });
}else if (OpName == "Cast") { }else if (OpName == "Cast") {
ImportNodeOneOut<mlir::ONNXCastOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXCastOp>(node, 1, 1, {
{"to","int", "0"} {"to", 0}
}); });
}else if (OpName == "Ceil") { }else if (OpName == "Ceil") {
ImportNodeOneOut<mlir::ONNXCeilOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXCeilOp>(node, 1, 1, {
@ -66,54 +63,35 @@
}); });
}else if (OpName == "Compress") { }else if (OpName == "Compress") {
ImportNodeOneOut<mlir::ONNXCompressOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXCompressOp>(node, 2, 1, {
{"axis","", ""}
}); });
}else if (OpName == "Concat") { }else if (OpName == "Concat") {
ImportNodeOneOut<mlir::ONNXConcatOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXConcatOp>(node, 1, 1, {
{"axis","int", "0"} {"axis", 0}
}); });
}else if (OpName == "ConcatFromSequence") { }else if (OpName == "ConcatFromSequence") {
ImportNodeOneOut<mlir::ONNXConcatFromSequenceOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXConcatFromSequenceOp>(node, 1, 1, {
{"axis","", ""} {"new_axis", 0}
,{"new_axis","int","0"}
}); });
}else if (OpName == "Constant") { }else if (OpName == "Constant") {
ImportNodeOneOut<mlir::ONNXConstantOp>(node, 0, 1, { ImportNodeOneOut<mlir::ONNXConstantOp>(node, 0, 1, {
{"sparse_value","", ""}
,{"value","", ""}
}); });
}else if (OpName == "ConstantOfShape") { }else if (OpName == "ConstantOfShape") {
ImportNodeOneOut<mlir::ONNXConstantOfShapeOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXConstantOfShapeOp>(node, 1, 1, {
{"value","", ""}
}); });
}else if (OpName == "Conv") { }else if (OpName == "Conv") {
ImportNodeConv(node, 1, { ImportNodeConv(node, 3, 1, {
{"auto_pad","str","NOTSET"} {"auto_pad", "NOTSET"}
,{"dilations","", ""} ,{"group", 1}
,{"group","int", "1"}
,{"kernel_shape","", ""}
,{"pads","", ""}
,{"strides","", ""}
}); });
}else if (OpName == "ConvInteger") { }else if (OpName == "ConvInteger") {
ImportNodeOneOut<mlir::ONNXConvIntegerOp>(node, 4, 1, { ImportNodeOneOut<mlir::ONNXConvIntegerOp>(node, 4, 1, {
{"auto_pad","str","NOTSET"} {"auto_pad", "NOTSET"}
,{"dilations","", ""} ,{"group", 1}
,{"group","int","1"}
,{"kernel_shape","", ""}
,{"pads","", ""}
,{"strides","", ""}
}); });
}else if (OpName == "ConvTranspose") { }else if (OpName == "ConvTranspose") {
ImportNodeOneOut<mlir::ONNXConvTransposeOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXConvTransposeOp>(node, 3, 1, {
{"auto_pad","str","NOTSET"} {"auto_pad", "NOTSET"}
,{"dilations","", ""} ,{"group", 1}
,{"group","int","1"}
,{"kernel_shape","", ""}
,{"output_padding","", ""}
,{"output_shape","", ""}
,{"pads","", ""}
,{"strides","", ""}
}); });
}else if (OpName == "Cos") { }else if (OpName == "Cos") {
ImportNodeOneOut<mlir::ONNXCosOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXCosOp>(node, 1, 1, {
@ -123,13 +101,12 @@
}); });
}else if (OpName == "CumSum") { }else if (OpName == "CumSum") {
ImportNodeOneOut<mlir::ONNXCumSumOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXCumSumOp>(node, 2, 1, {
{"exclusive","int","0"} {"exclusive", 0}
,{"reverse","int","0"} ,{"reverse", 0}
}); });
}else if (OpName == "DepthToSpace") { }else if (OpName == "DepthToSpace") {
ImportNodeOneOut<mlir::ONNXDepthToSpaceOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXDepthToSpaceOp>(node, 1, 1, {
{"blocksize","", ""} {"mode", "DCR"}
,{"mode","str","DCR"}
}); });
}else if (OpName == "DequantizeLinear") { }else if (OpName == "DequantizeLinear") {
ImportNodeOneOut<mlir::ONNXDequantizeLinearOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXDequantizeLinearOp>(node, 3, 1, {
@ -142,14 +119,14 @@
}); });
}else if (OpName == "Dropout") { }else if (OpName == "Dropout") {
ImportNodeMultipleOuts<mlir::ONNXDropoutOp>(node, 1, 2, { ImportNodeMultipleOuts<mlir::ONNXDropoutOp>(node, 1, 2, {
{"ratio","float","0.5"} {"ratio", (float)0.5}
}); });
}else if (OpName == "DynamicQuantizeLinear") { }else if (OpName == "DynamicQuantizeLinear") {
ImportNodeMultipleOuts<mlir::ONNXDynamicQuantizeLinearOp>(node, 1, 3, { ImportNodeMultipleOuts<mlir::ONNXDynamicQuantizeLinearOp>(node, 1, 3, {
}); });
}else if (OpName == "Elu") { }else if (OpName == "Elu") {
ImportNodeOneOut<mlir::ONNXEluOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXEluOp>(node, 1, 1, {
{"alpha","float","1.0"} {"alpha", (float)1.0}
}); });
}else if (OpName == "Equal") { }else if (OpName == "Equal") {
ImportNodeOneOut<mlir::ONNXEqualOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXEqualOp>(node, 2, 1, {
@ -165,50 +142,44 @@
}); });
}else if (OpName == "EyeLike") { }else if (OpName == "EyeLike") {
ImportNodeOneOut<mlir::ONNXEyeLikeOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXEyeLikeOp>(node, 1, 1, {
{"dtype","", ""} {"k", 0}
,{"k","int","0"}
}); });
}else if (OpName == "Flatten") { }else if (OpName == "Flatten") {
ImportNodeOneOut<mlir::ONNXFlattenOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXFlattenOp>(node, 1, 1, {
{"axis","int","1"} {"axis", 1}
}); });
}else if (OpName == "Floor") { }else if (OpName == "Floor") {
ImportNodeOneOut<mlir::ONNXFloorOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXFloorOp>(node, 1, 1, {
}); });
}else if (OpName == "GRU") { }else if (OpName == "GRU") {
ImportNodeMultipleOuts<mlir::ONNXGRUOp>(node, 6, 2, { ImportNodeMultipleOuts<mlir::ONNXGRUOp>(node, 6, 2, {
{"activation_alpha","", ""} {"direction", "forward"}
,{"activation_beta","", ""} ,{"linear_before_reset", 0}
,{"activations","", ""}
,{"clip","", ""}
,{"direction","str","forward"}
,{"hidden_size","", ""}
,{"linear_before_reset","int","0"}
}); });
}else if (OpName == "Gather") { }else if (OpName == "Gather") {
ImportNodeOneOut<mlir::ONNXGatherOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXGatherOp>(node, 2, 1, {
{"axis","int","0"} {"axis", 0}
}); });
}else if (OpName == "GatherElements") { }else if (OpName == "GatherElements") {
ImportNodeOneOut<mlir::ONNXGatherElementsOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXGatherElementsOp>(node, 2, 1, {
{"axis","int","0"} {"axis", 0}
}); });
}else if (OpName == "GatherND") { }else if (OpName == "GatherND") {
ImportNodeOneOut<mlir::ONNXGatherNDOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXGatherNDOp>(node, 2, 1, {
}); });
}else if (OpName == "Gemm") { }else if (OpName == "Gemm") {
ImportNodeOneOut<mlir::ONNXGemmOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXGemmOp>(node, 3, 1, {
{"alpha","float","1.0"} {"alpha", (float)1.0}
,{"beta","float","1.0"} ,{"beta", (float)1.0}
,{"transA","int","0"} ,{"transA", 0}
,{"transB","int","0"} ,{"transB", 0}
}); });
}else if (OpName == "GlobalAveragePool") { }else if (OpName == "GlobalAveragePool") {
ImportNodeOneOut<mlir::ONNXGlobalAveragePoolOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXGlobalAveragePoolOp>(node, 1, 1, {
}); });
}else if (OpName == "GlobalLpPool") { }else if (OpName == "GlobalLpPool") {
ImportNodeOneOut<mlir::ONNXGlobalLpPoolOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXGlobalLpPoolOp>(node, 1, 1, {
{"p","int","2"} {"p", 2}
}); });
}else if (OpName == "GlobalMaxPool") { }else if (OpName == "GlobalMaxPool") {
ImportNodeOneOut<mlir::ONNXGlobalMaxPoolOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXGlobalMaxPoolOp>(node, 1, 1, {
@ -218,53 +189,45 @@
}); });
}else if (OpName == "HardSigmoid") { }else if (OpName == "HardSigmoid") {
ImportNodeOneOut<mlir::ONNXHardSigmoidOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXHardSigmoidOp>(node, 1, 1, {
{"alpha","float","0.2"} {"alpha", (float)0.2}
,{"beta","float","0.5"} ,{"beta", (float)0.5}
}); });
}else if (OpName == "Hardmax") { }else if (OpName == "Hardmax") {
ImportNodeOneOut<mlir::ONNXHardmaxOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXHardmaxOp>(node, 1, 1, {
{"axis","int","1"} {"axis", 1}
}); });
}else if (OpName == "Identity") { }else if (OpName == "Identity") {
ImportNodeOneOut<mlir::ONNXIdentityOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXIdentityOp>(node, 1, 1, {
}); });
}else if (OpName == "If") { }else if (OpName == "If") {
ImportNodeOneOut<mlir::ONNXIfOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXIfOp>(node, 1, 1, {
{"else_branch","", ""}
,{"then_branch","", ""}
}); });
}else if (OpName == "InstanceNormalization") { }else if (OpName == "InstanceNormalization") {
ImportNodeOneOut<mlir::ONNXInstanceNormalizationOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXInstanceNormalizationOp>(node, 3, 1, {
{"epsilon","float","1e-05"} {"epsilon", (float)1e-05}
}); });
}else if (OpName == "IsInf") { }else if (OpName == "IsInf") {
ImportNodeOneOut<mlir::ONNXIsInfOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXIsInfOp>(node, 1, 1, {
{"detect_negative","int","1"} {"detect_negative", 1}
,{"detect_positive","int","1"} ,{"detect_positive", 1}
}); });
}else if (OpName == "IsNaN") { }else if (OpName == "IsNaN") {
ImportNodeOneOut<mlir::ONNXIsNaNOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXIsNaNOp>(node, 1, 1, {
}); });
}else if (OpName == "LRN") { }else if (OpName == "LRN") {
ImportNodeOneOut<mlir::ONNXLRNOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXLRNOp>(node, 1, 1, {
{"alpha","float","0.0001"} {"alpha", (float)0.0001}
,{"beta","float","0.75"} ,{"beta", (float)0.75}
,{"bias","float","1.0"} ,{"bias", (float)1.0}
,{"size","int", ""}
}); });
}else if (OpName == "LSTM") { }else if (OpName == "LSTM") {
ImportNodeMultipleOuts<mlir::ONNXLSTMOp>(node, 8, 3, { ImportNodeMultipleOuts<mlir::ONNXLSTMOp>(node, 8, 3, {
{"activation_alpha","", ""} {"direction", "forward"}
,{"activation_beta","", ""} ,{"input_forget", 0}
,{"activations","", ""}
,{"clip","", ""}
,{"direction","str","forward"}
,{"hidden_size","", ""}
,{"input_forget","int","0"}
}); });
}else if (OpName == "LeakyRelu") { }else if (OpName == "LeakyRelu") {
ImportNodeOneOut<mlir::ONNXLeakyReluOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXLeakyReluOp>(node, 1, 1, {
{"alpha","float","0.01"} {"alpha", (float)0.01}
}); });
}else if (OpName == "Less") { }else if (OpName == "Less") {
ImportNodeOneOut<mlir::ONNXLessOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXLessOp>(node, 2, 1, {
@ -274,24 +237,20 @@
}); });
}else if (OpName == "LogSoftmax") { }else if (OpName == "LogSoftmax") {
ImportNodeOneOut<mlir::ONNXLogSoftmaxOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXLogSoftmaxOp>(node, 1, 1, {
{"axis","int","1"} {"axis", 1}
}); });
}else if (OpName == "Loop") { }else if (OpName == "Loop") {
ImportNodeOneOut<mlir::ONNXLoopOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXLoopOp>(node, 3, 1, {
{"body","", ""}
}); });
}else if (OpName == "LpNormalization") { }else if (OpName == "LpNormalization") {
ImportNodeOneOut<mlir::ONNXLpNormalizationOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXLpNormalizationOp>(node, 1, 1, {
{"axis","int","-1"} {"axis", -1}
,{"p","int","2"} ,{"p", 2}
}); });
}else if (OpName == "LpPool") { }else if (OpName == "LpPool") {
ImportNodeOneOut<mlir::ONNXLpPoolOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXLpPoolOp>(node, 1, 1, {
{"auto_pad","str","NOTSET"} {"auto_pad", "NOTSET"}
,{"kernel_shape","", ""} ,{"p", 2}
,{"p","int","2"}
,{"pads","", ""}
,{"strides","", ""}
}); });
}else if (OpName == "MatMul") { }else if (OpName == "MatMul") {
ImportNodeOneOut<mlir::ONNXMatMulOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXMatMulOp>(node, 2, 1, {
@ -303,55 +262,47 @@
ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, {
}); });
}else if (OpName == "MaxPool") { }else if (OpName == "MaxPool") {
ImportNodeMaxPool(node, 1, { ImportNodeMaxPool(node, 1, 2, {
{"auto_pad","str","NOTSET"} {"auto_pad", "NOTSET"}
,{"ceil_mode","int","0"} ,{"ceil_mode", 0}
,{"dilations","", ""} ,{"kernel_shape", std::vector<int64_t> {}}
,{"kernel_shape","ints", ""} ,{"storage_order", 0}
,{"pads","", ""}
,{"storage_order","int","0"}
,{"strides","", ""}
}); });
}else if (OpName == "MaxRoiPool") { }else if (OpName == "MaxRoiPool") {
ImportNodeOneOut<mlir::ONNXMaxRoiPoolOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXMaxRoiPoolOp>(node, 2, 1, {
{"pooled_shape","", ""} {"spatial_scale", (float)1.0}
,{"spatial_scale","float","1.0"}
}); });
}else if (OpName == "MaxUnpool") { }else if (OpName == "MaxUnpool") {
ImportNodeOneOut<mlir::ONNXMaxUnpoolOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXMaxUnpoolOp>(node, 3, 1, {
{"kernel_shape","", ""}
,{"pads","", ""}
,{"strides","", ""}
}); });
}else if (OpName == "Mean") { }else if (OpName == "Mean") {
ImportNodeOneOut<mlir::ONNXMeanOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXMeanOp>(node, 1, 1, {
}); });
}else if (OpName == "MeanVarianceNormalization") { }else if (OpName == "MeanVarianceNormalization") {
ImportNodeOneOut<mlir::ONNXMeanVarianceNormalizationOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXMeanVarianceNormalizationOp>(node, 1, 1, {
{"axes","ints","{'0', '2', '3'}"} {"axes", std::vector<int64_t>{0, 2, 3}}
}); });
}else if (OpName == "Min") { }else if (OpName == "Min") {
ImportNodeOneOut<mlir::ONNXMinOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXMinOp>(node, 1, 1, {
}); });
}else if (OpName == "Mod") { }else if (OpName == "Mod") {
ImportNodeOneOut<mlir::ONNXModOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXModOp>(node, 2, 1, {
{"fmod","int","0"} {"fmod", 0}
}); });
}else if (OpName == "Mul") { }else if (OpName == "Mul") {
ImportNodeOneOut<mlir::ONNXMulOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXMulOp>(node, 2, 1, {
}); });
}else if (OpName == "Multinomial") { }else if (OpName == "Multinomial") {
ImportNodeOneOut<mlir::ONNXMultinomialOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXMultinomialOp>(node, 1, 1, {
{"dtype","int","6"} {"dtype", 6}
,{"sample_size","int","1"} ,{"sample_size", 1}
,{"seed","", ""}
}); });
}else if (OpName == "Neg") { }else if (OpName == "Neg") {
ImportNodeOneOut<mlir::ONNXNegOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXNegOp>(node, 1, 1, {
}); });
}else if (OpName == "NonMaxSuppression") { }else if (OpName == "NonMaxSuppression") {
ImportNodeOneOut<mlir::ONNXNonMaxSuppressionOp>(node, 5, 1, { ImportNodeOneOut<mlir::ONNXNonMaxSuppressionOp>(node, 5, 1, {
{"center_point_box","int","0"} {"center_point_box", 0}
}); });
}else if (OpName == "NonZero") { }else if (OpName == "NonZero") {
ImportNodeOneOut<mlir::ONNXNonZeroOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXNonZeroOp>(node, 1, 1, {
@ -361,7 +312,7 @@
}); });
}else if (OpName == "OneHot") { }else if (OpName == "OneHot") {
ImportNodeOneOut<mlir::ONNXOneHotOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXOneHotOp>(node, 3, 1, {
{"axis","int","-1"} {"axis", -1}
}); });
}else if (OpName == "Or") { }else if (OpName == "Or") {
ImportNodeOneOut<mlir::ONNXOrOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXOrOp>(node, 2, 1, {
@ -371,19 +322,15 @@
}); });
}else if (OpName == "Pad") { }else if (OpName == "Pad") {
ImportNodeOneOut<mlir::ONNXPadOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXPadOp>(node, 3, 1, {
{"mode","str","constant"} {"mode", "constant"}
}); });
}else if (OpName == "Pow") { }else if (OpName == "Pow") {
ImportNodeOneOut<mlir::ONNXPowOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXPowOp>(node, 2, 1, {
}); });
}else if (OpName == "QLinearConv") { }else if (OpName == "QLinearConv") {
ImportNodeOneOut<mlir::ONNXQLinearConvOp>(node, 9, 1, { ImportNodeOneOut<mlir::ONNXQLinearConvOp>(node, 9, 1, {
{"auto_pad","str","NOTSET"} {"auto_pad", "NOTSET"}
,{"dilations","", ""} ,{"group", 1}
,{"group","int","1"}
,{"kernel_shape","", ""}
,{"pads","", ""}
,{"strides","", ""}
}); });
}else if (OpName == "QLinearMatMul") { }else if (OpName == "QLinearMatMul") {
ImportNodeOneOut<mlir::ONNXQLinearMatMulOp>(node, 8, 1, { ImportNodeOneOut<mlir::ONNXQLinearMatMulOp>(node, 8, 1, {
@ -393,42 +340,32 @@
}); });
}else if (OpName == "RNN") { }else if (OpName == "RNN") {
ImportNodeMultipleOuts<mlir::ONNXRNNOp>(node, 6, 2, { ImportNodeMultipleOuts<mlir::ONNXRNNOp>(node, 6, 2, {
{"activation_alpha","floats", "{}"} {"activation_alpha", std::vector<float> {}}
,{"activation_beta","floats", "{}"} ,{"activation_beta", std::vector<float> {}}
,{"activations","", "{Tannh, Tanh}"} ,{"activations", std::vector<std::string>{"Tanh", "Tanh"}}
,{"clip","", ""} ,{"direction", "forward"}
,{"direction","str","forward"}
,{"hidden_size","", ""}
}); });
}else if (OpName == "RandomNormal") { }else if (OpName == "RandomNormal") {
ImportNodeOneOut<mlir::ONNXRandomNormalOp>(node, 0, 1, { ImportNodeOneOut<mlir::ONNXRandomNormalOp>(node, 0, 1, {
{"dtype","int","1"} {"dtype", 1}
,{"mean","float","0.0"} ,{"mean", (float)0.0}
,{"scale","float","1.0"} ,{"scale", (float)1.0}
,{"seed","", ""}
,{"shape","", ""}
}); });
}else if (OpName == "RandomNormalLike") { }else if (OpName == "RandomNormalLike") {
ImportNodeOneOut<mlir::ONNXRandomNormalLikeOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXRandomNormalLikeOp>(node, 1, 1, {
{"dtype","", ""} {"mean", (float)0.0}
,{"mean","float","0.0"} ,{"scale", (float)1.0}
,{"scale","float","1.0"}
,{"seed","", ""}
}); });
}else if (OpName == "RandomUniform") { }else if (OpName == "RandomUniform") {
ImportNodeOneOut<mlir::ONNXRandomUniformOp>(node, 0, 1, { ImportNodeOneOut<mlir::ONNXRandomUniformOp>(node, 0, 1, {
{"dtype","int","1"} {"dtype", 1}
,{"high","float","1.0"} ,{"high", (float)1.0}
,{"low","float","0.0"} ,{"low", (float)0.0}
,{"seed","", ""}
,{"shape","", ""}
}); });
}else if (OpName == "RandomUniformLike") { }else if (OpName == "RandomUniformLike") {
ImportNodeOneOut<mlir::ONNXRandomUniformLikeOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXRandomUniformLikeOp>(node, 1, 1, {
{"dtype","", ""} {"high", (float)1.0}
,{"high","float","1.0"} ,{"low", (float)0.0}
,{"low","float","0.0"}
,{"seed","", ""}
}); });
}else if (OpName == "Range") { }else if (OpName == "Range") {
ImportNodeOneOut<mlir::ONNXRangeOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXRangeOp>(node, 3, 1, {
@ -438,53 +375,43 @@
}); });
}else if (OpName == "ReduceL1") { }else if (OpName == "ReduceL1") {
ImportNodeOneOut<mlir::ONNXReduceL1Op>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXReduceL1Op>(node, 1, 1, {
{"axes","", ""} {"keepdims", 1}
,{"keepdims","int","1"}
}); });
}else if (OpName == "ReduceL2") { }else if (OpName == "ReduceL2") {
ImportNodeOneOut<mlir::ONNXReduceL2Op>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXReduceL2Op>(node, 1, 1, {
{"axes","", ""} {"keepdims", 1}
,{"keepdims","int","1"}
}); });
}else if (OpName == "ReduceLogSum") { }else if (OpName == "ReduceLogSum") {
ImportNodeOneOut<mlir::ONNXReduceLogSumOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXReduceLogSumOp>(node, 1, 1, {
{"axes","", ""} {"keepdims", 1}
,{"keepdims","int","1"}
}); });
}else if (OpName == "ReduceLogSumExp") { }else if (OpName == "ReduceLogSumExp") {
ImportNodeOneOut<mlir::ONNXReduceLogSumExpOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXReduceLogSumExpOp>(node, 1, 1, {
{"axes","", ""} {"keepdims", 1}
,{"keepdims","int","1"}
}); });
}else if (OpName == "ReduceMax") { }else if (OpName == "ReduceMax") {
ImportNodeOneOut<mlir::ONNXReduceMaxOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXReduceMaxOp>(node, 1, 1, {
{"axes","", ""} {"keepdims", 1}
,{"keepdims","int","1"}
}); });
}else if (OpName == "ReduceMean") { }else if (OpName == "ReduceMean") {
ImportNodeOneOut<mlir::ONNXReduceMeanOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXReduceMeanOp>(node, 1, 1, {
{"axes","", ""} {"keepdims", 1}
,{"keepdims","int","1"}
}); });
}else if (OpName == "ReduceMin") { }else if (OpName == "ReduceMin") {
ImportNodeOneOut<mlir::ONNXReduceMinOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXReduceMinOp>(node, 1, 1, {
{"axes","", ""} {"keepdims", 1}
,{"keepdims","int","1"}
}); });
}else if (OpName == "ReduceProd") { }else if (OpName == "ReduceProd") {
ImportNodeOneOut<mlir::ONNXReduceProdOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXReduceProdOp>(node, 1, 1, {
{"axes","", ""} {"keepdims", 1}
,{"keepdims","int","1"}
}); });
}else if (OpName == "ReduceSum") { }else if (OpName == "ReduceSum") {
ImportNodeOneOut<mlir::ONNXReduceSumOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXReduceSumOp>(node, 1, 1, {
{"axes","", ""} {"keepdims", 1}
,{"keepdims","int","1"}
}); });
}else if (OpName == "ReduceSumSquare") { }else if (OpName == "ReduceSumSquare") {
ImportNodeOneOut<mlir::ONNXReduceSumSquareOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXReduceSumSquareOp>(node, 1, 1, {
{"axes","", ""} {"keepdims", 1}
,{"keepdims","int","1"}
}); });
}else if (OpName == "Relu") { }else if (OpName == "Relu") {
ImportNodeOneOut<mlir::ONNXReluOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXReluOp>(node, 1, 1, {
@ -494,53 +421,47 @@
}); });
}else if (OpName == "Resize") { }else if (OpName == "Resize") {
ImportNodeOneOut<mlir::ONNXResizeOp>(node, 4, 1, { ImportNodeOneOut<mlir::ONNXResizeOp>(node, 4, 1, {
{"coordinate_transformation_mode","str","half_pixel"} {"coordinate_transformation_mode", "half_pixel"}
,{"cubic_coeff_a","float","-0.75"} ,{"cubic_coeff_a", (float)-0.75}
,{"exclude_outside","int","0"} ,{"exclude_outside", 0}
,{"extrapolation_value","float","0.0"} ,{"extrapolation_value", (float)0.0}
,{"mode","str","nearest"} ,{"mode", "nearest"}
,{"nearest_mode","str","round_prefer_floor"} ,{"nearest_mode", "round_prefer_floor"}
}); });
}else if (OpName == "ReverseSequence") { }else if (OpName == "ReverseSequence") {
ImportNodeOneOut<mlir::ONNXReverseSequenceOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXReverseSequenceOp>(node, 2, 1, {
{"batch_axis","int","1"} {"batch_axis", 1}
,{"time_axis","int","0"} ,{"time_axis", 0}
}); });
}else if (OpName == "RoiAlign") { }else if (OpName == "RoiAlign") {
ImportNodeOneOut<mlir::ONNXRoiAlignOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXRoiAlignOp>(node, 3, 1, {
{"mode","str","avg"} {"mode", "avg"}
,{"output_height","int","1"} ,{"output_height", 1}
,{"output_width","int","1"} ,{"output_width", 1}
,{"sampling_ratio","int","0"} ,{"sampling_ratio", 0}
,{"spatial_scale","float","1.0"} ,{"spatial_scale", (float)1.0}
}); });
}else if (OpName == "Round") { }else if (OpName == "Round") {
ImportNodeOneOut<mlir::ONNXRoundOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXRoundOp>(node, 1, 1, {
}); });
}else if (OpName == "Scan") { }else if (OpName == "Scan") {
ImportNodeOneOut<mlir::ONNXScanOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXScanOp>(node, 1, 1, {
{"body","", ""}
,{"num_scan_inputs","", ""}
,{"scan_input_axes","", ""}
,{"scan_input_directions","", ""}
,{"scan_output_axes","", ""}
,{"scan_output_directions","", ""}
}); });
}else if (OpName == "Scatter") { }else if (OpName == "Scatter") {
ImportNodeOneOut<mlir::ONNXScatterOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXScatterOp>(node, 3, 1, {
{"axis","int","0"} {"axis", 0}
}); });
}else if (OpName == "ScatterElements") { }else if (OpName == "ScatterElements") {
ImportNodeOneOut<mlir::ONNXScatterElementsOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXScatterElementsOp>(node, 3, 1, {
{"axis","int","0"} {"axis", 0}
}); });
}else if (OpName == "ScatterND") { }else if (OpName == "ScatterND") {
ImportNodeOneOut<mlir::ONNXScatterNDOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXScatterNDOp>(node, 3, 1, {
}); });
}else if (OpName == "Selu") { }else if (OpName == "Selu") {
ImportNodeOneOut<mlir::ONNXSeluOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXSeluOp>(node, 1, 1, {
{"alpha","float","1.67326"} {"alpha", (float)1.67326}
,{"gamma","float","1.0507"} ,{"gamma", (float)1.0507}
}); });
}else if (OpName == "SequenceAt") { }else if (OpName == "SequenceAt") {
ImportNodeOneOut<mlir::ONNXSequenceAtOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXSequenceAtOp>(node, 2, 1, {
@ -550,7 +471,6 @@
}); });
}else if (OpName == "SequenceEmpty") { }else if (OpName == "SequenceEmpty") {
ImportNodeOneOut<mlir::ONNXSequenceEmptyOp>(node, 0, 1, { ImportNodeOneOut<mlir::ONNXSequenceEmptyOp>(node, 0, 1, {
{"dtype","", ""}
}); });
}else if (OpName == "SequenceErase") { }else if (OpName == "SequenceErase") {
ImportNodeOneOut<mlir::ONNXSequenceEraseOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXSequenceEraseOp>(node, 2, 1, {
@ -566,8 +486,8 @@
}); });
}else if (OpName == "Shrink") { }else if (OpName == "Shrink") {
ImportNodeOneOut<mlir::ONNXShrinkOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXShrinkOp>(node, 1, 1, {
{"bias","float","0.0"} {"bias", (float)0.0}
,{"lambd","float","0.5"} ,{"lambd", (float)0.5}
}); });
}else if (OpName == "Sigmoid") { }else if (OpName == "Sigmoid") {
ImportNodeOneOut<mlir::ONNXSigmoidOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXSigmoidOp>(node, 1, 1, {
@ -589,7 +509,7 @@
}); });
}else if (OpName == "Softmax") { }else if (OpName == "Softmax") {
ImportNodeOneOut<mlir::ONNXSoftmaxOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXSoftmaxOp>(node, 1, 1, {
{"axis","int","1"} {"axis", 1}
}); });
}else if (OpName == "Softplus") { }else if (OpName == "Softplus") {
ImportNodeOneOut<mlir::ONNXSoftplusOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXSoftplusOp>(node, 1, 1, {
@ -599,31 +519,26 @@
}); });
}else if (OpName == "SpaceToDepth") { }else if (OpName == "SpaceToDepth") {
ImportNodeOneOut<mlir::ONNXSpaceToDepthOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXSpaceToDepthOp>(node, 1, 1, {
{"blocksize","", ""}
}); });
}else if (OpName == "Split") { }else if (OpName == "Split") {
ImportNodeOneOut<mlir::ONNXSplitOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXSplitOp>(node, 1, 1, {
{"axis","int","0"} {"axis", 0}
,{"split","", ""}
}); });
}else if (OpName == "SplitToSequence") { }else if (OpName == "SplitToSequence") {
ImportNodeOneOut<mlir::ONNXSplitToSequenceOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXSplitToSequenceOp>(node, 2, 1, {
{"axis","int","0"} {"axis", 0}
,{"keepdims","int","1"} ,{"keepdims", 1}
}); });
}else if (OpName == "Sqrt") { }else if (OpName == "Sqrt") {
ImportNodeOneOut<mlir::ONNXSqrtOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXSqrtOp>(node, 1, 1, {
}); });
}else if (OpName == "Squeeze") { }else if (OpName == "Squeeze") {
ImportNodeOneOut<mlir::ONNXSqueezeOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXSqueezeOp>(node, 1, 1, {
{"axes","", ""}
}); });
}else if (OpName == "StringNormalizer") { }else if (OpName == "StringNormalizer") {
ImportNodeOneOut<mlir::ONNXStringNormalizerOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXStringNormalizerOp>(node, 1, 1, {
{"case_change_action","str","NONE"} {"case_change_action", "NONE"}
,{"is_case_sensitive","int","0"} ,{"is_case_sensitive", 0}
,{"locale","", ""}
,{"stopwords","", ""}
}); });
}else if (OpName == "Sub") { }else if (OpName == "Sub") {
ImportNodeOneOut<mlir::ONNXSubOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXSubOp>(node, 2, 1, {
@ -639,45 +554,34 @@
}); });
}else if (OpName == "TfIdfVectorizer") { }else if (OpName == "TfIdfVectorizer") {
ImportNodeOneOut<mlir::ONNXTfIdfVectorizerOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXTfIdfVectorizerOp>(node, 1, 1, {
{"max_gram_length","", ""}
,{"max_skip_count","", ""}
,{"min_gram_length","", ""}
,{"mode","", ""}
,{"ngram_counts","", ""}
,{"ngram_indexes","", ""}
,{"pool_int64s","", ""}
,{"pool_strings","", ""}
,{"weights","", ""}
}); });
}else if (OpName == "ThresholdedRelu") { }else if (OpName == "ThresholdedRelu") {
ImportNodeOneOut<mlir::ONNXThresholdedReluOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXThresholdedReluOp>(node, 1, 1, {
{"alpha","float","1.0"} {"alpha", (float)1.0}
}); });
}else if (OpName == "Tile") { }else if (OpName == "Tile") {
ImportNodeOneOut<mlir::ONNXTileOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXTileOp>(node, 2, 1, {
}); });
}else if (OpName == "TopK") { }else if (OpName == "TopK") {
ImportNodeMultipleOuts<mlir::ONNXTopKOp>(node, 2, 2, { ImportNodeMultipleOuts<mlir::ONNXTopKOp>(node, 2, 2, {
{"axis","int","-1"} {"axis", -1}
,{"largest","int","1"} ,{"largest", 1}
,{"sorted","int","1"} ,{"sorted", 1}
}); });
}else if (OpName == "Transpose") { }else if (OpName == "Transpose") {
ImportNodeOneOut<mlir::ONNXTransposeOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXTransposeOp>(node, 1, 1, {
{"perm","", ""}
}); });
}else if (OpName == "Unique") { }else if (OpName == "Unique") {
ImportNodeMultipleOuts<mlir::ONNXUniqueOp>(node, 1, 4, { ImportNodeMultipleOuts<mlir::ONNXUniqueOp>(node, 1, 4, {
{"axis","", ""} {"sorted", 1}
,{"sorted","int","1"}
}); });
}else if (OpName == "Unsqueeze") { }else if (OpName == "Unsqueeze") {
ImportNodeOneOut<mlir::ONNXUnsqueezeOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXUnsqueezeOp>(node, 1, 1, {
{"axes","ints", ""} {"axes", std::vector<int64_t> {}}
}); });
}else if (OpName == "Upsample") { }else if (OpName == "Upsample") {
ImportNodeOneOut<mlir::ONNXUpsampleOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXUpsampleOp>(node, 2, 1, {
{"mode","str","nearest"} {"mode", "nearest"}
}); });
}else if (OpName == "Where") { }else if (OpName == "Where") {
ImportNodeOneOut<mlir::ONNXWhereOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXWhereOp>(node, 3, 1, {

View File

@ -267,7 +267,7 @@ def gen_schema(schema) :
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', 'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu', 'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
'Identity', 'Cos', 'Log', 'Transpose'] 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax']
CanonicalList=['Add', 'Identity'] CanonicalList=['Add', 'Identity']
line_indent = ' ' line_indent = ' '
@ -368,17 +368,17 @@ def gen_code(schema,fefile) :
("MaxPool", "ImportNodeMaxPool"), ("MaxPool", "ImportNodeMaxPool"),
#("Transpose", "ImportNodeTranspose") #("Transpose", "ImportNodeTranspose")
]) ])
special_type = dict([ list_str = 'std::vector'
("AveragePool "+"kernel_shape", '"ints", ""'), empty_ints = list_str+'<int> {}'
("MaxPool "+"kernel_shape", '"ints", ""'), empty_floats = list_str+'<float> {}'
("Cast "+"to", '"int", "0"'), special_default = dict([
("Concat "+"axis", '"int", "0"'), ("AveragePool "+"kernel_shape", empty_ints),
("Conv "+"group", '"int", "1"'), ("MaxPool "+"kernel_shape", empty_ints),
("Unsqueeze "+"axes", '"ints", ""'), ("Cast "+"to", '0'),
("RNN "+"activation_alpha", '"floats", "{}"'), ("Concat "+"axis", '0'),
("RNN "+"activation_beta", '"floats", "{}"'), ("Unsqueeze "+"axes", empty_ints),
("RNN "+"activations", '"", "{Tannh, Tanh}"'), ("RNN "+"activation_alpha", empty_floats),
("LRN "+"size", '"int", ""') ("RNN "+"activation_beta", empty_floats)
]) ])
line_indent = ' ' line_indent = ' '
fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n') fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n')
@ -400,21 +400,9 @@ def gen_code(schema,fefile) :
if schema.attributes: if schema.attributes:
first_attr = True first_attr = True
for _, attr in sorted(schema.attributes.items()): for _, attr in sorted(schema.attributes.items()):
attr_line = line_indent+line_indent+line_indent+line_indent #only generate default attr list
if not first_attr: if schema.name+' '+attr.name in special_default:
attr_line += ',{' attr_value = special_default[schema.name+' '+attr.name]
else :
attr_line += ' {'
first_attr = False
attr_line += '"'+attr.name+'",'
if schema.name+' '+attr.name in special_type:
attr_line += special_type[schema.name+' '+attr.name]
# option holds either required or default value
elif attr.required:
attr_line += '"", ""'
elif attr.default_value.name: elif attr.default_value.name:
default_value = helper.get_attribute_value(attr.default_value) default_value = helper.get_attribute_value(attr.default_value)
@ -430,28 +418,35 @@ def gen_code(schema,fefile) :
return str(value) return str(value)
if isinstance(default_value, list): if isinstance(default_value, list):
value = default_value[0] value = default_value[0]
default_value = [format_value(val) for val in default_value] default_value = [format_value(val) for val in default_value]
attr_option_str = '{}'.format(default_value)
attr_option_str = attr_option_str.replace('[', '{', 1)
attr_option_str = attr_option_str.replace(']', '}', 1)
# TODO the list type is homogenous or htergeneous? # TODO the list type is homogenous or htergeneous?
if isinstance(value, float) : if isinstance(value, float) :
attr_type_str = '"floats"' attr_type_str = list_str+'<float>'
attr_option_str = attr_option_str.replace("'", '')
elif isinstance(value, int) : elif isinstance(value, int) :
attr_type_str = '"ints"' attr_type_str = list_str+'<int>'
attr_option_str = attr_option_str.replace("'", '')
elif isinstance(value, str) : elif isinstance(value, str) :
attr_type_str = '"strs"' attr_type_str = list_str+'<std::string>'
attr_option_str = attr_option_str.replace("'", '"')
elif isinstance(value, (bytes, bytearray)) : elif isinstance(value, (bytes, bytearray)) :
attr_type_str = '"strs"' attr_type_str = list_str+'<std::string>'
attr_option_str = attr_option_str.replace("'", '"')
else : else :
attr_type_str = '"unknowns"' attr_type_str = '"unknowns"'
attr_option_str = '"{}"'.format(default_value)
attr_option_str = attr_option_str.replace('[', '{', 1)
attr_option_str = attr_option_str.replace(']', '}', 1)
else: else:
if isinstance(default_value, float) : if isinstance(default_value, float) :
attr_type_str = '"float"' attr_type_str = '(float)'
attr_option_str = default_value
elif isinstance(default_value, int) : elif isinstance(default_value, int) :
attr_type_str = '"int"' attr_option_str = default_value
attr_type_str=''
elif isinstance(default_value, str) : elif isinstance(default_value, str) :
attr_type_str = '"str"' attr_type_str = '"str"'
elif isinstance(default_value, (bytes, bytearray)) : elif isinstance(default_value, (bytes, bytearray)) :
@ -459,11 +454,25 @@ def gen_code(schema,fefile) :
else : else :
attr_type_str = '"unknown"' attr_type_str = '"unknown"'
default_value = format_value(default_value) default_value = format_value(default_value)
attr_option_str = '"{}"'.format(default_value) if attr_type_str == '"str"' :
attr_line += attr_type_str+','+attr_option_str attr_option_str = '"'+default_value+'"'
attr_type_str=''
else :
attr_option_str = default_value
attr_value = attr_type_str+attr_option_str
else: else:
#TODO why? #no default value
attr_line += '"", ""' continue
attr_line = line_indent+line_indent+line_indent+line_indent
if not first_attr:
attr_line += ',{'
else :
attr_line += ' {'
first_attr = False
attr_line += '"'+attr.name+'", '
attr_line += attr_value
attr_line += '}\n' attr_line += '}\n'
fefile.write(attr_line) fefile.write(attr_line)
fefile.write(line_indent+line_indent+line_indent+'});\n') fefile.write(line_indent+line_indent+line_indent+'});\n')

View File

@ -13,6 +13,7 @@
#include "mlir/IR/Function.h" #include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h" #include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SetVector.h"
@ -157,6 +158,14 @@ void ONNXReciprocalOp::inferShapes() {
getResult().setType(getOperand().getType()); getResult().setType(getOperand().getType());
} }
//===----------------------------------------------------------------------===//
// Softmax
/// Infer the output shape of the ONNXSoftmaxOp. This method is required by
/// the shape inference interface.
void ONNXSoftmaxOp::inferShapes() {
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Add // Add
/// Infer the output shape of the ONNXAddOp. This method is required by the /// Infer the output shape of the ONNXAddOp. This method is required by the
@ -484,13 +493,38 @@ void ONNXTransposeOp::inferShapes() {
// Naive transposition which handles the default case of // Naive transposition which handles the default case of
// reversing the shape of the tensor (similar to numpy.transpose). // reversing the shape of the tensor (similar to numpy.transpose).
// TODO: Once attributes are supported we can handle the case where the
// transposition uses a permutation vector to interchange the axes.
auto arrayTy = getOperand().getType().cast<RankedTensorType>(); auto arrayTy = getOperand().getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape())); SmallVector<int64_t, 2> dims;
if (auto permutation = getAttrOfType<ArrayAttr>(
ONNXTransposeOp::getPermAttrName())) {
// Perform transposition according to perm attribute.
for (auto perm : permutation.getValue())
dims.emplace_back(arrayTy.getShape()[perm.cast<IntegerAttr>().getInt()]);
} else {
// Default
for (auto dim : llvm::reverse(arrayTy.getShape()))
dims.emplace_back(dim);
}
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
} }
LogicalResult verify(ONNXTransposeOp op) {
auto module = op.getParentOfType<ModuleOp>();
if (!module)
op.emitError("Expected to belong to a module.");
if (auto permutation = op.getAttrOfType<ArrayAttr>(
ONNXTransposeOp::getPermAttrName())) {
for (auto perm : permutation.getValue())
if (perm.cast<IntegerAttr>().getInt() < 0)
op.emitError("Cannot tranpose, permuation contains negative index.");
}
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TableGen'd op method definitions // TableGen'd op method definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -2831,7 +2831,7 @@ def ONNXSliceOp:ONNX_Op<"Slice",
} }
def ONNXSoftmaxOp:ONNX_Op<"Softmax", def ONNXSoftmaxOp:ONNX_Op<"Softmax",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Softmax operation"; let summary = "ONNX Softmax operation";
let description = [{ let description = [{
"The operator computes the softmax (normalized exponential) values for each layer in the batch" "The operator computes the softmax (normalized exponential) values for each layer in the batch"
@ -3098,6 +3098,12 @@ def ONNXTransposeOp:ONNX_Op<"Transpose",
}]; }];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
let extraClassDeclaration = [{
static StringRef getPermAttrName() { return "perm"; }
}];
let verifier = [{ return ::verify(*this); }];
} }
def ONNXUniqueOp:ONNX_Op<"Unique", def ONNXUniqueOp:ONNX_Op<"Unique",

View File

@ -135,10 +135,13 @@ int main(int argc, char *argv[]) {
if (mlir::failed(pm.run(*module))) if (mlir::failed(pm.run(*module)))
return 4; return 4;
module->dump();
// Write LLVM bitcode to disk. if (emissionTarget == EmitLLVMBC) {
if (emissionTarget == EmitLLVMBC) // Write LLVM bitcode to disk.
EmitLLVMBitCode(module); EmitLLVMBitCode(module);
printf("LLVM bitcode written to ./model.bc");
} else
module->dump();
return 0; return 0;
} }

View File

@ -420,8 +420,8 @@ Value mapToLowerScalarOp<ONNXHardSigmoidOp>(
// Constant 1) // Constant 1)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value operand = operands[0]; Value operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.alpha"); auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha");
auto betaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.beta"); auto betaAttr = op->getAttrOfType<FloatAttr>("beta");
auto elementType = result_types[0]; auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0)); auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
@ -455,7 +455,7 @@ Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
Value operand = operands[0]; Value operand = operands[0];
auto elementType = result_types[0]; auto elementType = result_types[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("Elu.alpha"); auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha");
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0)); auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1)); auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr); auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
@ -508,7 +508,7 @@ Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op,
Value operand = operands[0]; Value operand = operands[0];
auto elementType = result_types[0]; auto elementType = result_types[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("LeakyRelu.alpha"); auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha");
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0)); auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr); auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto lessThanZero = auto lessThanZero =
@ -533,8 +533,8 @@ Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types,
// alpha))) // alpha)))
auto loc = op->getLoc(); auto loc = op->getLoc();
Value operand = operands[0]; Value operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("Selu.alpha"); auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha");
auto gammaAttr = op->getAttrOfType<FloatAttr>("Selu.gamma"); auto gammaAttr = op->getAttrOfType<FloatAttr>("gamma");
auto elementType = result_types[0]; auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0)); auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
@ -824,6 +824,225 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
} }
}; };
struct ONNXSoftmaxOpLowering : public ConversionPattern {
ONNXSoftmaxOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// softmax(x) = let max_x = max(x) in
// let exp_x = exp(x - max_x) in
// let sum = sum(exp_x) in
// exp_x / sum
auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
int64_t rank = tensorType.getRank();
int64_t axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
axis = axis >= 0 ? axis : rank + axis;
assert(axis >= -rank && axis <= rank - 1);
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
auto elementType = memRefType.getElementType();
Value alloc;
bool insertDealloc = checkInsertDealloc(op);
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
operands[0]);
// Shape of the result
auto memRefShape = memRefType.getShape();
// Insert allocations and deallocations for sum and max.
MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0);
Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
Value zero =
rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
Value negInfinity = rewriter.create<ConstantOp>(
loc,
FloatAttr::get(elementType, -std::numeric_limits<float>::infinity()));
// Define loops.
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
std::vector<Value> originalLoops;
originalLoops.reserve(rank);
for (auto result : loopsOp.getResults()) {
originalLoops.push_back(result);
}
// Define loop optimization.
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
std::vector<Value> optimizedLoops;
optimizedLoops.reserve(rank);
for (auto result : optimizedLoopsOp.getResults()) {
optimizedLoops.push_back(result);
}
Block &optimizationBlock = optimizedLoopsOp.region().front();
// Coerce the input into a 2-D tensor. `axis` will be the coercing point.
// This coercing follows the softmax definition in ONNX:
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softmax
// Here, we create an outer loop and inner loop for handling the two
// dimensions. The outer loop is only created once `axis` is not zero.
// Define an outer loop with respect to axis.
std::vector<Value> outerLoops, optimizedOuterLoops;
outerLoops.reserve(axis);
optimizedOuterLoops.reserve(axis);
for (int i = 0; i < axis; ++i) {
outerLoops.push_back(originalLoops[i]);
optimizedOuterLoops.push_back(optimizedLoops[i]);
}
KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops);
for (int i = 0; i < axis; ++i) {
if (memRefShape[i] < 0) {
outerPack.pushConstantBound(0);
outerPack.pushOperandBound(
rewriter.create<DimOp>(loc, operands[0], i).getResult());
} else {
outerPack.pushConstantBound(0);
outerPack.pushConstantBound(memRefShape[i]);
}
}
// Define an inner loop with respect to axis.
std::vector<Value> innerLoops, optimizedInnerLoops;
innerLoops.reserve(rank - axis);
optimizedInnerLoops.reserve(rank - axis);
for (int i = axis; i < rank; ++i) {
innerLoops.push_back(originalLoops[i]);
optimizedInnerLoops.push_back(optimizedLoops[i]);
}
KrnlIterateOperandPack innerPack(rewriter, innerLoops, optimizedInnerLoops);
for (int i = axis; i < rank; ++i) {
if (memRefShape[i] < 0) {
innerPack.pushConstantBound(0);
innerPack.pushOperandBound(
rewriter.create<DimOp>(loc, operands[0], i).getResult());
} else {
innerPack.pushConstantBound(0);
innerPack.pushConstantBound(memRefShape[i]);
}
}
KrnlIterateOp outerIterateOp, maxIterateOp, sumIterateOp, softmaxIterateOp;
SmallVector<Value, 4> outerLoopIVs;
if (axis != 0) {
outerIterateOp = rewriter.create<KrnlIterateOp>(loc, outerPack);
// No optimization
rewriter.setInsertionPointToEnd(&optimizationBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
rewriter.setInsertionPoint(optimizedLoopsOp);
// Insert instructions inside the outer loop.
Block &outerIterationBlock = outerIterateOp.bodyRegion().front();
rewriter.setInsertionPointToStart(&outerIterationBlock);
for (auto arg : outerIterationBlock.getArguments())
outerLoopIVs.push_back(arg);
// Reset accumulators.
rewriter.create<StoreOp>(loc, zero, sumOp);
rewriter.create<StoreOp>(loc, negInfinity, maxOp);
// Create an inner loop to compute max.
maxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
// Create an inner loop to compute sum.
sumIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
// Create an inner loop to compute softmax.
softmaxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
} else {
// Reset accumulators.
rewriter.create<StoreOp>(loc, zero, sumOp);
rewriter.create<StoreOp>(loc, negInfinity, maxOp);
// Create an inner loop to compute max.
maxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
// Create an inner loop to compute sum.
sumIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
// Create an inner loop to compute softmax.
softmaxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
// No optimization
rewriter.setInsertionPointToEnd(&optimizationBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
rewriter.setInsertionPoint(optimizedLoopsOp);
}
// Insert instructions inside the max loop.
Block &maxIterationBlock = maxIterateOp.bodyRegion().front();
rewriter.setInsertionPointToStart(&maxIterationBlock);
// Get induction variables.
SmallVector<Value, 4> maxLoopIVs;
for (auto arg : outerLoopIVs)
maxLoopIVs.push_back(arg);
for (auto arg : maxIterationBlock.getArguments())
maxLoopIVs.push_back(arg);
// Compute the max value.
Value max = rewriter.create<LoadOp>(loc, maxOp);
Value nextMax = rewriter.create<LoadOp>(loc, operands[0], maxLoopIVs);
auto maxCond =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, max, nextMax);
max = rewriter.create<SelectOp>(loc, maxCond, max, nextMax);
rewriter.create<StoreOp>(loc, max, maxOp);
// Get the max.
rewriter.setInsertionPoint(sumIterateOp);
max = rewriter.create<LoadOp>(loc, maxOp);
// Insert instructions inside the sum loop.
Block &sumIterationBlock = sumIterateOp.bodyRegion().front();
rewriter.setInsertionPointToStart(&sumIterationBlock);
// Get induction variables.
SmallVector<Value, 4> sumLoopIVs;
for (auto arg : outerLoopIVs)
sumLoopIVs.push_back(arg);
for (auto arg : sumIterationBlock.getArguments())
sumLoopIVs.push_back(arg);
// Sum up values.
Value sum = rewriter.create<LoadOp>(loc, sumOp);
Value next = rewriter.create<LoadOp>(loc, operands[0], sumLoopIVs);
Value sub = rewriter.create<SubFOp>(loc, next, max);
Value exp = rewriter.create<ExpOp>(loc, sub);
sum = rewriter.create<AddFOp>(loc, sum, exp);
rewriter.create<StoreOp>(loc, sum, sumOp);
// Store intermediate values in the result to avoid recomputation.
rewriter.create<StoreOp>(loc, exp, alloc, sumLoopIVs);
// Get the sum.
rewriter.setInsertionPoint(softmaxIterateOp);
sum = rewriter.create<LoadOp>(loc, sumOp);
// Insert instructions inside the softmax loop.
Block &softmaxIterationBlock = softmaxIterateOp.bodyRegion().front();
rewriter.setInsertionPointToStart(&softmaxIterationBlock);
// Get induction variables.
SmallVector<Value, 4> softmaxLoopIVs;
for (auto arg : outerLoopIVs)
softmaxLoopIVs.push_back(arg);
for (auto arg : softmaxIterationBlock.getArguments())
softmaxLoopIVs.push_back(arg);
// Compute softmax.
Value expLoadedVal = rewriter.create<LoadOp>(loc, alloc, softmaxLoopIVs);
Value result = rewriter.create<DivFOp>(loc, expLoadedVal, sum);
rewriter.create<StoreOp>(loc, result, alloc, softmaxLoopIVs);
rewriter.replaceOp(op, alloc);
return matchSuccess();
}
};
struct ONNXReshapeOpLowering : public ConversionPattern { struct ONNXReshapeOpLowering : public ConversionPattern {
ONNXReshapeOpLowering(MLIRContext *ctx) ONNXReshapeOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {} : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
@ -1005,7 +1224,8 @@ void FrontendToKrnlLoweringPass::runOnModule() {
ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
ONNXReshapeOpLowering, ONNXEntryPointLowering>(&getContext()); ONNXReshapeOpLowering, ONNXEntryPointLowering,
ONNXSoftmaxOpLowering>(&getContext());
// With the target and rewrite patterns defined, we can now attempt the // With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal` // conversion. The conversion will signal failure if any of our `illegal`

View File

@ -116,7 +116,8 @@ public:
op->getName().getStringRef() != "onnx.Gemm" && op->getName().getStringRef() != "onnx.Gemm" &&
op->getName().getStringRef() != "onnx.GemmNoBias" && op->getName().getStringRef() != "onnx.GemmNoBias" &&
op->getName().getStringRef() != "onnx.Reshape" && op->getName().getStringRef() != "onnx.Reshape" &&
op->getName().getStringRef() != "onnx.Transpose") op->getName().getStringRef() != "onnx.Transpose" &&
op->getName().getStringRef() != "onnx.Softmax")
return false; return false;
return llvm::any_of(op->getResultTypes(), [](Type result_type) { return llvm::any_of(op->getResultTypes(), [](Type result_type) {
return !result_type.isa<RankedTensorType>(); return !result_type.isa<RankedTensorType>();

View File

@ -1 +1,2 @@
add_subdirectory(mlir) add_subdirectory(mlir)
add_subdirectory(backend)

View File

@ -0,0 +1,10 @@
configure_file(test.py test.py COPYONLY)
configure_file(test_config.py.in test_config.py)
find_package(PythonInterp 3 REQUIRED)
add_custom_target(run-onnx-backend-test
COMMAND ${PYTHON_EXECUTABLE}
${CMAKE_CURRENT_BINARY_DIR}/test.py)
add_dependencies(run-onnx-backend-test onnf)
add_dependencies(run-onnx-backend-test pyruntime)

View File

@ -3,46 +3,51 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import itertools
import os import os
import sys
import unittest import unittest
import onnx.backend.base import onnx.backend.base
import onnx.backend.test import onnx.backend.test
from onnx.backend.base import Device, DeviceType from onnx.backend.base import Device, DeviceType
import onnx.shape_inference
import onnx.version_converter
import subprocess import subprocess
import test_config
VERBOSE = bool(os.environ.get("VERBOSE"))
CXX = test_config.CXX_PATH
ONNF = os.path.join(test_config.ONNF_BUILD_PATH, "bin/onnf")
LLC = os.path.join(test_config.LLVM_PROJ_BUILD_PATH, "bin/llc")
# Make lib folder under build directory visible in PYTHONPATH
doc_check_base_dir = os.path.dirname(os.path.realpath(__file__))
RUNTIME_DIR = os.path.join(test_config.ONNF_BUILD_PATH, "lib")
sys.path.append(RUNTIME_DIR)
from pyruntime import ExecutionSession from pyruntime import ExecutionSession
CXX = os.getenv('CXX')
ONNF = os.getenv('ONNF') def execute_commands(cmds):
LLC = os.getenv('LLC') if (VERBOSE):
RT_DIR = os.getenv('RT_DIR') print(" ".join(cmds))
assert CXX and ONNF and LLC and RT_DIR, "tools path not set" subprocess.run(cmds, stdout=subprocess.PIPE)
class DummyBackend(onnx.backend.base.Backend): class DummyBackend(onnx.backend.base.Backend):
@classmethod @classmethod
def prepare( def prepare(cls, model, device='CPU', **kwargs):
cls,
model,
device='CPU',
**kwargs
):
super(DummyBackend, cls).prepare(model, device, **kwargs) super(DummyBackend, cls).prepare(model, device, **kwargs)
# Save model to disk as temp_model.onnx. # Save model to disk as temp_model.onnx.
onnx.save(model, "temp_model.onnx") onnx.save(model, "temp_model.onnx")
# Call frontend to process temp_model.onnx, bit code will be generated. # Call frontend to process temp_model.onnx, bit code will be generated.
subprocess.run([ONNF, "temp_model.onnx"], stdout=subprocess.PIPE) execute_commands([ONNF, "temp_model.onnx"])
# Call llc to generate object file from bitcode. # Call llc to generate object file from bitcode.
subprocess.run([LLC, "-filetype=obj", "model.bc"], execute_commands(
stdout=subprocess.PIPE) [LLC, "-filetype=obj", "-relocation-model=pic", "model.bc"])
# Generate shared library from object file, linking with c runtime. # Generate shared library from object file, linking with c runtime.
subprocess.run([ execute_commands([
CXX, "-shared", "model.o", "-o", "model.so", "-L" + RT_DIR, CXX, "-shared", "-fPIC", "model.o", "-o", "model.so",
"-lcruntime" "-L" + RUNTIME_DIR, "-lcruntime"
], ])
stdout=subprocess.PIPE)
return ExecutionSession("./model.so", "_dyn_entry_point_main_graph") return ExecutionSession("./model.so", "_dyn_entry_point_main_graph")
@classmethod @classmethod
@ -125,6 +130,14 @@ test_to_enable = [
"test_sigmoid_cpu", "test_sigmoid_cpu",
"test_sigmoid_example_cpu", "test_sigmoid_example_cpu",
# Softmax Op:
"test_softmax_axis_0_cpu",
"test_softmax_axis_1_cpu",
"test_softmax_axis_2_cpu",
"test_softmax_default_axis_cpu",
"test_softmax_example_cpu",
"test_softmax_large_number_cpu",
# Sum Op: # Sum Op:
#"test_sum_example_cpu", <- error #"test_sum_example_cpu", <- error
"test_sum_one_input_cpu", "test_sum_one_input_cpu",
@ -140,18 +153,15 @@ import inspect
all_tests = inspect.getmembers( all_tests = inspect.getmembers(
backend_test.test_cases["OnnxBackendNodeModelTest"]) backend_test.test_cases["OnnxBackendNodeModelTest"])
all_test_names = list(map(lambda x: x[0], all_tests)) all_test_names = list(map(lambda x: x[0], all_tests))
# Ensure that test names specified in test_to_enable actually exist.
for test_name in test_to_enable: for test_name in test_to_enable:
assert test_name in all_test_names, "test name {} not found".format(test_name) assert test_name in all_test_names, "test name {} not found, it is likely "
"that you may have misspelled the test name or the specified test does not "
"exist in the version of onnx package you installed.".format(
test_name)
backend_test.include(r"^{}$".format(test_name)) backend_test.include(r"^{}$".format(test_name))
def tearDownModule():
print()
print("*" * 40)
print("A total of {} tests should have run".format(len(test_to_enable)))
print("*" * 40)
# import all test cases at global scope to make them visible to python.unittest # import all test cases at global scope to make them visible to python.unittest
globals().update(backend_test.test_cases) globals().update(backend_test.test_cases)

View File

@ -0,0 +1,3 @@
ONNF_BUILD_PATH = "@CMAKE_BINARY_DIR@"
LLVM_PROJ_BUILD_PATH = "@LLVM_PROJ_BUILD@"
CXX_PATH = "@CMAKE_CXX_COMPILER@"

View File

@ -1,12 +1,12 @@
// RUN: onnf-opt %s -mlir-print-op-generic | FileCheck -check-prefix=GENERIC %s // RUN: onnf-opt %s -mlir-print-op-generic | FileCheck -check-prefix=GENERIC %s
// RUN: onnf-opt %s | FileCheck %s // RUN: onnf-opt %s | FileCheck %s
// GENERIC-DAG: #{{.*}} = () -> (0) // GENERIC-DAG: #{{.*}} = affine_map<() -> (0)>
// GENERIC-DAG: #{{.*}} = () -> (10) // GENERIC-DAG: #{{.*}} = affine_map<() -> (10)>
// GENERIC-DAG: #{{.*}} = () -> (1) // GENERIC-DAG: #{{.*}} = affine_map<() -> (1)>
// GENERIC-DAG: #{{.*}} = () -> (11) // GENERIC-DAG: #{{.*}} = affine_map<() -> (11)>
// GENERIC-DAG: #{{.*}} = (d0, d1) -> (d0 - d1) // GENERIC-DAG: #{{.*}} = affine_map<(d0, d1) -> (d0 - d1)>
// GENERIC-DAG: #{{.*}} = (d0, d1) -> (d0 + d1) // GENERIC-DAG: #{{.*}} = affine_map<(d0, d1) -> (d0 + d1)>
func @simple_iterate(%N : index) { func @simple_iterate(%N : index) {
%ii, %ij, %ik = krnl.define_loops 3 %ii, %ij, %ik = krnl.define_loops 3
@ -55,18 +55,18 @@ func @affine_map_bound(%N : index) {
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( { // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
// GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index): // GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index):
// CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10, %{{.*}} -> %{{.*}} = 0 to 10) { // CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10, %{{.*}} -> %{{.*}} = 0 to 10) {
krnl.iterate(%oi, %oj) with (%ii -> %i = ()->(0)() to ()->(10)(), %ij -> %j = 0 to 10) { krnl.iterate(%oi, %oj) with (%ii -> %i = affine_map<()->(0)>() to affine_map<()->(10)>(), %ij -> %j = 0 to 10) {
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( { // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
// GENERIC-NEXT: ^bb0(%{{.*}}: index): // GENERIC-NEXT: ^bb0(%{{.*}}: index):
// CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = #{{.*}}(%{{.*}}, %{{.*}}) to #{{.*}}(%{{.*}}, %{{.*}})) { // CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = #{{.*}}(%{{.*}}, %{{.*}}) to #{{.*}}(%{{.*}}, %{{.*}})) {
krnl.iterate(%ok) with (%ik -> %k = (d0, d1)->(d0 - d1)(%i, %j) to (d0, d1)->(d0 + d1)(%i, %j)) { krnl.iterate(%ok) with (%ik -> %k = affine_map<(d0, d1)->(d0 - d1)>(%i, %j) to affine_map<(d0, d1)->(d0 + d1)>(%i, %j)) {
} }
// GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( { // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
// GENERIC-NEXT: ^bb0(%{{.*}}: index): // GENERIC-NEXT: ^bb0(%{{.*}}: index):
// CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = max #map{{.*}}(%{{.*}}, %{{.*}}) to min #map{{.*}}(%{{.*}}, %{{.*}})[%{{.*}}]) { // CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = max #map{{.*}}(%{{.*}}, %{{.*}}) to min #map{{.*}}(%{{.*}}, %{{.*}})[%{{.*}}]) {
krnl.iterate(%ok) with (%ik -> %k = max (d0, d1)->(d0 - d1, 0)(%i, %j) to min (d0, d1)[s0]->(d0 + d1, s0)(%i, %j)[%N]) { krnl.iterate(%ok) with (%ik -> %k = max affine_map<(d0, d1)->(d0 - d1, 0)>(%i, %j) to min affine_map<(d0, d1)[s0]->(d0 + d1, s0)>(%i, %j)[%N]) {
} }
} }

View File

@ -385,7 +385,7 @@ func @test_min(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*
} }
func @test_elu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { func @test_elu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Elu"(%arg0) {Elu.alpha=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32> %0 = "onnx.Elu"(%arg0) {alpha=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_elu // CHECK-LABEL: test_elu
@ -411,7 +411,7 @@ func @test_elu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
} }
func @test_leakyrelu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { func @test_leakyrelu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.LeakyRelu"(%arg0) {LeakyRelu.alpha=1.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32> %0 = "onnx.LeakyRelu"(%arg0) {alpha=1.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_leakyrelu // CHECK-LABEL: test_leakyrelu
@ -434,7 +434,7 @@ func @test_leakyrelu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
} }
func @test_selu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { func @test_selu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Selu"(%arg0) {Selu.alpha=1.0:f32, Selu.gamma=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32> %0 = "onnx.Selu"(%arg0) {alpha=1.0:f32, gamma=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_selu // CHECK-LABEL: test_selu
@ -461,7 +461,7 @@ func @test_selu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
} }
func @test_hardsigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { func @test_hardsigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.HardSigmoid"(%arg0) {HardSigmoid.alpha=1.0:f32, HardSigmoid.beta=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32> %0 = "onnx.HardSigmoid"(%arg0) {alpha=1.0:f32, beta=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_hardsigmoid // CHECK-LABEL: test_hardsigmoid
@ -533,3 +533,49 @@ func @test_add_with_broadcasting(%arg0 : tensor<?xf32>, %arg1 : tensor<?x10xf32>
// CHECK: } // CHECK: }
// CHECK: return [[RES]] : memref<?x10xf32> // CHECK: return [[RES]] : memref<?x10xf32>
} }
func @test_softmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Softmax"(%arg0) {axis=1:i32} : (tensor<10x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_softmax
// CHECK: [[MAX:%.+]] = alloc() : memref<f32>
// CHECK: [[SUM:%.+]] = alloc() : memref<f32>
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32>
// CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32
// CHECK: [[CST_0:%.+]] = constant 0xFF800000 : f32
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, %3#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to 10) {
// CHECK: store [[CST]], [[SUM]][] : memref<f32>
// CHECK: store [[CST_0]], [[MAX]][] : memref<f32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load [[MAX]][] : memref<f32>
// CHECK: [[LOAD2:%.+]] = load %arg0[%arg1, %arg2] : memref<10x10xf32>
// CHECK: [[COND:%.+]] = cmpf "ogt", [[LOAD1]], [[LOAD2]] : f32
// CHECK: [[SELECT:%.+]] = select [[COND]], [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[SELECT]], [[MAX]][] : memref<f32>
// CHECK: }
// CHECK: %5 = load [[MAX]][] : memref<f32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD1]] = load [[SUM]][] : memref<f32>
// CHECK: [[LOAD2]] = load %arg0[%arg1, %arg2] : memref<10x10xf32>
// CHECK: [[SUB:%.+]] = subf [[LOAD2]], %5 : f32
// CHECK: [[EXP:%.+]] = exp [[SUB]] : f32
// CHECK: [[ADD:%.+]] = addf [[LOAD1]], [[EXP]] : f32
// CHECK: store [[ADD]], [[SUM]][] : memref<f32>
// CHECK: store %10, [[RES]][%arg1, %arg2] : memref<10x10xf32>
// CHECK: }
// CHECK: %6 = load [[SUM]][] : memref<f32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD1]] = load [[RES]][%arg1, %arg2] : memref<10x10xf32>
// CHECK: [[DIV:%.+]] = divf [[LOAD1]], %6 : f32
// CHECK: store [[DIV]], [[RES]][%arg1, %arg2] : memref<10x10xf32>
// CHECK: }
// CHECK: }
// CHECK: dealloc [[SUM]] : memref<f32>
// CHECK: dealloc [[MAX]] : memref<f32>
// CHECK: return [[RES]] : memref<10x10xf32>
}

View File

@ -648,8 +648,8 @@ func @test_min_min(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tens
} }
func @test_elu_elu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { func @test_elu_elu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Elu"(%arg0) {Elu.alpha=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32> %0 = "onnx.Elu"(%arg0) {alpha=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
%1 = "onnx.Elu"(%0) {Elu.alpha=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32> %1 = "onnx.Elu"(%0) {alpha=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> () "std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_elu_elu // CHECK-LABEL: test_elu_elu
@ -701,8 +701,8 @@ func @test_elu_elu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
} }
func @test_leakyrelu_leakyrelu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { func @test_leakyrelu_leakyrelu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.LeakyRelu"(%arg0) {LeakyRelu.alpha=1.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32> %0 = "onnx.LeakyRelu"(%arg0) {alpha=1.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
%1 = "onnx.LeakyRelu"(%0) {LeakyRelu.alpha=1.0:f32} : (tensor<*xf32>) -> tensor<*xf32> %1 = "onnx.LeakyRelu"(%0) {alpha=1.0:f32} : (tensor<*xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> () "std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_leakyrelu_leakyrelu // CHECK-LABEL: test_leakyrelu_leakyrelu
@ -748,8 +748,8 @@ func @test_leakyrelu_leakyrelu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
} }
func @test_selu_selu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { func @test_selu_selu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Selu"(%arg0) {Selu.alpha=1.0:f32, Selu.gamma=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32> %0 = "onnx.Selu"(%arg0) {alpha=1.0:f32, gamma=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
%1 = "onnx.Selu"(%0) {Selu.alpha=1.0:f32, Selu.gamma=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32> %1 = "onnx.Selu"(%0) {alpha=1.0:f32, gamma=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> () "std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_selu_selu // CHECK-LABEL: test_selu_selu
@ -803,8 +803,8 @@ func @test_selu_selu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
} }
func @test_hardsigmoid_hardsigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { func @test_hardsigmoid_hardsigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.HardSigmoid"(%arg0) {HardSigmoid.alpha=1.0:f32, HardSigmoid.beta=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32> %0 = "onnx.HardSigmoid"(%arg0) {alpha=1.0:f32, beta=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
%1 = "onnx.HardSigmoid"(%0) {HardSigmoid.alpha=1.0:f32, HardSigmoid.beta=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32> %1 = "onnx.HardSigmoid"(%0) {alpha=1.0:f32, beta=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> () "std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_hardsigmoid_hardsigmoid // CHECK-LABEL: test_hardsigmoid_hardsigmoid

View File

@ -11,8 +11,26 @@ func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
// CHECK: return [[RES]] : tensor<32x1x5x5xf32> // CHECK: return [[RES]] : tensor<32x1x5x5xf32>
} }
// CHECK-LABEL: test_default_transpose
// CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32>
// CHECK: return [[RES]] : tensor<32x1x5x5xf32>
/// Test shape inference for transposition when perm attribute is specified.
func @test_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
%0 = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
}
// CHECK-LABEL: test_transpose
// CHECK: [[RES_ATTR:%.+]] = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<1x5x32x5xf32>
// CHECK: return [[RES_ATTR]] : tensor<1x5x32x5xf32>
//===----------------------------------------------------------------------===//
/// Test the shape inferencing scheme for the matmul operation. /// Test the shape inferencing scheme for the matmul operation.
//===----------------------------------------------------------------------===//
/// MatMul: 1-D x 1-D /// MatMul: 1-D x 1-D
func @test_matmul_1(%arg0 : tensor<32xf32>, %arg1 : tensor<32xf32>) -> tensor<*xf32> { func @test_matmul_1(%arg0 : tensor<32xf32>, %arg1 : tensor<32xf32>) -> tensor<*xf32> {
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32xf32>) -> tensor<*xf32> %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
@ -23,6 +41,7 @@ func @test_matmul_1(%arg0 : tensor<32xf32>, %arg1 : tensor<32xf32>) -> tensor<*x
} }
/// MatMul: K-D x 2-D (K > 2) /// MatMul: K-D x 2-D (K > 2)
func @test_matmul_2(%arg0 : tensor<16x?x64x42xf32>, %arg1 : tensor<42x32xf32>) -> tensor<*xf32> { func @test_matmul_2(%arg0 : tensor<16x?x64x42xf32>, %arg1 : tensor<42x32xf32>) -> tensor<*xf32> {
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x42xf32>, tensor<42x32xf32>) -> tensor<*xf32> %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x42xf32>, tensor<42x32xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
@ -33,6 +52,7 @@ func @test_matmul_2(%arg0 : tensor<16x?x64x42xf32>, %arg1 : tensor<42x32xf32>) -
} }
/// MatMul: 2-D x K-D (K > 2) /// MatMul: 2-D x K-D (K > 2)
func @test_matmul_3(%arg0 : tensor<64x42xf32>, %arg1 : tensor<16x?x42x32xf32>) -> tensor<*xf32> { func @test_matmul_3(%arg0 : tensor<64x42xf32>, %arg1 : tensor<16x?x42x32xf32>) -> tensor<*xf32> {
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor<16x?x42x32xf32>) -> tensor<*xf32> %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor<16x?x42x32xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
@ -43,6 +63,7 @@ func @test_matmul_3(%arg0 : tensor<64x42xf32>, %arg1 : tensor<16x?x42x32xf32>) -
} }
/// MatMul: 2-D x K-D (K > 2) /// MatMul: 2-D x K-D (K > 2)
func @test_matmul_4(%arg0 : tensor<64x42xf32>, %arg1 : tensor<?x?x?x?xf32>) -> tensor<*xf32> { func @test_matmul_4(%arg0 : tensor<64x42xf32>, %arg1 : tensor<?x?x?x?xf32>) -> tensor<*xf32> {
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor<?x?x?x?xf32>) -> tensor<*xf32> %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor<?x?x?x?xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
@ -53,6 +74,7 @@ func @test_matmul_4(%arg0 : tensor<64x42xf32>, %arg1 : tensor<?x?x?x?xf32>) -> t
} }
/// MatMul: K1-D x K2-D (K1 > 2, K2 > 2) /// MatMul: K1-D x K2-D (K1 > 2, K2 > 2)
func @test_matmul_5(%arg0 : tensor<16x?x?x42xf32>, %arg1 : tensor<32x?x64x42x32xf32>) -> tensor<*xf32> { func @test_matmul_5(%arg0 : tensor<16x?x?x42xf32>, %arg1 : tensor<32x?x64x42x32xf32>) -> tensor<*xf32> {
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x?x42xf32>, tensor<32x?x64x42x32xf32>) -> tensor<*xf32> %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x?x42xf32>, tensor<32x?x64x42x32xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
@ -63,6 +85,7 @@ func @test_matmul_5(%arg0 : tensor<16x?x?x42xf32>, %arg1 : tensor<32x?x64x42x32x
} }
/// MatMul: 1-D x 2-D /// MatMul: 1-D x 2-D
func @test_matmul_6(%arg0 : tensor<32xf32>, %arg1 : tensor<32x64xf32>) -> tensor<*xf32> { func @test_matmul_6(%arg0 : tensor<32xf32>, %arg1 : tensor<32x64xf32>) -> tensor<*xf32> {
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32x64xf32>) -> tensor<*xf32> %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32x64xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
@ -73,6 +96,7 @@ func @test_matmul_6(%arg0 : tensor<32xf32>, %arg1 : tensor<32x64xf32>) -> tensor
} }
/// MatMul: 2-D x 1-D /// MatMul: 2-D x 1-D
func @test_matmul_7(%arg0 : tensor<32x64xf32>, %arg1 : tensor<64xf32>) -> tensor<*xf32> { func @test_matmul_7(%arg0 : tensor<32x64xf32>, %arg1 : tensor<64xf32>) -> tensor<*xf32> {
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32> %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
@ -83,6 +107,7 @@ func @test_matmul_7(%arg0 : tensor<32x64xf32>, %arg1 : tensor<64xf32>) -> tensor
} }
/// MatMul: 2-D x 2-D /// MatMul: 2-D x 2-D
func @test_matmul_8(%arg0 : tensor<32x64xf32>, %arg1 : tensor<64x128xf32>) -> tensor<*xf32> { func @test_matmul_8(%arg0 : tensor<32x64xf32>, %arg1 : tensor<64x128xf32>) -> tensor<*xf32> {
%0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64x128xf32>) -> tensor<*xf32> %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64x128xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()

1
third_party/variant vendored Submodule

@ -0,0 +1 @@
Subproject commit 3c7fc8266bb46046b42c2dc2663f9f505f0cec28

View File

@ -9,4 +9,5 @@ cmake -G Ninja ../llvm \
-DLLVM_ENABLE_ASSERTIONS=ON \ -DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_ENABLE_RTTI=ON -DLLVM_ENABLE_RTTI=ON
cmake --build . --target check-mlir -- ${MAKEFLAGS} cmake --build . --target
cmake --build . --target check-mlir

View File

@ -7,4 +7,5 @@ cmake ..
cmake --build . --target onnf cmake --build . --target onnf
# Run FileCheck tests: # Run FileCheck tests:
export LIT_OPTS=-v
cmake --build . --target check-mlir-lit cmake --build . --target check-mlir-lit