Chentong319 attribute with variant (#25)

* change the read-in of attribute, using variant

* Use backported variant.

* Reduce code duplication.

* 1. Make array attribute parsing more clear.
2. int -> int64_t.

* 1. Fix how array attributes are imported.

* 1. Fix clang-tidy warnings.

* 1. Nit: fix clang-tidy warnings.

* Fix MaxPool node construction.

* Fix call to MaxPool.

* Comment out backend tests that fail.

* Add path to variant submodule to enable include file detection.

* Allow unused argument to avoid special casing generator.

* Address attribute related e2e test failures for Hard sigmoid,Elu,LeakyRelu,Selu,Softmax

Co-authored-by: chentong319 <chentong@us.ibm.com>
Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
Tian Jin 2020-01-21 19:36:21 -07:00 committed by GitHub
parent 0231bb83a2
commit 51b0f4c9dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 375 additions and 655 deletions

3
.gitmodules vendored
View File

@ -7,3 +7,6 @@
[submodule "third_party/pybind11"] [submodule "third_party/pybind11"]
path = third_party/pybind11 path = third_party/pybind11
url = https://github.com/pybind/pybind11.git url = https://github.com/pybind/pybind11.git
[submodule "third_party/variant"]
path = third_party/variant
url = git@github.com:mpark/variant.git

View File

@ -22,6 +22,7 @@ include(MLIR.cmake)
add_subdirectory(third_party/onnx) add_subdirectory(third_party/onnx)
add_subdirectory(third_party/benchmark) add_subdirectory(third_party/benchmark)
add_subdirectory(third_party/pybind11) add_subdirectory(third_party/pybind11)
add_subdirectory(third_party/variant)
set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_STANDARD 14)
add_subdirectory(src) add_subdirectory(src)

View File

@ -7,8 +7,9 @@ add_library(builder
target_include_directories(builder PRIVATE ${CMAKE_SOURCE_DIR}) target_include_directories(builder PRIVATE ${CMAKE_SOURCE_DIR})
target_include_directories(builder PRIVATE ${CMAKE_BINARY_DIR}) target_include_directories(builder PRIVATE ${CMAKE_BINARY_DIR})
target_link_libraries(builder compiler onnx ${MLIRLibs} curses) target_link_libraries(builder compiler onnx ${MLIRLibs} curses mpark_variant)
target_include_directories(builder target_include_directories(builder
PRIVATE PRIVATE
${CMAKE_SOURCE_DIR}/third_party/onnx ${CMAKE_SOURCE_DIR}/third_party/onnx
${CMAKE_SOURCE_DIR}/third_party/variant
${CMAKE_SOURCE_DIR}) ${CMAKE_SOURCE_DIR})

View File

@ -14,11 +14,16 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include <map>
#include <numeric> #include <numeric>
#include <regex> #include <regex>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <map>
// Using backported variant.
// bstd = backported standard library.
#include <mpark/variant.hpp>
namespace bstd = mpark;
#include "mlir/Analysis/Verifier.h" #include "mlir/Analysis/Verifier.h"
#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Dialect/StandardOps/Ops.h"
@ -42,8 +47,8 @@
namespace onnf { namespace onnf {
namespace { namespace {
void replaceAll( void replaceAll(std::string &str, const std::string &from,
std::string& str, const std::string& from, const std::string& to) { const std::string &to) {
if (from.empty()) if (from.empty())
return; return;
size_t start_pos = 0; size_t start_pos = 0;
@ -71,7 +76,7 @@ struct OnnxOnnfSymbolMapping {
* @param name onnx tensor name. * @param name onnx tensor name.
* @return onnf tensor corresponding to `name`. * @return onnf tensor corresponding to `name`.
*/ */
mlir::Value GetTensorByOnnxName(std::string name) { mlir::Value GetTensorByOnnxName(const std::string &name) {
assert(onnx_name2onnf_tensor.find(legalize_name(name)) != assert(onnx_name2onnf_tensor.find(legalize_name(name)) !=
onnx_name2onnf_tensor.end() && onnx_name2onnf_tensor.end() &&
"Tensor not found"); "Tensor not found");
@ -83,7 +88,7 @@ struct OnnxOnnfSymbolMapping {
* @param name onnx tensor name. * @param name onnx tensor name.
* @param tensor MLIR Value pointer. * @param tensor MLIR Value pointer.
*/ */
void AddMapping(std::string name, mlir::Value tensor) { void AddMapping(const std::string &name, mlir::Value tensor) {
assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 && assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 &&
"Tensor already exists."); "Tensor already exists.");
onnx_name2onnf_tensor.emplace(legalize_name(name), tensor); onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
@ -124,8 +129,8 @@ private:
// Convert type to MLIR type. // Convert type to MLIR type.
// A complete list of types can be found in: // A complete list of types can be found in:
// <onnf-build-folder>/third_party/onnx/onnx/onnx.pb.h // <onnf-build-folder>/third_party/onnx/onnx/onnx.pb.h
mlir::Type TypeConvert(onnx::TensorProto_DataType intype) { mlir::Type convertONNXTypeToMLIRType(onnx::TensorProto_DataType onnxType) {
switch (intype) { switch (onnxType) {
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16: case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16:
return builder_.getF16Type(); return builder_.getF16Type();
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT: case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
@ -169,8 +174,8 @@ private:
for (int i = 0; i < shape_proto.dim_size(); i++) { for (int i = 0; i < shape_proto.dim_size(); i++) {
if (shape_proto.dim()[i].dim_value()) { if (shape_proto.dim()[i].dim_value()) {
int dim_numeric_size = shape_proto.dim()[i].dim_value(); int dim_numeric_size = shape_proto.dim()[i].dim_value();
assert( assert(dim_numeric_size != 0 &&
dim_numeric_size != 0 && "Parsed an input tensor with a dimension size of zero"); "Parsed an input tensor with a dimension size of zero");
if (dim_numeric_size > 0) { if (dim_numeric_size > 0) {
dims.push_back(dim_numeric_size); dims.push_back(dim_numeric_size);
} else { // If dim_value < 0, then dim is parametric. } else { // If dim_value < 0, then dim is parametric.
@ -184,7 +189,7 @@ private:
} }
mlir::Type elementType = mlir::Type elementType =
TypeConvert(input.type().tensor_type().elem_type()); convertONNXTypeToMLIRType(input.type().tensor_type().elem_type());
llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size()); llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size());
arg_types.emplace_back( arg_types.emplace_back(
mlir::RankedTensorType::get(tensor_dims, elementType)); mlir::RankedTensorType::get(tensor_dims, elementType));
@ -200,288 +205,111 @@ private:
void ImportInputTensorSymbol(const onnx::ValueInfoProto &input, void ImportInputTensorSymbol(const onnx::ValueInfoProto &input,
mlir::Value symbol) { mlir::Value symbol) {
auto input_tensor_legalized_name = legalize_name(input.name()); auto input_tensor_legalized_name = legalize_name(input.name());
assert( assert(!frontend_symbols_.ContainKey(input_tensor_legalized_name) &&
!frontend_symbols_.ContainKey(input_tensor_legalized_name) &&
"Found duplicate legalized input tensor names."); "Found duplicate legalized input tensor names.");
frontend_symbols_.AddMapping(input_tensor_legalized_name, symbol); frontend_symbols_.AddMapping(input_tensor_legalized_name, symbol);
} }
template <typename T> typedef bstd::variant<int64_t, std::vector<int64_t>, float,
T get_attr_generic(onnx::NodeProto &node, std::string name, std::vector<float>, std::string,
std::function<T(onnx::AttributeProto &)> attr_getter, std::vector<std::string>>
T default_val) { AttrValueType;
struct ONNXAttrVisitor {
ONNXAttrVisitor(std::string name, mlir::OpBuilder &builder)
: _builder(builder), _name(std::move(name)) {}
// Op builder.
mlir::OpBuilder &_builder;
// Name of the attribute being inspected.
std::string _name;
mlir::NamedAttribute operator()(int64_t const &r) {
auto val = _builder.getI32IntegerAttr(r);
return _builder.getNamedAttr(_name, val);
}
mlir::NamedAttribute operator()(std::vector<int64_t> const &ints) {
auto val = _builder.getI64ArrayAttr(ints);
return _builder.getNamedAttr(_name, val);
}
mlir::NamedAttribute operator()(float const &r) {
auto val = _builder.getF32FloatAttr(r);
return _builder.getNamedAttr(_name, val);
}
mlir::NamedAttribute operator()(std::vector<float> const &floats) {
auto val = _builder.getF32ArrayAttr(floats);
return _builder.getNamedAttr(_name, val);
}
mlir::NamedAttribute operator()(std::string const &s) {
auto val = _builder.getStringAttr(s);
return _builder.getNamedAttr(_name, val);
}
mlir::NamedAttribute operator()(std::vector<std::string> const &r) {
assert(false && "type of attribute value is not implemented");
auto val = _builder.getI32IntegerAttr(1);
return _builder.getNamedAttr(_name, val);
};
};
mlir::NamedAttribute convertNameValuePairToNamedAttribute(
std::pair<std::string, AttrValueType> nameAndVal) {
auto visitor = ONNXAttrVisitor(nameAndVal.first, builder_);
return mpark::visit(visitor, nameAndVal.second);
}
static std::pair<std::string, AttrValueType>
convertAttributeProtoToNameValuePair(onnx::AttributeProto &attr) {
AttrValueType val;
switch (attr.type()) {
case onnx::AttributeProto::FLOAT:
return std::make_pair(attr.name(), AttrValueType(attr.f()));
case onnx::AttributeProto::INT:
return std::make_pair(attr.name(), AttrValueType(attr.i()));
case onnx::AttributeProto::STRING:
return std::make_pair(attr.name(), AttrValueType(attr.s()));
case onnx::AttributeProto::FLOATS:
val = AttrValueType(
std::vector<float>(attr.floats().begin(), attr.floats().end()));
return std::make_pair(attr.name(), val);
case onnx::AttributeProto::INTS:
val = AttrValueType(
std::vector<int64_t>(attr.ints().begin(), attr.ints().end()));
return std::make_pair(attr.name(), val);
default:
assert(false && "datatype for attribute is not implemented");
break;
}
}
std::vector<mlir::NamedAttribute> ImportNodeAttributes(
const onnx::NodeProto &node,
std::initializer_list<std::pair<std::string, AttrValueType>>
defaultAttrList) {
std::vector<mlir::NamedAttribute> attributes;
std::set<std::string> definedAttributeSet;
for (int i = 0; i < node.attribute_size(); ++i) { for (int i = 0; i < node.attribute_size(); ++i) {
auto attr = node.attribute(i); auto attr = node.attribute(i);
if (attr.name() == name) { auto nameValPair = convertAttributeProtoToNameValuePair(attr);
return attr_getter(attr); attributes.push_back(convertNameValuePairToNamedAttribute(nameValPair));
definedAttributeSet.insert(attr.name());
} }
for (const auto &defaultAttr : defaultAttrList) {
if (definedAttributeSet.find(defaultAttr.first) ==
definedAttributeSet.end())
attributes.push_back(convertNameValuePairToNamedAttribute(defaultAttr));
} }
return default_val; return attributes;
} }
template <typename T> void ImportNodeGeneric(const onnx::NodeProto &node) {
T get_attr_generic(onnx::NodeProto &node, std::string name,
std::function<T(onnx::AttributeProto &)> attr_getter) {
for (int i = 0; i < node.attribute_size(); ++i) {
auto attr = node.attribute(i);
if (attr.name() == name) {
return attr_getter(attr);
}
}
assert(false && "ONNX Node Attribute Not Found!");
}
auto get_attr_ints(onnx::NodeProto &node, std::string name,
std::vector<int> default_val) {
std::function<std::vector<int>(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) {
std::vector<int> ints(attr.ints_size());
std::copy(attr.ints().begin(), attr.ints().end(), ints.begin());
return ints;
};
auto r = get_attr_generic(node, name, attr_getter, default_val);
auto dataType =
mlir::RankedTensorType::get(r.size(), builder_.getIntegerType(32));
auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r));
auto aname = node.op_type() + "." + name;
auto attr_output = builder_.getNamedAttr(aname, attr_v);
return attr_output;
}
auto get_attr_ints(onnx::NodeProto &node, std::string name) {
std::function<std::vector<int>(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) {
std::vector<int> ints(attr.ints_size());
std::copy(attr.ints().begin(), attr.ints().end(), ints.begin());
return ints;
};
auto r = get_attr_generic(node, name, attr_getter);
auto dataType =
mlir::RankedTensorType::get(r.size(), builder_.getIntegerType(32));
auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r));
auto aname = node.op_type() + "." + name;
auto attr_output = builder_.getNamedAttr(aname, attr_v);
return attr_output;
}
auto get_attr_floats(onnx::NodeProto &node, std::string name) {
std::function<std::vector<float>(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) {
std::vector<float> floats(attr.floats_size());
std::copy(attr.floats().begin(), attr.floats().end(), floats.begin());
return floats;
};
auto r = get_attr_generic(node, name, attr_getter);
auto dataType =
mlir::RankedTensorType::get(r.size(), builder_.getF32Type());
auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r));
auto aname = node.op_type() + "." + name;
auto attr_output = builder_.getNamedAttr(aname, attr_v);
return attr_output;
}
auto get_attr_floats(onnx::NodeProto &node, std::string name,
std::vector<float> default_val) {
std::function<std::vector<float>(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) {
std::vector<float> floats(attr.floats_size());
std::copy(attr.floats().begin(), attr.floats().end(), floats.begin());
return floats;
};
auto r = get_attr_generic(node, name, attr_getter, default_val);
auto dataType =
mlir::RankedTensorType::get(r.size(), builder_.getF32Type());
auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r));
auto aname = node.op_type() + "." + name;
auto attr_output = builder_.getNamedAttr(aname, attr_v);
return attr_output;
}
auto get_attr_int(onnx::NodeProto &node, std::string name) {
std::function<int(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) { return attr.i(); };
int r = get_attr_generic(node, name, attr_getter);
auto attr_v = builder_.getI32IntegerAttr(r);
auto aname = node.op_type() + "." + name;
auto attr_output = builder_.getNamedAttr(aname, attr_v);
return attr_output;
}
auto get_attr_int(onnx::NodeProto &node, std::string name, int default_val) {
std::function<int(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) { return attr.i(); };
int r = get_attr_generic(node, name, attr_getter, default_val);
auto attr_v = builder_.getI32IntegerAttr(r);
auto aname = node.op_type() + "." + name;
auto attr_output = builder_.getNamedAttr(aname, attr_v);
return attr_output;
}
auto get_attr_float(onnx::NodeProto &node, std::string name) {
std::function<float(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) { return attr.f(); };
auto r = get_attr_generic(node, name, attr_getter);
auto attr_v = builder_.getF32FloatAttr(r);
auto aname = node.op_type() + "." + name;
return builder_.getNamedAttr(aname, attr_v);
}
auto get_attr_float(onnx::NodeProto &node, std::string name,
float default_val) {
std::function<float(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) { return attr.f(); };
auto r = get_attr_generic(node, name, attr_getter, default_val);
auto attr_v = builder_.getF32FloatAttr(r);
auto aname = node.op_type() + "." + name;
return builder_.getNamedAttr(aname, attr_v);
}
auto get_attr_string(onnx::NodeProto &node, std::string name) {
std::function<std::string(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) { return attr.s(); };
auto r = get_attr_generic(node, name, attr_getter);
auto attr_v = builder_.getStringAttr(r);
auto aname = node.op_type() + "." + name;
return builder_.getNamedAttr(aname, attr_v);
}
auto get_attr_string(onnx::NodeProto &node, std::string name,
std::string default_val) {
std::function<std::string(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) { return attr.s(); };
auto r = get_attr_generic(node, name, attr_getter, default_val);
auto attr_v = builder_.getStringAttr(r);
auto aname = node.op_type() + "." + name;
return builder_.getNamedAttr(aname, attr_v);
}
/*
auto get_attr_strings(onnx::NodeProto &node, std::string name) {
std::function<std::vector<std::string>(onnx::AttributeProto &)>
attr_getter =
[](onnx::AttributeProto &attr) {
std::vector<std::string> strings(attr.strings_size());
std::copy(attr.strings().begin(), attr.strings().end(),
strings.begin()); return strings;
};
auto r = get_attr_generic(node, name, attr_getter);
return r;
return builder_.getNamedAttr(aname, attr_v);
auto dataType =
mlir::RankedTensorType::get(r.size(), builder_.get???Type());
auto attr_v = mlir::DenseElementsAttr::get(dataType,
llvm::makeArrayRef(r)); auto aname = node.op_type() + "." + name; auto
attr_output = builder_.getNamedAttr(aname, attr_v); return attr_output;
}
*/
auto get_default_ints(std::string default_str) {
std::vector<int> r;
auto start = default_str.find("{");
while (true) {
auto end = default_str.find(",", start + 1);
if (end == std::string::npos) {
end = default_str.find("}", start + 1);
if (end != std::string::npos && end > start + 1) {
r.push_back(std::stoi(default_str.substr(start + 1, end)));
}
break;
} else {
r.push_back(std::stoi(default_str.substr(start + 1, end)));
}
start = end + 1;
}
return r;
}
auto get_default_floats(std::string default_str) {
std::vector<float> r;
auto start = default_str.find("{");
while (true) {
auto end = default_str.find(",", start + 1);
if (end == std::string::npos) {
end = default_str.find("}", start + 1);
if (end != std::string::npos && end > start + 1) {
r.push_back(std::stof(default_str.substr(start + 1, end)));
}
break;
} else {
r.push_back(std::stof(default_str.substr(start + 1, end)));
}
start = end + 1;
}
return r;
}
auto get_default_strings(std::string default_str) {
std::vector<std::string> r;
auto start = default_str.find("{");
while (true) {
auto end = default_str.find(",", start + 1);
if (end == std::string::npos) {
end = default_str.find("}", start + 1);
if (end != std::string::npos && end > start + 1) {
r.push_back(default_str.substr(start + 1, end));
}
break;
} else {
r.push_back(default_str.substr(start + 1, end));
}
start = end + 1;
}
return r;
}
onnx::TensorProto get_attr_tensor(onnx::NodeProto &node, std::string name) {
std::function<onnx::TensorProto(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) { return attr.t(); };
return get_attr_generic(node, name, attr_getter);
}
auto ImportNodeAttr(onnx::NodeProto node, std::string attr_name,
std::string type_name, std::string default_str) {
if (default_str == "") {
if (type_name == "int") {
return get_attr_int(node, attr_name);
} else if (type_name == "float") {
return get_attr_float(node, attr_name);
} else if (type_name == "str") {
return get_attr_string(node, attr_name);
} else if (type_name == "ints") {
return get_attr_ints(node, attr_name);
} else if (type_name == "floats") {
return get_attr_floats(node, attr_name);
} else {
assert(
false &&
"Got an empty initializer or initializer for this "
"datatype is not implemented. Something is wrong.");
}
} else {
// with default value
if (type_name == "int") {
return get_attr_int(node, attr_name, std::stoi(default_str));
} else if (type_name == "float") {
return get_attr_float(node, attr_name, std::stof(default_str));
} else if (type_name == "str") {
return get_attr_string(node, attr_name, default_str);
} else if (type_name == "ints") {
return get_attr_ints(node, attr_name, get_default_ints(default_str));
} else if (type_name == "floats") {
return get_attr_floats(node, attr_name,
get_default_floats(default_str));
} else {
assert(
false &&
"Got an empty initializer or initializer for this "
"datatype is not implemented. Something is wrong.");
}
}
}
void ImportNodeGeneric(onnx::NodeProto node) {
std::vector<mlir::Value> inputs; std::vector<mlir::Value> inputs;
for (auto item : node.input()) { for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) { if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
} }
@ -511,12 +339,12 @@ private:
* default} * default}
*/ */
template <typename T> template <typename T>
void ImportNodeOneOut( void
onnx::NodeProto node, int nIn, int nOut, ImportNodeOneOut(const onnx::NodeProto &node, int nIn, int nOut,
std::initializer_list<std::tuple<std::string, std::string, std::string>> std::initializer_list<std::pair<std::string, AttrValueType>>
attrs) { defaultAttrList) {
std::vector<mlir::Value> inputs; std::vector<mlir::Value> inputs;
for (auto item : node.input()) { for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) { if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
} }
@ -528,22 +356,7 @@ private:
mlir::UnrankedTensorType::get(builder_.getF32Type())); mlir::UnrankedTensorType::get(builder_.getF32Type()));
} }
std::vector<mlir::NamedAttribute> attributes; auto attributes = ImportNodeAttributes(node, defaultAttrList);
// for (auto [attr_name, attr_type, attr_default] : attrs) {
for (auto oneAttr : attrs) {
std::string attr_name;
std::string attr_type;
std::string attr_default;
std::tie(attr_name, attr_type, attr_default) = oneAttr;
if (attr_type != "") {
auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default);
attributes.push_back(attr);
} else {
// TODO: the attributes need special handling
// std::cout << "missing " << node.op_type() << " " << attr_name <<
// std::endl;
}
}
llvm::StringRef OpName = node.op_type(); llvm::StringRef OpName = node.op_type();
@ -559,11 +372,11 @@ private:
template <typename T> template <typename T>
void ImportNodeMultipleOuts( void ImportNodeMultipleOuts(
onnx::NodeProto node, int nIn, int nOut, const onnx::NodeProto &node, int nIn, int nOut,
std::initializer_list<std::tuple<std::string, std::string, std::string>> std::initializer_list<std::pair<std::string, AttrValueType>>
attrs) { defaultAttrList) {
std::vector<mlir::Value> inputs; std::vector<mlir::Value> inputs;
for (auto item : node.input()) { for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) { if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
} }
@ -575,21 +388,7 @@ private:
mlir::UnrankedTensorType::get(builder_.getF32Type())); mlir::UnrankedTensorType::get(builder_.getF32Type()));
} }
std::vector<mlir::NamedAttribute> attributes; auto attributes = ImportNodeAttributes(node, defaultAttrList);
for (auto oneAttr : attrs) {
std::string attr_name;
std::string attr_type;
std::string attr_default;
std::tie(attr_name, attr_type, attr_default) = oneAttr;
if (attr_type != "") {
auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default);
attributes.push_back(attr);
} else {
// TODO: the attributes need special handling
// std::cout << "missing " << node.op_type() << " " << attr_name <<
// std::endl;
}
}
llvm::StringRef OpName = node.op_type(); llvm::StringRef OpName = node.op_type();
@ -610,10 +409,10 @@ private:
* c++ does not allow template specialization inside a class scope * c++ does not allow template specialization inside a class scope
* a specialized function is used * a specialized function is used
*/ */
void ImportNodeConv( void
onnx::NodeProto node, int nOut, ImportNodeConv(onnx::NodeProto node, int nIn, int nOut,
std::initializer_list<std::tuple<std::string, std::string, std::string>> std::initializer_list<std::pair<std::string, AttrValueType>>
attrs) { defaultAttrList) {
// Conv has attribute dilations, kernel_shape, pads, the default value of // Conv has attribute dilations, kernel_shape, pads, the default value of
// which is determined by the shape of first argument. However, since the // which is determined by the shape of first argument. However, since the
// shape is unknown now, these attributes can be not generated auto // shape is unknown now, these attributes can be not generated auto
@ -627,29 +426,32 @@ private:
int nOps = node.input().size(); int nOps = node.input().size();
if (nOps == 2) if (nOps == 2)
ImportNodeOneOut<mlir::ONNXConvNoBiasOp>(node, nOps, nOut, attrs); ImportNodeOneOut<mlir::ONNXConvNoBiasOp>(
node, nOps, nOut, defaultAttrList);
else else
ImportNodeOneOut<mlir::ONNXConvOp>(node, nOps, nOut, attrs); ImportNodeOneOut<mlir::ONNXConvOp>(node, nOps, nOut, defaultAttrList);
} }
/*! /*!
* Special handle for MaxPool operations. * Special handle for MaxPool operations.
*/ */
void ImportNodeMaxPool( void ImportNodeMaxPool(
onnx::NodeProto node, int nIn, onnx::NodeProto node, int nIn, int nOut,
std::initializer_list<std::tuple<std::string, std::string, std::string>> std::initializer_list<std::pair<std::string, AttrValueType>>
attrs) { defaultAttrList) {
int nOuts = node.output().size(); int nOuts = node.output().size();
if (nOuts == 1) { if (nOuts == 1) {
ImportNodeOneOut<mlir::ONNXMaxPoolSingleOutOp>(node, nIn, nOuts, attrs); ImportNodeOneOut<mlir::ONNXMaxPoolSingleOutOp>(
node, nIn, nOuts, defaultAttrList);
} else { } else {
ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>(node, nIn, nOuts, attrs); ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>(
node, nIn, nOuts, defaultAttrList);
} }
} }
void ImportNode(onnx::NodeProto node) { void ImportNode(const onnx::NodeProto &node) {
std::vector<mlir::Value> inputs; std::vector<mlir::Value> inputs;
for (auto item : node.input()) { for (const auto &item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) { if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
} }
@ -689,8 +491,7 @@ private:
llvm::SmallVectorImpl<mlir::Type> &ret_types, llvm::SmallVectorImpl<mlir::Type> &ret_types,
llvm::SmallVectorImpl<mlir::Value> &ret_vals) { llvm::SmallVectorImpl<mlir::Value> &ret_vals) {
auto output_tensor_legalized_name = legalize_name(output.name()); auto output_tensor_legalized_name = legalize_name(output.name());
assert( assert(frontend_symbols_.ContainKey(output_tensor_legalized_name) &&
frontend_symbols_.ContainKey(output_tensor_legalized_name) &&
"Output tensor not found"); "Output tensor not found");
auto tensor_val = auto tensor_val =

View File

@ -16,13 +16,13 @@
}); });
}else if (OpName == "ArgMax") { }else if (OpName == "ArgMax") {
ImportNodeOneOut<mlir::ONNXArgMaxOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXArgMaxOp>(node, 1, 1, {
{"axis","int","0"} {"axis", 0}
,{"keepdims","int","1"} ,{"keepdims", 1}
}); });
}else if (OpName == "ArgMin") { }else if (OpName == "ArgMin") {
ImportNodeOneOut<mlir::ONNXArgMinOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXArgMinOp>(node, 1, 1, {
{"axis","int","0"} {"axis", 0}
,{"keepdims","int","1"} ,{"keepdims", 1}
}); });
}else if (OpName == "Asin") { }else if (OpName == "Asin") {
ImportNodeOneOut<mlir::ONNXAsinOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXAsinOp>(node, 1, 1, {
@ -38,25 +38,22 @@
}); });
}else if (OpName == "AveragePool") { }else if (OpName == "AveragePool") {
ImportNodeOneOut<mlir::ONNXAveragePoolOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXAveragePoolOp>(node, 1, 1, {
{"auto_pad","str","NOTSET"} {"auto_pad", "NOTSET"}
,{"ceil_mode","int","0"} ,{"ceil_mode", 0}
,{"count_include_pad","int","0"} ,{"count_include_pad", 0}
,{"kernel_shape","ints", ""} ,{"kernel_shape", std::vector<int64_t> {}}
,{"pads","", ""}
,{"strides","", ""}
}); });
}else if (OpName == "BatchNormalization") { }else if (OpName == "BatchNormalization") {
ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, 5, 5, { ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, 5, 5, {
{"epsilon","float","1e-05"} {"epsilon", (float)1e-05}
,{"momentum","float","0.9"} ,{"momentum", (float)0.9}
}); });
}else if (OpName == "BitShift") { }else if (OpName == "BitShift") {
ImportNodeOneOut<mlir::ONNXBitShiftOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXBitShiftOp>(node, 2, 1, {
{"direction","", ""}
}); });
}else if (OpName == "Cast") { }else if (OpName == "Cast") {
ImportNodeOneOut<mlir::ONNXCastOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXCastOp>(node, 1, 1, {
{"to","int", "0"} {"to", 0}
}); });
}else if (OpName == "Ceil") { }else if (OpName == "Ceil") {
ImportNodeOneOut<mlir::ONNXCeilOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXCeilOp>(node, 1, 1, {
@ -66,54 +63,35 @@
}); });
}else if (OpName == "Compress") { }else if (OpName == "Compress") {
ImportNodeOneOut<mlir::ONNXCompressOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXCompressOp>(node, 2, 1, {
{"axis","", ""}
}); });
}else if (OpName == "Concat") { }else if (OpName == "Concat") {
ImportNodeOneOut<mlir::ONNXConcatOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXConcatOp>(node, 1, 1, {
{"axis","int", "0"} {"axis", 0}
}); });
}else if (OpName == "ConcatFromSequence") { }else if (OpName == "ConcatFromSequence") {
ImportNodeOneOut<mlir::ONNXConcatFromSequenceOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXConcatFromSequenceOp>(node, 1, 1, {
{"axis","", ""} {"new_axis", 0}
,{"new_axis","int","0"}
}); });
}else if (OpName == "Constant") { }else if (OpName == "Constant") {
ImportNodeOneOut<mlir::ONNXConstantOp>(node, 0, 1, { ImportNodeOneOut<mlir::ONNXConstantOp>(node, 0, 1, {
{"sparse_value","", ""}
,{"value","", ""}
}); });
}else if (OpName == "ConstantOfShape") { }else if (OpName == "ConstantOfShape") {
ImportNodeOneOut<mlir::ONNXConstantOfShapeOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXConstantOfShapeOp>(node, 1, 1, {
{"value","", ""}
}); });
}else if (OpName == "Conv") { }else if (OpName == "Conv") {
ImportNodeConv(node, 1, { ImportNodeConv(node, 3, 1, {
{"auto_pad","str","NOTSET"} {"auto_pad", "NOTSET"}
,{"dilations","", ""} ,{"group", 1}
,{"group","int", "1"}
,{"kernel_shape","", ""}
,{"pads","", ""}
,{"strides","", ""}
}); });
}else if (OpName == "ConvInteger") { }else if (OpName == "ConvInteger") {
ImportNodeOneOut<mlir::ONNXConvIntegerOp>(node, 4, 1, { ImportNodeOneOut<mlir::ONNXConvIntegerOp>(node, 4, 1, {
{"auto_pad","str","NOTSET"} {"auto_pad", "NOTSET"}
,{"dilations","", ""} ,{"group", 1}
,{"group","int","1"}
,{"kernel_shape","", ""}
,{"pads","", ""}
,{"strides","", ""}
}); });
}else if (OpName == "ConvTranspose") { }else if (OpName == "ConvTranspose") {
ImportNodeOneOut<mlir::ONNXConvTransposeOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXConvTransposeOp>(node, 3, 1, {
{"auto_pad","str","NOTSET"} {"auto_pad", "NOTSET"}
,{"dilations","", ""} ,{"group", 1}
,{"group","int","1"}
,{"kernel_shape","", ""}
,{"output_padding","", ""}
,{"output_shape","", ""}
,{"pads","", ""}
,{"strides","", ""}
}); });
}else if (OpName == "Cos") { }else if (OpName == "Cos") {
ImportNodeOneOut<mlir::ONNXCosOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXCosOp>(node, 1, 1, {
@ -123,13 +101,12 @@
}); });
}else if (OpName == "CumSum") { }else if (OpName == "CumSum") {
ImportNodeOneOut<mlir::ONNXCumSumOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXCumSumOp>(node, 2, 1, {
{"exclusive","int","0"} {"exclusive", 0}
,{"reverse","int","0"} ,{"reverse", 0}
}); });
}else if (OpName == "DepthToSpace") { }else if (OpName == "DepthToSpace") {
ImportNodeOneOut<mlir::ONNXDepthToSpaceOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXDepthToSpaceOp>(node, 1, 1, {
{"blocksize","", ""} {"mode", "DCR"}
,{"mode","str","DCR"}
}); });
}else if (OpName == "DequantizeLinear") { }else if (OpName == "DequantizeLinear") {
ImportNodeOneOut<mlir::ONNXDequantizeLinearOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXDequantizeLinearOp>(node, 3, 1, {
@ -142,14 +119,14 @@
}); });
}else if (OpName == "Dropout") { }else if (OpName == "Dropout") {
ImportNodeMultipleOuts<mlir::ONNXDropoutOp>(node, 1, 2, { ImportNodeMultipleOuts<mlir::ONNXDropoutOp>(node, 1, 2, {
{"ratio","float","0.5"} {"ratio", (float)0.5}
}); });
}else if (OpName == "DynamicQuantizeLinear") { }else if (OpName == "DynamicQuantizeLinear") {
ImportNodeMultipleOuts<mlir::ONNXDynamicQuantizeLinearOp>(node, 1, 3, { ImportNodeMultipleOuts<mlir::ONNXDynamicQuantizeLinearOp>(node, 1, 3, {
}); });
}else if (OpName == "Elu") { }else if (OpName == "Elu") {
ImportNodeOneOut<mlir::ONNXEluOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXEluOp>(node, 1, 1, {
{"alpha","float","1.0"} {"alpha", (float)1.0}
}); });
}else if (OpName == "Equal") { }else if (OpName == "Equal") {
ImportNodeOneOut<mlir::ONNXEqualOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXEqualOp>(node, 2, 1, {
@ -165,50 +142,44 @@
}); });
}else if (OpName == "EyeLike") { }else if (OpName == "EyeLike") {
ImportNodeOneOut<mlir::ONNXEyeLikeOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXEyeLikeOp>(node, 1, 1, {
{"dtype","", ""} {"k", 0}
,{"k","int","0"}
}); });
}else if (OpName == "Flatten") { }else if (OpName == "Flatten") {
ImportNodeOneOut<mlir::ONNXFlattenOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXFlattenOp>(node, 1, 1, {
{"axis","int","1"} {"axis", 1}
}); });
}else if (OpName == "Floor") { }else if (OpName == "Floor") {
ImportNodeOneOut<mlir::ONNXFloorOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXFloorOp>(node, 1, 1, {
}); });
}else if (OpName == "GRU") { }else if (OpName == "GRU") {
ImportNodeMultipleOuts<mlir::ONNXGRUOp>(node, 6, 2, { ImportNodeMultipleOuts<mlir::ONNXGRUOp>(node, 6, 2, {
{"activation_alpha","", ""} {"direction", "forward"}
,{"activation_beta","", ""} ,{"linear_before_reset", 0}
,{"activations","", ""}
,{"clip","", ""}
,{"direction","str","forward"}
,{"hidden_size","", ""}
,{"linear_before_reset","int","0"}
}); });
}else if (OpName == "Gather") { }else if (OpName == "Gather") {
ImportNodeOneOut<mlir::ONNXGatherOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXGatherOp>(node, 2, 1, {
{"axis","int","0"} {"axis", 0}
}); });
}else if (OpName == "GatherElements") { }else if (OpName == "GatherElements") {
ImportNodeOneOut<mlir::ONNXGatherElementsOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXGatherElementsOp>(node, 2, 1, {
{"axis","int","0"} {"axis", 0}
}); });
}else if (OpName == "GatherND") { }else if (OpName == "GatherND") {
ImportNodeOneOut<mlir::ONNXGatherNDOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXGatherNDOp>(node, 2, 1, {
}); });
}else if (OpName == "Gemm") { }else if (OpName == "Gemm") {
ImportNodeOneOut<mlir::ONNXGemmOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXGemmOp>(node, 3, 1, {
{"alpha","float","1.0"} {"alpha", (float)1.0}
,{"beta","float","1.0"} ,{"beta", (float)1.0}
,{"transA","int","0"} ,{"transA", 0}
,{"transB","int","0"} ,{"transB", 0}
}); });
}else if (OpName == "GlobalAveragePool") { }else if (OpName == "GlobalAveragePool") {
ImportNodeOneOut<mlir::ONNXGlobalAveragePoolOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXGlobalAveragePoolOp>(node, 1, 1, {
}); });
}else if (OpName == "GlobalLpPool") { }else if (OpName == "GlobalLpPool") {
ImportNodeOneOut<mlir::ONNXGlobalLpPoolOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXGlobalLpPoolOp>(node, 1, 1, {
{"p","int","2"} {"p", 2}
}); });
}else if (OpName == "GlobalMaxPool") { }else if (OpName == "GlobalMaxPool") {
ImportNodeOneOut<mlir::ONNXGlobalMaxPoolOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXGlobalMaxPoolOp>(node, 1, 1, {
@ -218,53 +189,45 @@
}); });
}else if (OpName == "HardSigmoid") { }else if (OpName == "HardSigmoid") {
ImportNodeOneOut<mlir::ONNXHardSigmoidOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXHardSigmoidOp>(node, 1, 1, {
{"alpha","float","0.2"} {"alpha", (float)0.2}
,{"beta","float","0.5"} ,{"beta", (float)0.5}
}); });
}else if (OpName == "Hardmax") { }else if (OpName == "Hardmax") {
ImportNodeOneOut<mlir::ONNXHardmaxOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXHardmaxOp>(node, 1, 1, {
{"axis","int","1"} {"axis", 1}
}); });
}else if (OpName == "Identity") { }else if (OpName == "Identity") {
ImportNodeOneOut<mlir::ONNXIdentityOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXIdentityOp>(node, 1, 1, {
}); });
}else if (OpName == "If") { }else if (OpName == "If") {
ImportNodeOneOut<mlir::ONNXIfOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXIfOp>(node, 1, 1, {
{"else_branch","", ""}
,{"then_branch","", ""}
}); });
}else if (OpName == "InstanceNormalization") { }else if (OpName == "InstanceNormalization") {
ImportNodeOneOut<mlir::ONNXInstanceNormalizationOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXInstanceNormalizationOp>(node, 3, 1, {
{"epsilon","float","1e-05"} {"epsilon", (float)1e-05}
}); });
}else if (OpName == "IsInf") { }else if (OpName == "IsInf") {
ImportNodeOneOut<mlir::ONNXIsInfOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXIsInfOp>(node, 1, 1, {
{"detect_negative","int","1"} {"detect_negative", 1}
,{"detect_positive","int","1"} ,{"detect_positive", 1}
}); });
}else if (OpName == "IsNaN") { }else if (OpName == "IsNaN") {
ImportNodeOneOut<mlir::ONNXIsNaNOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXIsNaNOp>(node, 1, 1, {
}); });
}else if (OpName == "LRN") { }else if (OpName == "LRN") {
ImportNodeOneOut<mlir::ONNXLRNOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXLRNOp>(node, 1, 1, {
{"alpha","float","0.0001"} {"alpha", (float)0.0001}
,{"beta","float","0.75"} ,{"beta", (float)0.75}
,{"bias","float","1.0"} ,{"bias", (float)1.0}
,{"size","int", ""}
}); });
}else if (OpName == "LSTM") { }else if (OpName == "LSTM") {
ImportNodeMultipleOuts<mlir::ONNXLSTMOp>(node, 8, 3, { ImportNodeMultipleOuts<mlir::ONNXLSTMOp>(node, 8, 3, {
{"activation_alpha","", ""} {"direction", "forward"}
,{"activation_beta","", ""} ,{"input_forget", 0}
,{"activations","", ""}
,{"clip","", ""}
,{"direction","str","forward"}
,{"hidden_size","", ""}
,{"input_forget","int","0"}
}); });
}else if (OpName == "LeakyRelu") { }else if (OpName == "LeakyRelu") {
ImportNodeOneOut<mlir::ONNXLeakyReluOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXLeakyReluOp>(node, 1, 1, {
{"alpha","float","0.01"} {"alpha", (float)0.01}
}); });
}else if (OpName == "Less") { }else if (OpName == "Less") {
ImportNodeOneOut<mlir::ONNXLessOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXLessOp>(node, 2, 1, {
@ -274,24 +237,20 @@
}); });
}else if (OpName == "LogSoftmax") { }else if (OpName == "LogSoftmax") {
ImportNodeOneOut<mlir::ONNXLogSoftmaxOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXLogSoftmaxOp>(node, 1, 1, {
{"axis","int","1"} {"axis", 1}
}); });
}else if (OpName == "Loop") { }else if (OpName == "Loop") {
ImportNodeOneOut<mlir::ONNXLoopOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXLoopOp>(node, 3, 1, {
{"body","", ""}
}); });
}else if (OpName == "LpNormalization") { }else if (OpName == "LpNormalization") {
ImportNodeOneOut<mlir::ONNXLpNormalizationOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXLpNormalizationOp>(node, 1, 1, {
{"axis","int","-1"} {"axis", -1}
,{"p","int","2"} ,{"p", 2}
}); });
}else if (OpName == "LpPool") { }else if (OpName == "LpPool") {
ImportNodeOneOut<mlir::ONNXLpPoolOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXLpPoolOp>(node, 1, 1, {
{"auto_pad","str","NOTSET"} {"auto_pad", "NOTSET"}
,{"kernel_shape","", ""} ,{"p", 2}
,{"p","int","2"}
,{"pads","", ""}
,{"strides","", ""}
}); });
}else if (OpName == "MatMul") { }else if (OpName == "MatMul") {
ImportNodeOneOut<mlir::ONNXMatMulOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXMatMulOp>(node, 2, 1, {
@ -303,55 +262,47 @@
ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, {
}); });
}else if (OpName == "MaxPool") { }else if (OpName == "MaxPool") {
ImportNodeMaxPool(node, 1, { ImportNodeMaxPool(node, 1, 2, {
{"auto_pad","str","NOTSET"} {"auto_pad", "NOTSET"}
,{"ceil_mode","int","0"} ,{"ceil_mode", 0}
,{"dilations","", ""} ,{"kernel_shape", std::vector<int64_t> {}}
,{"kernel_shape","ints", ""} ,{"storage_order", 0}
,{"pads","", ""}
,{"storage_order","int","0"}
,{"strides","", ""}
}); });
}else if (OpName == "MaxRoiPool") { }else if (OpName == "MaxRoiPool") {
ImportNodeOneOut<mlir::ONNXMaxRoiPoolOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXMaxRoiPoolOp>(node, 2, 1, {
{"pooled_shape","", ""} {"spatial_scale", (float)1.0}
,{"spatial_scale","float","1.0"}
}); });
}else if (OpName == "MaxUnpool") { }else if (OpName == "MaxUnpool") {
ImportNodeOneOut<mlir::ONNXMaxUnpoolOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXMaxUnpoolOp>(node, 3, 1, {
{"kernel_shape","", ""}
,{"pads","", ""}
,{"strides","", ""}
}); });
}else if (OpName == "Mean") { }else if (OpName == "Mean") {
ImportNodeOneOut<mlir::ONNXMeanOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXMeanOp>(node, 1, 1, {
}); });
}else if (OpName == "MeanVarianceNormalization") { }else if (OpName == "MeanVarianceNormalization") {
ImportNodeOneOut<mlir::ONNXMeanVarianceNormalizationOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXMeanVarianceNormalizationOp>(node, 1, 1, {
{"axes","ints","{'0', '2', '3'}"} {"axes", std::vector<int64_t>{0, 2, 3}}
}); });
}else if (OpName == "Min") { }else if (OpName == "Min") {
ImportNodeOneOut<mlir::ONNXMinOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXMinOp>(node, 1, 1, {
}); });
}else if (OpName == "Mod") { }else if (OpName == "Mod") {
ImportNodeOneOut<mlir::ONNXModOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXModOp>(node, 2, 1, {
{"fmod","int","0"} {"fmod", 0}
}); });
}else if (OpName == "Mul") { }else if (OpName == "Mul") {
ImportNodeOneOut<mlir::ONNXMulOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXMulOp>(node, 2, 1, {
}); });
}else if (OpName == "Multinomial") { }else if (OpName == "Multinomial") {
ImportNodeOneOut<mlir::ONNXMultinomialOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXMultinomialOp>(node, 1, 1, {
{"dtype","int","6"} {"dtype", 6}
,{"sample_size","int","1"} ,{"sample_size", 1}
,{"seed","", ""}
}); });
}else if (OpName == "Neg") { }else if (OpName == "Neg") {
ImportNodeOneOut<mlir::ONNXNegOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXNegOp>(node, 1, 1, {
}); });
}else if (OpName == "NonMaxSuppression") { }else if (OpName == "NonMaxSuppression") {
ImportNodeOneOut<mlir::ONNXNonMaxSuppressionOp>(node, 5, 1, { ImportNodeOneOut<mlir::ONNXNonMaxSuppressionOp>(node, 5, 1, {
{"center_point_box","int","0"} {"center_point_box", 0}
}); });
}else if (OpName == "NonZero") { }else if (OpName == "NonZero") {
ImportNodeOneOut<mlir::ONNXNonZeroOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXNonZeroOp>(node, 1, 1, {
@ -361,7 +312,7 @@
}); });
}else if (OpName == "OneHot") { }else if (OpName == "OneHot") {
ImportNodeOneOut<mlir::ONNXOneHotOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXOneHotOp>(node, 3, 1, {
{"axis","int","-1"} {"axis", -1}
}); });
}else if (OpName == "Or") { }else if (OpName == "Or") {
ImportNodeOneOut<mlir::ONNXOrOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXOrOp>(node, 2, 1, {
@ -371,19 +322,15 @@
}); });
}else if (OpName == "Pad") { }else if (OpName == "Pad") {
ImportNodeOneOut<mlir::ONNXPadOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXPadOp>(node, 3, 1, {
{"mode","str","constant"} {"mode", "constant"}
}); });
}else if (OpName == "Pow") { }else if (OpName == "Pow") {
ImportNodeOneOut<mlir::ONNXPowOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXPowOp>(node, 2, 1, {
}); });
}else if (OpName == "QLinearConv") { }else if (OpName == "QLinearConv") {
ImportNodeOneOut<mlir::ONNXQLinearConvOp>(node, 9, 1, { ImportNodeOneOut<mlir::ONNXQLinearConvOp>(node, 9, 1, {
{"auto_pad","str","NOTSET"} {"auto_pad", "NOTSET"}
,{"dilations","", ""} ,{"group", 1}
,{"group","int","1"}
,{"kernel_shape","", ""}
,{"pads","", ""}
,{"strides","", ""}
}); });
}else if (OpName == "QLinearMatMul") { }else if (OpName == "QLinearMatMul") {
ImportNodeOneOut<mlir::ONNXQLinearMatMulOp>(node, 8, 1, { ImportNodeOneOut<mlir::ONNXQLinearMatMulOp>(node, 8, 1, {
@ -393,42 +340,32 @@
}); });
}else if (OpName == "RNN") { }else if (OpName == "RNN") {
ImportNodeMultipleOuts<mlir::ONNXRNNOp>(node, 6, 2, { ImportNodeMultipleOuts<mlir::ONNXRNNOp>(node, 6, 2, {
{"activation_alpha","floats", "{}"} {"activation_alpha", std::vector<float> {}}
,{"activation_beta","floats", "{}"} ,{"activation_beta", std::vector<float> {}}
,{"activations","", "{Tannh, Tanh}"} ,{"activations", std::vector<std::string>{"Tanh", "Tanh"}}
,{"clip","", ""} ,{"direction", "forward"}
,{"direction","str","forward"}
,{"hidden_size","", ""}
}); });
}else if (OpName == "RandomNormal") { }else if (OpName == "RandomNormal") {
ImportNodeOneOut<mlir::ONNXRandomNormalOp>(node, 0, 1, { ImportNodeOneOut<mlir::ONNXRandomNormalOp>(node, 0, 1, {
{"dtype","int","1"} {"dtype", 1}
,{"mean","float","0.0"} ,{"mean", (float)0.0}
,{"scale","float","1.0"} ,{"scale", (float)1.0}
,{"seed","", ""}
,{"shape","", ""}
}); });
}else if (OpName == "RandomNormalLike") { }else if (OpName == "RandomNormalLike") {
ImportNodeOneOut<mlir::ONNXRandomNormalLikeOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXRandomNormalLikeOp>(node, 1, 1, {
{"dtype","", ""} {"mean", (float)0.0}
,{"mean","float","0.0"} ,{"scale", (float)1.0}
,{"scale","float","1.0"}
,{"seed","", ""}
}); });
}else if (OpName == "RandomUniform") { }else if (OpName == "RandomUniform") {
ImportNodeOneOut<mlir::ONNXRandomUniformOp>(node, 0, 1, { ImportNodeOneOut<mlir::ONNXRandomUniformOp>(node, 0, 1, {
{"dtype","int","1"} {"dtype", 1}
,{"high","float","1.0"} ,{"high", (float)1.0}
,{"low","float","0.0"} ,{"low", (float)0.0}
,{"seed","", ""}
,{"shape","", ""}
}); });
}else if (OpName == "RandomUniformLike") { }else if (OpName == "RandomUniformLike") {
ImportNodeOneOut<mlir::ONNXRandomUniformLikeOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXRandomUniformLikeOp>(node, 1, 1, {
{"dtype","", ""} {"high", (float)1.0}
,{"high","float","1.0"} ,{"low", (float)0.0}
,{"low","float","0.0"}
,{"seed","", ""}
}); });
}else if (OpName == "Range") { }else if (OpName == "Range") {
ImportNodeOneOut<mlir::ONNXRangeOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXRangeOp>(node, 3, 1, {
@ -438,53 +375,43 @@
}); });
}else if (OpName == "ReduceL1") { }else if (OpName == "ReduceL1") {
ImportNodeOneOut<mlir::ONNXReduceL1Op>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXReduceL1Op>(node, 1, 1, {
{"axes","", ""} {"keepdims", 1}
,{"keepdims","int","1"}
}); });
}else if (OpName == "ReduceL2") { }else if (OpName == "ReduceL2") {
ImportNodeOneOut<mlir::ONNXReduceL2Op>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXReduceL2Op>(node, 1, 1, {
{"axes","", ""} {"keepdims", 1}
,{"keepdims","int","1"}
}); });
}else if (OpName == "ReduceLogSum") { }else if (OpName == "ReduceLogSum") {
ImportNodeOneOut<mlir::ONNXReduceLogSumOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXReduceLogSumOp>(node, 1, 1, {
{"axes","", ""} {"keepdims", 1}
,{"keepdims","int","1"}
}); });
}else if (OpName == "ReduceLogSumExp") { }else if (OpName == "ReduceLogSumExp") {
ImportNodeOneOut<mlir::ONNXReduceLogSumExpOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXReduceLogSumExpOp>(node, 1, 1, {
{"axes","", ""} {"keepdims", 1}
,{"keepdims","int","1"}
}); });
}else if (OpName == "ReduceMax") { }else if (OpName == "ReduceMax") {
ImportNodeOneOut<mlir::ONNXReduceMaxOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXReduceMaxOp>(node, 1, 1, {
{"axes","", ""} {"keepdims", 1}
,{"keepdims","int","1"}
}); });
}else if (OpName == "ReduceMean") { }else if (OpName == "ReduceMean") {
ImportNodeOneOut<mlir::ONNXReduceMeanOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXReduceMeanOp>(node, 1, 1, {
{"axes","", ""} {"keepdims", 1}
,{"keepdims","int","1"}
}); });
}else if (OpName == "ReduceMin") { }else if (OpName == "ReduceMin") {
ImportNodeOneOut<mlir::ONNXReduceMinOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXReduceMinOp>(node, 1, 1, {
{"axes","", ""} {"keepdims", 1}
,{"keepdims","int","1"}
}); });
}else if (OpName == "ReduceProd") { }else if (OpName == "ReduceProd") {
ImportNodeOneOut<mlir::ONNXReduceProdOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXReduceProdOp>(node, 1, 1, {
{"axes","", ""} {"keepdims", 1}
,{"keepdims","int","1"}
}); });
}else if (OpName == "ReduceSum") { }else if (OpName == "ReduceSum") {
ImportNodeOneOut<mlir::ONNXReduceSumOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXReduceSumOp>(node, 1, 1, {
{"axes","", ""} {"keepdims", 1}
,{"keepdims","int","1"}
}); });
}else if (OpName == "ReduceSumSquare") { }else if (OpName == "ReduceSumSquare") {
ImportNodeOneOut<mlir::ONNXReduceSumSquareOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXReduceSumSquareOp>(node, 1, 1, {
{"axes","", ""} {"keepdims", 1}
,{"keepdims","int","1"}
}); });
}else if (OpName == "Relu") { }else if (OpName == "Relu") {
ImportNodeOneOut<mlir::ONNXReluOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXReluOp>(node, 1, 1, {
@ -494,53 +421,47 @@
}); });
}else if (OpName == "Resize") { }else if (OpName == "Resize") {
ImportNodeOneOut<mlir::ONNXResizeOp>(node, 4, 1, { ImportNodeOneOut<mlir::ONNXResizeOp>(node, 4, 1, {
{"coordinate_transformation_mode","str","half_pixel"} {"coordinate_transformation_mode", "half_pixel"}
,{"cubic_coeff_a","float","-0.75"} ,{"cubic_coeff_a", (float)-0.75}
,{"exclude_outside","int","0"} ,{"exclude_outside", 0}
,{"extrapolation_value","float","0.0"} ,{"extrapolation_value", (float)0.0}
,{"mode","str","nearest"} ,{"mode", "nearest"}
,{"nearest_mode","str","round_prefer_floor"} ,{"nearest_mode", "round_prefer_floor"}
}); });
}else if (OpName == "ReverseSequence") { }else if (OpName == "ReverseSequence") {
ImportNodeOneOut<mlir::ONNXReverseSequenceOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXReverseSequenceOp>(node, 2, 1, {
{"batch_axis","int","1"} {"batch_axis", 1}
,{"time_axis","int","0"} ,{"time_axis", 0}
}); });
}else if (OpName == "RoiAlign") { }else if (OpName == "RoiAlign") {
ImportNodeOneOut<mlir::ONNXRoiAlignOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXRoiAlignOp>(node, 3, 1, {
{"mode","str","avg"} {"mode", "avg"}
,{"output_height","int","1"} ,{"output_height", 1}
,{"output_width","int","1"} ,{"output_width", 1}
,{"sampling_ratio","int","0"} ,{"sampling_ratio", 0}
,{"spatial_scale","float","1.0"} ,{"spatial_scale", (float)1.0}
}); });
}else if (OpName == "Round") { }else if (OpName == "Round") {
ImportNodeOneOut<mlir::ONNXRoundOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXRoundOp>(node, 1, 1, {
}); });
}else if (OpName == "Scan") { }else if (OpName == "Scan") {
ImportNodeOneOut<mlir::ONNXScanOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXScanOp>(node, 1, 1, {
{"body","", ""}
,{"num_scan_inputs","", ""}
,{"scan_input_axes","", ""}
,{"scan_input_directions","", ""}
,{"scan_output_axes","", ""}
,{"scan_output_directions","", ""}
}); });
}else if (OpName == "Scatter") { }else if (OpName == "Scatter") {
ImportNodeOneOut<mlir::ONNXScatterOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXScatterOp>(node, 3, 1, {
{"axis","int","0"} {"axis", 0}
}); });
}else if (OpName == "ScatterElements") { }else if (OpName == "ScatterElements") {
ImportNodeOneOut<mlir::ONNXScatterElementsOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXScatterElementsOp>(node, 3, 1, {
{"axis","int","0"} {"axis", 0}
}); });
}else if (OpName == "ScatterND") { }else if (OpName == "ScatterND") {
ImportNodeOneOut<mlir::ONNXScatterNDOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXScatterNDOp>(node, 3, 1, {
}); });
}else if (OpName == "Selu") { }else if (OpName == "Selu") {
ImportNodeOneOut<mlir::ONNXSeluOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXSeluOp>(node, 1, 1, {
{"alpha","float","1.67326"} {"alpha", (float)1.67326}
,{"gamma","float","1.0507"} ,{"gamma", (float)1.0507}
}); });
}else if (OpName == "SequenceAt") { }else if (OpName == "SequenceAt") {
ImportNodeOneOut<mlir::ONNXSequenceAtOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXSequenceAtOp>(node, 2, 1, {
@ -550,7 +471,6 @@
}); });
}else if (OpName == "SequenceEmpty") { }else if (OpName == "SequenceEmpty") {
ImportNodeOneOut<mlir::ONNXSequenceEmptyOp>(node, 0, 1, { ImportNodeOneOut<mlir::ONNXSequenceEmptyOp>(node, 0, 1, {
{"dtype","", ""}
}); });
}else if (OpName == "SequenceErase") { }else if (OpName == "SequenceErase") {
ImportNodeOneOut<mlir::ONNXSequenceEraseOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXSequenceEraseOp>(node, 2, 1, {
@ -566,8 +486,8 @@
}); });
}else if (OpName == "Shrink") { }else if (OpName == "Shrink") {
ImportNodeOneOut<mlir::ONNXShrinkOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXShrinkOp>(node, 1, 1, {
{"bias","float","0.0"} {"bias", (float)0.0}
,{"lambd","float","0.5"} ,{"lambd", (float)0.5}
}); });
}else if (OpName == "Sigmoid") { }else if (OpName == "Sigmoid") {
ImportNodeOneOut<mlir::ONNXSigmoidOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXSigmoidOp>(node, 1, 1, {
@ -589,7 +509,7 @@
}); });
}else if (OpName == "Softmax") { }else if (OpName == "Softmax") {
ImportNodeOneOut<mlir::ONNXSoftmaxOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXSoftmaxOp>(node, 1, 1, {
{"axis","int","1"} {"axis", 1}
}); });
}else if (OpName == "Softplus") { }else if (OpName == "Softplus") {
ImportNodeOneOut<mlir::ONNXSoftplusOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXSoftplusOp>(node, 1, 1, {
@ -599,31 +519,26 @@
}); });
}else if (OpName == "SpaceToDepth") { }else if (OpName == "SpaceToDepth") {
ImportNodeOneOut<mlir::ONNXSpaceToDepthOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXSpaceToDepthOp>(node, 1, 1, {
{"blocksize","", ""}
}); });
}else if (OpName == "Split") { }else if (OpName == "Split") {
ImportNodeOneOut<mlir::ONNXSplitOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXSplitOp>(node, 1, 1, {
{"axis","int","0"} {"axis", 0}
,{"split","", ""}
}); });
}else if (OpName == "SplitToSequence") { }else if (OpName == "SplitToSequence") {
ImportNodeOneOut<mlir::ONNXSplitToSequenceOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXSplitToSequenceOp>(node, 2, 1, {
{"axis","int","0"} {"axis", 0}
,{"keepdims","int","1"} ,{"keepdims", 1}
}); });
}else if (OpName == "Sqrt") { }else if (OpName == "Sqrt") {
ImportNodeOneOut<mlir::ONNXSqrtOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXSqrtOp>(node, 1, 1, {
}); });
}else if (OpName == "Squeeze") { }else if (OpName == "Squeeze") {
ImportNodeOneOut<mlir::ONNXSqueezeOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXSqueezeOp>(node, 1, 1, {
{"axes","", ""}
}); });
}else if (OpName == "StringNormalizer") { }else if (OpName == "StringNormalizer") {
ImportNodeOneOut<mlir::ONNXStringNormalizerOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXStringNormalizerOp>(node, 1, 1, {
{"case_change_action","str","NONE"} {"case_change_action", "NONE"}
,{"is_case_sensitive","int","0"} ,{"is_case_sensitive", 0}
,{"locale","", ""}
,{"stopwords","", ""}
}); });
}else if (OpName == "Sub") { }else if (OpName == "Sub") {
ImportNodeOneOut<mlir::ONNXSubOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXSubOp>(node, 2, 1, {
@ -639,45 +554,34 @@
}); });
}else if (OpName == "TfIdfVectorizer") { }else if (OpName == "TfIdfVectorizer") {
ImportNodeOneOut<mlir::ONNXTfIdfVectorizerOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXTfIdfVectorizerOp>(node, 1, 1, {
{"max_gram_length","", ""}
,{"max_skip_count","", ""}
,{"min_gram_length","", ""}
,{"mode","", ""}
,{"ngram_counts","", ""}
,{"ngram_indexes","", ""}
,{"pool_int64s","", ""}
,{"pool_strings","", ""}
,{"weights","", ""}
}); });
}else if (OpName == "ThresholdedRelu") { }else if (OpName == "ThresholdedRelu") {
ImportNodeOneOut<mlir::ONNXThresholdedReluOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXThresholdedReluOp>(node, 1, 1, {
{"alpha","float","1.0"} {"alpha", (float)1.0}
}); });
}else if (OpName == "Tile") { }else if (OpName == "Tile") {
ImportNodeOneOut<mlir::ONNXTileOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXTileOp>(node, 2, 1, {
}); });
}else if (OpName == "TopK") { }else if (OpName == "TopK") {
ImportNodeMultipleOuts<mlir::ONNXTopKOp>(node, 2, 2, { ImportNodeMultipleOuts<mlir::ONNXTopKOp>(node, 2, 2, {
{"axis","int","-1"} {"axis", -1}
,{"largest","int","1"} ,{"largest", 1}
,{"sorted","int","1"} ,{"sorted", 1}
}); });
}else if (OpName == "Transpose") { }else if (OpName == "Transpose") {
ImportNodeOneOut<mlir::ONNXTransposeOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXTransposeOp>(node, 1, 1, {
{"perm","", ""}
}); });
}else if (OpName == "Unique") { }else if (OpName == "Unique") {
ImportNodeMultipleOuts<mlir::ONNXUniqueOp>(node, 1, 4, { ImportNodeMultipleOuts<mlir::ONNXUniqueOp>(node, 1, 4, {
{"axis","", ""} {"sorted", 1}
,{"sorted","int","1"}
}); });
}else if (OpName == "Unsqueeze") { }else if (OpName == "Unsqueeze") {
ImportNodeOneOut<mlir::ONNXUnsqueezeOp>(node, 1, 1, { ImportNodeOneOut<mlir::ONNXUnsqueezeOp>(node, 1, 1, {
{"axes","ints", ""} {"axes", std::vector<int64_t> {}}
}); });
}else if (OpName == "Upsample") { }else if (OpName == "Upsample") {
ImportNodeOneOut<mlir::ONNXUpsampleOp>(node, 2, 1, { ImportNodeOneOut<mlir::ONNXUpsampleOp>(node, 2, 1, {
{"mode","str","nearest"} {"mode", "nearest"}
}); });
}else if (OpName == "Where") { }else if (OpName == "Where") {
ImportNodeOneOut<mlir::ONNXWhereOp>(node, 3, 1, { ImportNodeOneOut<mlir::ONNXWhereOp>(node, 3, 1, {

View File

@ -368,17 +368,17 @@ def gen_code(schema,fefile) :
("MaxPool", "ImportNodeMaxPool"), ("MaxPool", "ImportNodeMaxPool"),
#("Transpose", "ImportNodeTranspose") #("Transpose", "ImportNodeTranspose")
]) ])
special_type = dict([ list_str = 'std::vector'
("AveragePool "+"kernel_shape", '"ints", ""'), empty_ints = list_str+'<int> {}'
("MaxPool "+"kernel_shape", '"ints", ""'), empty_floats = list_str+'<float> {}'
("Cast "+"to", '"int", "0"'), special_default = dict([
("Concat "+"axis", '"int", "0"'), ("AveragePool "+"kernel_shape", empty_ints),
("Conv "+"group", '"int", "1"'), ("MaxPool "+"kernel_shape", empty_ints),
("Unsqueeze "+"axes", '"ints", ""'), ("Cast "+"to", '0'),
("RNN "+"activation_alpha", '"floats", "{}"'), ("Concat "+"axis", '0'),
("RNN "+"activation_beta", '"floats", "{}"'), ("Unsqueeze "+"axes", empty_ints),
("RNN "+"activations", '"", "{Tannh, Tanh}"'), ("RNN "+"activation_alpha", empty_floats),
("LRN "+"size", '"int", ""') ("RNN "+"activation_beta", empty_floats)
]) ])
line_indent = ' ' line_indent = ' '
fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n') fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n')
@ -400,21 +400,9 @@ def gen_code(schema,fefile) :
if schema.attributes: if schema.attributes:
first_attr = True first_attr = True
for _, attr in sorted(schema.attributes.items()): for _, attr in sorted(schema.attributes.items()):
attr_line = line_indent+line_indent+line_indent+line_indent #only generate default attr list
if not first_attr: if schema.name+' '+attr.name in special_default:
attr_line += ',{' attr_value = special_default[schema.name+' '+attr.name]
else :
attr_line += ' {'
first_attr = False
attr_line += '"'+attr.name+'",'
if schema.name+' '+attr.name in special_type:
attr_line += special_type[schema.name+' '+attr.name]
# option holds either required or default value
elif attr.required:
attr_line += '"", ""'
elif attr.default_value.name: elif attr.default_value.name:
default_value = helper.get_attribute_value(attr.default_value) default_value = helper.get_attribute_value(attr.default_value)
@ -430,28 +418,35 @@ def gen_code(schema,fefile) :
return str(value) return str(value)
if isinstance(default_value, list): if isinstance(default_value, list):
value = default_value[0] value = default_value[0]
default_value = [format_value(val) for val in default_value] default_value = [format_value(val) for val in default_value]
attr_option_str = '{}'.format(default_value)
attr_option_str = attr_option_str.replace('[', '{', 1)
attr_option_str = attr_option_str.replace(']', '}', 1)
# TODO the list type is homogenous or htergeneous? # TODO the list type is homogenous or htergeneous?
if isinstance(value, float) : if isinstance(value, float) :
attr_type_str = '"floats"' attr_type_str = list_str+'<float>'
attr_option_str = attr_option_str.replace("'", '')
elif isinstance(value, int) : elif isinstance(value, int) :
attr_type_str = '"ints"' attr_type_str = list_str+'<int>'
attr_option_str = attr_option_str.replace("'", '')
elif isinstance(value, str) : elif isinstance(value, str) :
attr_type_str = '"strs"' attr_type_str = list_str+'<std::string>'
attr_option_str = attr_option_str.replace("'", '"')
elif isinstance(value, (bytes, bytearray)) : elif isinstance(value, (bytes, bytearray)) :
attr_type_str = '"strs"' attr_type_str = list_str+'<std::string>'
attr_option_str = attr_option_str.replace("'", '"')
else : else :
attr_type_str = '"unknowns"' attr_type_str = '"unknowns"'
attr_option_str = '"{}"'.format(default_value)
attr_option_str = attr_option_str.replace('[', '{', 1)
attr_option_str = attr_option_str.replace(']', '}', 1)
else: else:
if isinstance(default_value, float) : if isinstance(default_value, float) :
attr_type_str = '"float"' attr_type_str = '(float)'
attr_option_str = default_value
elif isinstance(default_value, int) : elif isinstance(default_value, int) :
attr_type_str = '"int"' attr_option_str = default_value
attr_type_str=''
elif isinstance(default_value, str) : elif isinstance(default_value, str) :
attr_type_str = '"str"' attr_type_str = '"str"'
elif isinstance(default_value, (bytes, bytearray)) : elif isinstance(default_value, (bytes, bytearray)) :
@ -459,11 +454,25 @@ def gen_code(schema,fefile) :
else : else :
attr_type_str = '"unknown"' attr_type_str = '"unknown"'
default_value = format_value(default_value) default_value = format_value(default_value)
attr_option_str = '"{}"'.format(default_value) if attr_type_str == '"str"' :
attr_line += attr_type_str+','+attr_option_str attr_option_str = '"'+default_value+'"'
attr_type_str=''
else : else :
#TODO why? attr_option_str = default_value
attr_line += '"", ""' attr_value = attr_type_str+attr_option_str
else:
#no default value
continue
attr_line = line_indent+line_indent+line_indent+line_indent
if not first_attr:
attr_line += ',{'
else :
attr_line += ' {'
first_attr = False
attr_line += '"'+attr.name+'", '
attr_line += attr_value
attr_line += '}\n' attr_line += '}\n'
fefile.write(attr_line) fefile.write(attr_line)
fefile.write(line_indent+line_indent+line_indent+'});\n') fefile.write(line_indent+line_indent+line_indent+'});\n')

View File

@ -420,8 +420,8 @@ Value mapToLowerScalarOp<ONNXHardSigmoidOp>(
// Constant 1) // Constant 1)
auto loc = op->getLoc(); auto loc = op->getLoc();
Value operand = operands[0]; Value operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.alpha"); auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha");
auto betaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.beta"); auto betaAttr = op->getAttrOfType<FloatAttr>("beta");
auto elementType = result_types[0]; auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0)); auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
@ -455,7 +455,7 @@ Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
Value operand = operands[0]; Value operand = operands[0];
auto elementType = result_types[0]; auto elementType = result_types[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("Elu.alpha"); auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha");
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0)); auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1)); auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr); auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
@ -508,7 +508,7 @@ Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op,
Value operand = operands[0]; Value operand = operands[0];
auto elementType = result_types[0]; auto elementType = result_types[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("LeakyRelu.alpha"); auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha");
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0)); auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr); auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto lessThanZero = auto lessThanZero =
@ -533,8 +533,8 @@ Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types,
// alpha))) // alpha)))
auto loc = op->getLoc(); auto loc = op->getLoc();
Value operand = operands[0]; Value operand = operands[0];
auto alphaAttr = op->getAttrOfType<FloatAttr>("Selu.alpha"); auto alphaAttr = op->getAttrOfType<FloatAttr>("alpha");
auto gammaAttr = op->getAttrOfType<FloatAttr>("Selu.gamma"); auto gammaAttr = op->getAttrOfType<FloatAttr>("gamma");
auto elementType = result_types[0]; auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0)); auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
@ -836,7 +836,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
// exp_x / sum // exp_x / sum
auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>(); auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
int64_t rank = tensorType.getRank(); int64_t rank = tensorType.getRank();
int64_t axis = op->getAttrOfType<IntegerAttr>("Softmax.axis").getInt(); int64_t axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
axis = axis >= 0 ? axis : rank + axis; axis = axis >= 0 ? axis : rank + axis;
assert(axis >= -rank && axis <= rank - 1); assert(axis >= -rank && axis <= rank - 1);

View File

@ -385,7 +385,7 @@ func @test_min(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*
} }
func @test_elu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { func @test_elu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Elu"(%arg0) {Elu.alpha=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32> %0 = "onnx.Elu"(%arg0) {alpha=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_elu // CHECK-LABEL: test_elu
@ -411,7 +411,7 @@ func @test_elu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
} }
func @test_leakyrelu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { func @test_leakyrelu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.LeakyRelu"(%arg0) {LeakyRelu.alpha=1.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32> %0 = "onnx.LeakyRelu"(%arg0) {alpha=1.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_leakyrelu // CHECK-LABEL: test_leakyrelu
@ -434,7 +434,7 @@ func @test_leakyrelu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
} }
func @test_selu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { func @test_selu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Selu"(%arg0) {Selu.alpha=1.0:f32, Selu.gamma=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32> %0 = "onnx.Selu"(%arg0) {alpha=1.0:f32, gamma=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_selu // CHECK-LABEL: test_selu
@ -461,7 +461,7 @@ func @test_selu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
} }
func @test_hardsigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { func @test_hardsigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.HardSigmoid"(%arg0) {HardSigmoid.alpha=1.0:f32, HardSigmoid.beta=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32> %0 = "onnx.HardSigmoid"(%arg0) {alpha=1.0:f32, beta=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_hardsigmoid // CHECK-LABEL: test_hardsigmoid
@ -535,7 +535,7 @@ func @test_add_with_broadcasting(%arg0 : tensor<?xf32>, %arg1 : tensor<?x10xf32>
} }
func @test_softmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> { func @test_softmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Softmax"(%arg0) {Softmax.axis=1:i32} : (tensor<10x10xf32>) -> tensor<*xf32> %0 = "onnx.Softmax"(%arg0) {axis=1:i32} : (tensor<10x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_softmax // CHECK-LABEL: test_softmax

View File

@ -648,8 +648,8 @@ func @test_min_min(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tens
} }
func @test_elu_elu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { func @test_elu_elu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Elu"(%arg0) {Elu.alpha=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32> %0 = "onnx.Elu"(%arg0) {alpha=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
%1 = "onnx.Elu"(%0) {Elu.alpha=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32> %1 = "onnx.Elu"(%0) {alpha=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> () "std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_elu_elu // CHECK-LABEL: test_elu_elu
@ -701,8 +701,8 @@ func @test_elu_elu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
} }
func @test_leakyrelu_leakyrelu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { func @test_leakyrelu_leakyrelu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.LeakyRelu"(%arg0) {LeakyRelu.alpha=1.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32> %0 = "onnx.LeakyRelu"(%arg0) {alpha=1.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
%1 = "onnx.LeakyRelu"(%0) {LeakyRelu.alpha=1.0:f32} : (tensor<*xf32>) -> tensor<*xf32> %1 = "onnx.LeakyRelu"(%0) {alpha=1.0:f32} : (tensor<*xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> () "std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_leakyrelu_leakyrelu // CHECK-LABEL: test_leakyrelu_leakyrelu
@ -748,8 +748,8 @@ func @test_leakyrelu_leakyrelu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
} }
func @test_selu_selu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { func @test_selu_selu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Selu"(%arg0) {Selu.alpha=1.0:f32, Selu.gamma=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32> %0 = "onnx.Selu"(%arg0) {alpha=1.0:f32, gamma=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
%1 = "onnx.Selu"(%0) {Selu.alpha=1.0:f32, Selu.gamma=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32> %1 = "onnx.Selu"(%0) {alpha=1.0:f32, gamma=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> () "std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_selu_selu // CHECK-LABEL: test_selu_selu
@ -803,8 +803,8 @@ func @test_selu_selu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
} }
func @test_hardsigmoid_hardsigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { func @test_hardsigmoid_hardsigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.HardSigmoid"(%arg0) {HardSigmoid.alpha=1.0:f32, HardSigmoid.beta=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32> %0 = "onnx.HardSigmoid"(%arg0) {alpha=1.0:f32, beta=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
%1 = "onnx.HardSigmoid"(%0) {HardSigmoid.alpha=1.0:f32, HardSigmoid.beta=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32> %1 = "onnx.HardSigmoid"(%0) {alpha=1.0:f32, beta=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> () "std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_hardsigmoid_hardsigmoid // CHECK-LABEL: test_hardsigmoid_hardsigmoid

1
third_party/variant vendored Submodule

@ -0,0 +1 @@
Subproject commit 3c7fc8266bb46046b42c2dc2663f9f505f0cec28