[MLIR] import attribute of onnx node (#383)
* add attributes as NamedAttribute * support list value for attribute * use std::tie to avoid c++17 feature
This commit is contained in:
parent
45608282e0
commit
c8d591fb28
|
@ -1,9 +1,12 @@
|
|||
add_executable(onnf main.cpp)
|
||||
|
||||
target_link_libraries(onnf builder compiler ${MLIRLibs} ${Boost_LIBRARIES})
|
||||
set_target_properties(onnf PROPERTIES LINK_FLAGS "-lz")
|
||||
|
||||
whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs})
|
||||
whole_archive_link_onnf(onnf onnf_transform)
|
||||
|
||||
|
||||
target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR})
|
||||
target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR})
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
add_library(builder
|
||||
frontend_dialect_transformer.cpp
|
||||
frontend_dialect_transformer.hpp
|
||||
op_build_table.inc
|
||||
)
|
||||
|
||||
|
|
|
@ -71,7 +71,10 @@ struct OnnxOnnfSymbolMapping {
|
|||
* @param name onnx tensor name.
|
||||
* @return onnf tensor corresponding to `name`.
|
||||
*/
|
||||
mlir::Value* GetTensorByOnnxName(std::string name) {
|
||||
mlir::Value *GetTensorByOnnxName(std::string name) {
|
||||
assert(onnx_name2onnf_tensor.find(legalize_name(name)) !=
|
||||
onnx_name2onnf_tensor.end() &&
|
||||
"Tensor not found");
|
||||
return onnx_name2onnf_tensor.at(legalize_name(name));
|
||||
}
|
||||
|
||||
|
@ -80,7 +83,9 @@ struct OnnxOnnfSymbolMapping {
|
|||
* @param name onnx tensor name.
|
||||
* @param tensor MLIR Value* pointer.
|
||||
*/
|
||||
void AddMapping(std::string name, mlir::Value* tensor) {
|
||||
void AddMapping(std::string name, mlir::Value *tensor) {
|
||||
assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 &&
|
||||
"Tensor already exists.");
|
||||
onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
|
||||
}
|
||||
|
||||
|
@ -88,7 +93,7 @@ struct OnnxOnnfSymbolMapping {
|
|||
return onnx_name2onnf_tensor.count(name) != 0;
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
/*!
|
||||
* mapping from onnx tensor names to MLIR tensor.
|
||||
*/
|
||||
|
@ -96,8 +101,8 @@ struct OnnxOnnfSymbolMapping {
|
|||
};
|
||||
|
||||
class FrontendGenImpl {
|
||||
public:
|
||||
FrontendGenImpl(mlir::MLIRContext& context)
|
||||
public:
|
||||
FrontendGenImpl(mlir::MLIRContext &context)
|
||||
: context_(context), builder_(&context) {
|
||||
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
}
|
||||
|
@ -107,8 +112,8 @@ class FrontendGenImpl {
|
|||
return module_;
|
||||
}
|
||||
|
||||
private:
|
||||
mlir::MLIRContext& context_;
|
||||
private:
|
||||
mlir::MLIRContext &context_;
|
||||
mlir::ModuleOp module_;
|
||||
mlir::OpBuilder builder_;
|
||||
// mapping between string name and symbol
|
||||
|
@ -145,55 +150,31 @@ class FrontendGenImpl {
|
|||
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
|
||||
assert(false && "Unsupported data type encountered.");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
//if c++17 is used, these two def can be combined with 'if constexpr'
|
||||
//leave n there for possible future use
|
||||
//alternative is to use template and pass the outputTypes, inputs and attributes
|
||||
//as parameter
|
||||
|
||||
#define MultipleOuts(name, nIn, nOut)\
|
||||
{ \
|
||||
if (nIn == inputs.size() && nOut == outputTypes.size()) {\
|
||||
auto op = builder_.create<mlir::ONNX##name##Op>(UnknownLoc(), outputTypes, inputs, attributes); \
|
||||
for (int i = 0; i < node.output().size(); i++) { \
|
||||
frontend_symbols_.AddMapping(\
|
||||
legalize_name(node.output()[i]), op.getResult(i));\
|
||||
}\
|
||||
return;\
|
||||
}\
|
||||
}
|
||||
|
||||
#define OneOut(name, nIn, nOut)\
|
||||
{ \
|
||||
if (nIn == inputs.size() && nOut == outputTypes.size()) {\
|
||||
auto op = builder_.create<mlir::ONNX##name##Op>(UnknownLoc(), outputTypes, inputs, attributes); \
|
||||
frontend_symbols_.AddMapping(\
|
||||
legalize_name(node.output()[0]), op.getResult());\
|
||||
return;\
|
||||
}\
|
||||
}
|
||||
|
||||
/*!
|
||||
* Import an onnx input tensor type by determining and recording its type
|
||||
* in a list of input tensor mlir types.
|
||||
* @param input onnx input tensor ValueInfoProto.
|
||||
* @param arg_types list of mlir types representing types of graph input.
|
||||
*/
|
||||
void ImportInputTensorType(const onnx::ValueInfoProto& input,
|
||||
llvm::SmallVector<mlir::Type, 4>& arg_types) {
|
||||
void ImportInputTensorType(const onnx::ValueInfoProto &input,
|
||||
llvm::SmallVector<mlir::Type, 4> &arg_types) {
|
||||
std::vector<int64_t> dims;
|
||||
auto shape_proto = input.type().tensor_type().shape();
|
||||
auto input_tensor_legalized_name = legalize_name(input.name());
|
||||
for (int i = 0; i < shape_proto.dim_size(); i++) {
|
||||
if (shape_proto.dim()[i].dim_value()) {
|
||||
int dim_numeric_size = shape_proto.dim()[i].dim_value();
|
||||
assert(
|
||||
dim_numeric_size != 0 && "Parsed an input tensor with a dimension size of zero");
|
||||
if (dim_numeric_size > 0) {
|
||||
dims.push_back(dim_numeric_size);
|
||||
} else { // If dim_value < 0, then dim is parametric.
|
||||
// TODO Verify the unknown dim size in MLIR
|
||||
} else { // If dim_value < 0, then dim is parametric.
|
||||
// TODO Verify the unknown dim size in MLIR
|
||||
dims.push_back(-1);
|
||||
}
|
||||
} else {
|
||||
|
@ -216,8 +197,8 @@ class FrontendGenImpl {
|
|||
* @param input onnx input tensor ValueInfoProto.
|
||||
* @param symbol mlir input argument.
|
||||
*/
|
||||
void ImportInputTensorSymbol(
|
||||
const onnx::ValueInfoProto& input, mlir::Value* symbol) {
|
||||
void ImportInputTensorSymbol(const onnx::ValueInfoProto &input,
|
||||
mlir::Value *symbol) {
|
||||
auto input_tensor_legalized_name = legalize_name(input.name());
|
||||
assert(
|
||||
!frontend_symbols_.ContainKey(input_tensor_legalized_name) &&
|
||||
|
@ -225,32 +206,286 @@ class FrontendGenImpl {
|
|||
frontend_symbols_.AddMapping(input_tensor_legalized_name, symbol);
|
||||
}
|
||||
|
||||
void ImportNode(onnx::NodeProto node) {
|
||||
std::vector<mlir::Value*> inputs;
|
||||
template <typename T>
|
||||
T get_attr_generic(onnx::NodeProto &node, std::string name,
|
||||
std::function<T(onnx::AttributeProto &)> attr_getter,
|
||||
T default_val) {
|
||||
for (int i = 0; i < node.attribute_size(); ++i) {
|
||||
auto attr = node.attribute(i);
|
||||
if (attr.name() == name) {
|
||||
return attr_getter(attr);
|
||||
}
|
||||
}
|
||||
return default_val;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T get_attr_generic(onnx::NodeProto &node, std::string name,
|
||||
std::function<T(onnx::AttributeProto &)> attr_getter) {
|
||||
for (int i = 0; i < node.attribute_size(); ++i) {
|
||||
auto attr = node.attribute(i);
|
||||
if (attr.name() == name) {
|
||||
return attr_getter(attr);
|
||||
}
|
||||
}
|
||||
assert(false && "ONNX Node Attribute Not Found!");
|
||||
}
|
||||
|
||||
auto get_attr_ints(onnx::NodeProto &node, std::string name,
|
||||
std::vector<int> default_val) {
|
||||
std::function<std::vector<int>(onnx::AttributeProto &)> attr_getter =
|
||||
[](onnx::AttributeProto &attr) {
|
||||
std::vector<int> ints(attr.ints_size());
|
||||
std::copy(attr.ints().begin(), attr.ints().end(), ints.begin());
|
||||
return ints;
|
||||
};
|
||||
auto r = get_attr_generic(node, name, attr_getter, default_val);
|
||||
auto dataType =
|
||||
mlir::RankedTensorType::get(r.size(), builder_.getIntegerType(32));
|
||||
auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r));
|
||||
auto aname = node.op_type() + "." + name;
|
||||
auto attr_output = builder_.getNamedAttr(aname, attr_v);
|
||||
return attr_output;
|
||||
}
|
||||
|
||||
auto get_attr_ints(onnx::NodeProto &node, std::string name) {
|
||||
std::function<std::vector<int>(onnx::AttributeProto &)> attr_getter =
|
||||
[](onnx::AttributeProto &attr) {
|
||||
std::vector<int> ints(attr.ints_size());
|
||||
std::copy(attr.ints().begin(), attr.ints().end(), ints.begin());
|
||||
return ints;
|
||||
};
|
||||
auto r = get_attr_generic(node, name, attr_getter);
|
||||
auto dataType =
|
||||
mlir::RankedTensorType::get(r.size(), builder_.getIntegerType(32));
|
||||
auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r));
|
||||
auto aname = node.op_type() + "." + name;
|
||||
auto attr_output = builder_.getNamedAttr(aname, attr_v);
|
||||
return attr_output;
|
||||
}
|
||||
|
||||
auto get_attr_floats(onnx::NodeProto &node, std::string name) {
|
||||
std::function<std::vector<float>(onnx::AttributeProto &)> attr_getter =
|
||||
[](onnx::AttributeProto &attr) {
|
||||
std::vector<float> floats(attr.floats_size());
|
||||
std::copy(attr.floats().begin(), attr.floats().end(), floats.begin());
|
||||
return floats;
|
||||
};
|
||||
auto r = get_attr_generic(node, name, attr_getter);
|
||||
auto dataType =
|
||||
mlir::RankedTensorType::get(r.size(), builder_.getF32Type());
|
||||
auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r));
|
||||
auto aname = node.op_type() + "." + name;
|
||||
auto attr_output = builder_.getNamedAttr(aname, attr_v);
|
||||
return attr_output;
|
||||
}
|
||||
|
||||
auto get_attr_floats(onnx::NodeProto &node, std::string name,
|
||||
std::vector<float> default_val) {
|
||||
std::function<std::vector<float>(onnx::AttributeProto &)> attr_getter =
|
||||
[](onnx::AttributeProto &attr) {
|
||||
std::vector<float> floats(attr.floats_size());
|
||||
std::copy(attr.floats().begin(), attr.floats().end(), floats.begin());
|
||||
return floats;
|
||||
};
|
||||
auto r = get_attr_generic(node, name, attr_getter, default_val);
|
||||
auto dataType =
|
||||
mlir::RankedTensorType::get(r.size(), builder_.getF32Type());
|
||||
auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r));
|
||||
auto aname = node.op_type() + "." + name;
|
||||
auto attr_output = builder_.getNamedAttr(aname, attr_v);
|
||||
return attr_output;
|
||||
}
|
||||
|
||||
auto get_attr_int(onnx::NodeProto &node, std::string name) {
|
||||
std::function<int(onnx::AttributeProto &)> attr_getter =
|
||||
[](onnx::AttributeProto &attr) { return attr.i(); };
|
||||
int r = get_attr_generic(node, name, attr_getter);
|
||||
auto attr_v = builder_.getI32IntegerAttr(r);
|
||||
auto aname = node.op_type() + "." + name;
|
||||
auto attr_output = builder_.getNamedAttr(aname, attr_v);
|
||||
return attr_output;
|
||||
}
|
||||
|
||||
auto get_attr_int(onnx::NodeProto &node, std::string name, int default_val) {
|
||||
std::function<int(onnx::AttributeProto &)> attr_getter =
|
||||
[](onnx::AttributeProto &attr) { return attr.i(); };
|
||||
int r = get_attr_generic(node, name, attr_getter, default_val);
|
||||
auto attr_v = builder_.getI32IntegerAttr(r);
|
||||
auto aname = node.op_type() + "." + name;
|
||||
auto attr_output = builder_.getNamedAttr(aname, attr_v);
|
||||
return attr_output;
|
||||
}
|
||||
|
||||
auto get_attr_float(onnx::NodeProto &node, std::string name) {
|
||||
std::function<float(onnx::AttributeProto &)> attr_getter =
|
||||
[](onnx::AttributeProto &attr) { return attr.f(); };
|
||||
auto r = get_attr_generic(node, name, attr_getter);
|
||||
auto attr_v = builder_.getF32FloatAttr(r);
|
||||
auto aname = node.op_type() + "." + name;
|
||||
return builder_.getNamedAttr(aname, attr_v);
|
||||
}
|
||||
|
||||
auto get_attr_float(onnx::NodeProto &node, std::string name,
|
||||
float default_val) {
|
||||
std::function<float(onnx::AttributeProto &)> attr_getter =
|
||||
[](onnx::AttributeProto &attr) { return attr.f(); };
|
||||
auto r = get_attr_generic(node, name, attr_getter, default_val);
|
||||
auto attr_v = builder_.getF32FloatAttr(r);
|
||||
auto aname = node.op_type() + "." + name;
|
||||
return builder_.getNamedAttr(aname, attr_v);
|
||||
}
|
||||
|
||||
auto get_attr_string(onnx::NodeProto &node, std::string name) {
|
||||
std::function<std::string(onnx::AttributeProto &)> attr_getter =
|
||||
[](onnx::AttributeProto &attr) { return attr.s(); };
|
||||
auto r = get_attr_generic(node, name, attr_getter);
|
||||
auto attr_v = builder_.getStringAttr(r);
|
||||
auto aname = node.op_type() + "." + name;
|
||||
return builder_.getNamedAttr(aname, attr_v);
|
||||
}
|
||||
|
||||
auto get_attr_string(onnx::NodeProto &node, std::string name,
|
||||
std::string default_val) {
|
||||
std::function<std::string(onnx::AttributeProto &)> attr_getter =
|
||||
[](onnx::AttributeProto &attr) { return attr.s(); };
|
||||
auto r = get_attr_generic(node, name, attr_getter, default_val);
|
||||
auto attr_v = builder_.getStringAttr(r);
|
||||
auto aname = node.op_type() + "." + name;
|
||||
return builder_.getNamedAttr(aname, attr_v);
|
||||
}
|
||||
|
||||
/*
|
||||
auto get_attr_strings(onnx::NodeProto &node, std::string name) {
|
||||
std::function<std::vector<std::string>(onnx::AttributeProto &)>
|
||||
attr_getter =
|
||||
[](onnx::AttributeProto &attr) {
|
||||
std::vector<std::string> strings(attr.strings_size());
|
||||
std::copy(attr.strings().begin(), attr.strings().end(),
|
||||
strings.begin()); return strings;
|
||||
};
|
||||
auto r = get_attr_generic(node, name, attr_getter);
|
||||
return r;
|
||||
return builder_.getNamedAttr(aname, attr_v);
|
||||
auto dataType =
|
||||
mlir::RankedTensorType::get(r.size(), builder_.get???Type());
|
||||
auto attr_v = mlir::DenseElementsAttr::get(dataType,
|
||||
llvm::makeArrayRef(r)); auto aname = node.op_type() + "." + name; auto
|
||||
attr_output = builder_.getNamedAttr(aname, attr_v); return attr_output;
|
||||
}
|
||||
*/
|
||||
|
||||
auto get_default_ints(std::string default_str) {
|
||||
std::vector<int> r;
|
||||
auto start = default_str.find("{");
|
||||
while (true) {
|
||||
auto end = default_str.find(",", start + 1);
|
||||
if (end == std::string::npos) {
|
||||
end = default_str.find("}", start + 1);
|
||||
if (end != std::string::npos && end > start+1) {
|
||||
r.push_back(std::stoi(default_str.substr(start + 1, end)));
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
r.push_back(std::stoi(default_str.substr(start + 1, end)));
|
||||
}
|
||||
start = end + 1;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
auto get_default_floats(std::string default_str) {
|
||||
std::vector<float> r;
|
||||
auto start = default_str.find("{");
|
||||
while (true) {
|
||||
auto end = default_str.find(",", start + 1);
|
||||
if (end == std::string::npos) {
|
||||
end = default_str.find("}", start + 1);
|
||||
if (end != std::string::npos && end > start+1) {
|
||||
r.push_back(std::stof(default_str.substr(start + 1, end)));
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
r.push_back(std::stof(default_str.substr(start + 1, end)));
|
||||
}
|
||||
start = end + 1;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
auto get_default_strings(std::string default_str) {
|
||||
std::vector<std::string> r;
|
||||
auto start = default_str.find("{");
|
||||
while (true) {
|
||||
auto end = default_str.find(",", start + 1);
|
||||
if (end == std::string::npos) {
|
||||
end = default_str.find("}", start + 1);
|
||||
if (end != std::string::npos && end > start+1) {
|
||||
r.push_back(default_str.substr(start + 1, end));
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
r.push_back(default_str.substr(start + 1, end));
|
||||
}
|
||||
start = end + 1;
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
onnx::TensorProto get_attr_tensor(onnx::NodeProto &node, std::string name) {
|
||||
std::function<onnx::TensorProto(onnx::AttributeProto &)> attr_getter =
|
||||
[](onnx::AttributeProto &attr) { return attr.t(); };
|
||||
return get_attr_generic(node, name, attr_getter);
|
||||
}
|
||||
|
||||
auto ImportNodeAttr(onnx::NodeProto node, std::string attr_name,
|
||||
std::string type_name, std::string default_str) {
|
||||
if (default_str == "") {
|
||||
if (type_name == "int") {
|
||||
return get_attr_int(node, attr_name);
|
||||
} else if (type_name == "float") {
|
||||
return get_attr_float(node, attr_name);
|
||||
} else if (type_name == "str") {
|
||||
return get_attr_string(node, attr_name);
|
||||
} else if (type_name == "ints") {
|
||||
return get_attr_ints(node, attr_name);
|
||||
} else if (type_name == "floats") {
|
||||
return get_attr_floats(node, attr_name);
|
||||
} else {
|
||||
assert(
|
||||
false &&
|
||||
"Got an empty initializer or initializer for this "
|
||||
"datatype is not implemented. Something is wrong.");
|
||||
}
|
||||
} else {
|
||||
// with default value
|
||||
if (type_name == "int") {
|
||||
return get_attr_int(node, attr_name, std::stoi(default_str));
|
||||
} else if (type_name == "float") {
|
||||
return get_attr_float(node, attr_name, std::stof(default_str));
|
||||
} else if (type_name == "str") {
|
||||
return get_attr_string(node, attr_name, default_str);
|
||||
} else if (type_name == "ints") {
|
||||
return get_attr_ints(node, attr_name, get_default_ints(default_str));
|
||||
} else if (type_name == "floats") {
|
||||
return get_attr_floats(node, attr_name,
|
||||
get_default_floats(default_str));
|
||||
} else {
|
||||
assert(
|
||||
false &&
|
||||
"Got an empty initializer or initializer for this "
|
||||
"datatype is not implemented. Something is wrong.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ImportNodeGeneric(onnx::NodeProto node) {
|
||||
std::vector<mlir::Value *> inputs;
|
||||
for (auto item : node.input()) {
|
||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<mlir::Type> outputTypes;
|
||||
for (auto item : node.output()) {
|
||||
outputTypes.push_back(mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
||||
}
|
||||
|
||||
std::vector<mlir::NamedAttribute> attributes;
|
||||
llvm::StringRef OpName = node.op_type();
|
||||
|
||||
//the following code is generated by gen_doc.py
|
||||
//refer to dialect/onnx/onnx.td for details
|
||||
//when the input or output of then op does not match the specification,
|
||||
//the generic operator is used
|
||||
//one known reeason is the optional input
|
||||
|
||||
#include "src/builder/op_build_table.inc"
|
||||
|
||||
|
||||
// Old way of doing things.
|
||||
mlir::OperationState result(UnknownLoc(), "frontend." + node.op_type());
|
||||
for (auto item : node.output()) {
|
||||
result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
||||
|
@ -261,8 +496,165 @@ class FrontendGenImpl {
|
|||
auto r = op->getResult(i);
|
||||
frontend_symbols_.AddMapping(legalize_name(node.output()[i]), r);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO more info from node: attributes
|
||||
// if c++17 is used, ImportNodeOneOut and ImportNodeMultipleOuts can be
|
||||
// combined with 'if constexpr' the issue is the type of the output is
|
||||
// different. alternative way to use variadic output for all the op
|
||||
|
||||
/*!
|
||||
* Important onnx node which generates only one output
|
||||
* @param node onnx node
|
||||
* @param nIn number of expected inputs
|
||||
* @param nOut number of expected outputs
|
||||
* @param attrs list of desription for attributes with format {name, type,
|
||||
* default}
|
||||
*/
|
||||
template <typename T>
|
||||
void ImportNodeOneOut(
|
||||
onnx::NodeProto node, int nIn, int nOut,
|
||||
std::initializer_list<std::tuple<std::string, std::string, std::string>>
|
||||
attrs) {
|
||||
std::vector<mlir::Value *> inputs;
|
||||
for (auto item : node.input()) {
|
||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<mlir::Type> outputTypes;
|
||||
for (auto item : node.output()) {
|
||||
outputTypes.push_back(
|
||||
mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
||||
}
|
||||
|
||||
std::vector<mlir::NamedAttribute> attributes;
|
||||
//for (auto [attr_name, attr_type, attr_default] : attrs) {
|
||||
for (auto oneAttr: attrs) {
|
||||
std::string attr_name;
|
||||
std::string attr_type;
|
||||
std::string attr_default;
|
||||
std::tie (attr_name, attr_type, attr_default) = oneAttr;
|
||||
if (attr_type != "") {
|
||||
auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default);
|
||||
attributes.push_back(attr);
|
||||
} else {
|
||||
// TODO: the attributes need special handling
|
||||
//std::cout << "missing " << node.op_type() << " " << attr_name << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
llvm::StringRef OpName = node.op_type();
|
||||
|
||||
if (nIn == inputs.size() && nOut == outputTypes.size()) {
|
||||
auto op =
|
||||
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
|
||||
frontend_symbols_.AddMapping(legalize_name(node.output()[0]),
|
||||
op.getResult());
|
||||
} else {
|
||||
ImportNodeGeneric(node);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ImportNodeMultipleOuts(
|
||||
onnx::NodeProto node, int nIn, int nOut,
|
||||
std::initializer_list<std::tuple<std::string, std::string, std::string>>
|
||||
attrs) {
|
||||
std::vector<mlir::Value *> inputs;
|
||||
for (auto item : node.input()) {
|
||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<mlir::Type> outputTypes;
|
||||
for (auto item : node.output()) {
|
||||
outputTypes.push_back(
|
||||
mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
||||
}
|
||||
|
||||
std::vector<mlir::NamedAttribute> attributes;
|
||||
for (auto oneAttr: attrs) {
|
||||
std::string attr_name;
|
||||
std::string attr_type;
|
||||
std::string attr_default;
|
||||
std::tie (attr_name, attr_type, attr_default) = oneAttr;
|
||||
if (attr_type != "") {
|
||||
auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default);
|
||||
attributes.push_back(attr);
|
||||
} else {
|
||||
// TODO: the attributes need special handling
|
||||
//std::cout << "missing " << node.op_type() << " " << attr_name << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
llvm::StringRef OpName = node.op_type();
|
||||
|
||||
if (nIn == inputs.size() && nOut == outputTypes.size()) {
|
||||
auto op =
|
||||
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
|
||||
for (int i = 0; i < node.output().size(); i++) {
|
||||
frontend_symbols_.AddMapping(legalize_name(node.output()[i]),
|
||||
op.getResult(i));
|
||||
}
|
||||
} else {
|
||||
ImportNodeGeneric(node);
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* Special handle for Conv operations.
|
||||
* c++ does not allow template specialization inside a class scope
|
||||
* a specialized function is used
|
||||
*/
|
||||
void ImportNodeConv(
|
||||
onnx::NodeProto node, int nIn, int nOut,
|
||||
std::initializer_list<std::tuple<std::string, std::string, std::string>>
|
||||
attrs) {
|
||||
|
||||
// Conv has attribute dilations, kernel_shape, pads, the default value of
|
||||
// which is determined by the shape of first argument. However, since the
|
||||
// shape is unknown now, these attributes can be not generated auto
|
||||
// dilations_attr = get_attr_ints(node, "dilations",
|
||||
// std::vector<int>(inputs[0]->getType().cast<RankedTensorType>.getDims()-2,
|
||||
// 1));
|
||||
// attributes.push_back(dilations_attr)
|
||||
// similar situation for pads, strides in AveragePool
|
||||
// axes of ReduceSum, pads, strides, dilations and kernel_shape of MaxPool
|
||||
// TODO: fix this after type inference
|
||||
|
||||
if (node.input().size() == 1) {
|
||||
ImportNodeOneOut<mlir::ONNXConv1Op>(node, nIn, nOut, attrs);
|
||||
} else {
|
||||
ImportNodeOneOut<mlir::ONNXConv3Op>(node, nIn, nOut, attrs);
|
||||
}
|
||||
}
|
||||
|
||||
void ImportNode(onnx::NodeProto node) {
|
||||
std::vector<mlir::Value *> inputs;
|
||||
for (auto item : node.input()) {
|
||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<mlir::Type> outputTypes;
|
||||
for (auto item : node.output()) {
|
||||
outputTypes.push_back(
|
||||
mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
||||
}
|
||||
|
||||
std::vector<mlir::NamedAttribute> attributes;
|
||||
llvm::StringRef OpName = node.op_type();
|
||||
|
||||
// the following code is generated by gen_doc.py
|
||||
// refer to dialect/onnx/onnx.td for details
|
||||
// when the input or output of then op does not match the specification,
|
||||
// the generic operator is used
|
||||
// one known reeason is the optional input
|
||||
|
||||
#include "src/builder/op_build_table.inc"
|
||||
}
|
||||
|
||||
/*!
|
||||
|
@ -277,9 +669,9 @@ class FrontendGenImpl {
|
|||
* @param ret_vals a vector of mlir Value* representing graph's
|
||||
* output tensor.
|
||||
*/
|
||||
void ImportOutputTensor(const onnx::ValueInfoProto& output,
|
||||
llvm::SmallVectorImpl<mlir::Type>& ret_types,
|
||||
llvm::SmallVectorImpl<mlir::Value*>& ret_vals) {
|
||||
void ImportOutputTensor(const onnx::ValueInfoProto &output,
|
||||
llvm::SmallVectorImpl<mlir::Type> &ret_types,
|
||||
llvm::SmallVectorImpl<mlir::Value *> &ret_vals) {
|
||||
auto output_tensor_legalized_name = legalize_name(output.name());
|
||||
assert(
|
||||
frontend_symbols_.ContainKey(output_tensor_legalized_name) &&
|
||||
|
@ -291,8 +683,8 @@ class FrontendGenImpl {
|
|||
ret_vals.push_back(tensor_val);
|
||||
}
|
||||
|
||||
void ImportGraph(
|
||||
const onnx::GraphProto& graph, const std::string& name = "main") {
|
||||
void ImportGraph(const onnx::GraphProto &graph,
|
||||
const std::string &name = "main") {
|
||||
// create a function for the graph
|
||||
// TODO:
|
||||
// * get name and type for the function.
|
||||
|
@ -300,7 +692,7 @@ class FrontendGenImpl {
|
|||
llvm::SmallVector<mlir::Type, 4> arg_types;
|
||||
|
||||
// Import the input tensor types.
|
||||
for (const auto& input : graph.input()) {
|
||||
for (const auto &input : graph.input()) {
|
||||
ImportInputTensorType(input, arg_types);
|
||||
}
|
||||
|
||||
|
@ -308,7 +700,7 @@ class FrontendGenImpl {
|
|||
auto func_type = builder_.getFunctionType(arg_types, {});
|
||||
auto main_func =
|
||||
mlir::FuncOp::create(UnknownLoc(), name, func_type, /* attrs = */ {});
|
||||
auto& entryBlock = *main_func.addEntryBlock();
|
||||
auto &entryBlock = *main_func.addEntryBlock();
|
||||
|
||||
builder_.setInsertionPointToStart(&entryBlock);
|
||||
module_.push_back(main_func);
|
||||
|
@ -319,14 +711,14 @@ class FrontendGenImpl {
|
|||
|
||||
// import nodes in the graph
|
||||
auto node = graph.node();
|
||||
for (const auto& item : node) {
|
||||
for (const auto &item : node) {
|
||||
ImportNode(item);
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::Type, 4> ret_types;
|
||||
llvm::SmallVector<mlir::Value*, 4> ret_vals;
|
||||
llvm::SmallVector<mlir::Value *, 4> ret_vals;
|
||||
// Import the output tensors
|
||||
for (const auto& output : graph.output()) {
|
||||
for (const auto &output : graph.output()) {
|
||||
ImportOutputTensor(output, ret_types, ret_vals);
|
||||
}
|
||||
|
||||
|
@ -337,7 +729,6 @@ class FrontendGenImpl {
|
|||
func_type = builder_.getFunctionType(arg_types, ret_types);
|
||||
main_func.setType(func_type);
|
||||
}
|
||||
|
||||
}; // FrontendGenImpl class
|
||||
} // namespace
|
||||
} // namespace onnf
|
||||
|
@ -354,11 +745,13 @@ mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model) {
|
|||
}
|
||||
|
||||
void ImportFrontendModelFile(std::string model_fname,
|
||||
mlir::MLIRContext& context, mlir::OwningModuleRef& module) {
|
||||
mlir::MLIRContext &context,
|
||||
mlir::OwningModuleRef &module) {
|
||||
onnx::ModelProto model;
|
||||
std::fstream input(model_fname, std::ios::in | std::ios::binary);
|
||||
|
||||
auto parse_success = model.ParseFromIstream(&input);
|
||||
assert(parse_success && "Onnx Model Parsing Failed.");
|
||||
|
||||
FrontendGenImpl myONNXGen(context);
|
||||
module = myONNXGen.ImportONNXModel(model);
|
||||
|
|
|
@ -1,313 +1,688 @@
|
|||
if (OpName == "Abs") {
|
||||
OneOut(Abs, 1, 1);
|
||||
if (OpName == "DUMMY") {
|
||||
}else if (OpName == "Abs") {
|
||||
ImportNodeOneOut<mlir::ONNXAbsOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Acos") {
|
||||
OneOut(Acos, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXAcosOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Acosh") {
|
||||
OneOut(Acosh, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXAcoshOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Add") {
|
||||
OneOut(Add, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXAddOp>(node, 2, 1, {
|
||||
});
|
||||
}else if (OpName == "And") {
|
||||
OneOut(And, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXAndOp>(node, 2, 1, {
|
||||
});
|
||||
}else if (OpName == "ArgMax") {
|
||||
OneOut(ArgMax, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXArgMaxOp>(node, 1, 1, {
|
||||
{"axis","int","0"}
|
||||
,{"keepdims","int","1"}
|
||||
});
|
||||
}else if (OpName == "ArgMin") {
|
||||
OneOut(ArgMin, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXArgMinOp>(node, 1, 1, {
|
||||
{"axis","int","0"}
|
||||
,{"keepdims","int","1"}
|
||||
});
|
||||
}else if (OpName == "Asin") {
|
||||
OneOut(Asin, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXAsinOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Asinh") {
|
||||
OneOut(Asinh, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXAsinhOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Atan") {
|
||||
OneOut(Atan, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXAtanOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Atanh") {
|
||||
OneOut(Atanh, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXAtanhOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "AveragePool") {
|
||||
OneOut(AveragePool, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXAveragePoolOp>(node, 1, 1, {
|
||||
{"auto_pad","str","NOTSET"}
|
||||
,{"ceil_mode","int","0"}
|
||||
,{"count_include_pad","int","0"}
|
||||
,{"kernel_shape","ints", ""}
|
||||
,{"pads","", ""}
|
||||
,{"strides","", ""}
|
||||
});
|
||||
}else if (OpName == "BatchNormalization") {
|
||||
MultipleOuts(BatchNormalization, 5, 5);
|
||||
ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, 5, 5, {
|
||||
{"epsilon","float","1e-05"}
|
||||
,{"momentum","float","0.9"}
|
||||
});
|
||||
}else if (OpName == "BitShift") {
|
||||
OneOut(BitShift, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXBitShiftOp>(node, 2, 1, {
|
||||
{"direction","", ""}
|
||||
});
|
||||
}else if (OpName == "Cast") {
|
||||
OneOut(Cast, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXCastOp>(node, 1, 1, {
|
||||
{"to","int", "0"}
|
||||
});
|
||||
}else if (OpName == "Ceil") {
|
||||
OneOut(Ceil, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXCeilOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Clip") {
|
||||
OneOut(Clip, 3, 1);
|
||||
ImportNodeOneOut<mlir::ONNXClipOp>(node, 3, 1, {
|
||||
});
|
||||
}else if (OpName == "Compress") {
|
||||
OneOut(Compress, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXCompressOp>(node, 2, 1, {
|
||||
{"axis","", ""}
|
||||
});
|
||||
}else if (OpName == "Concat") {
|
||||
OneOut(Concat, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXConcatOp>(node, 1, 1, {
|
||||
{"axis","int", "0"}
|
||||
});
|
||||
}else if (OpName == "ConcatFromSequence") {
|
||||
OneOut(ConcatFromSequence, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXConcatFromSequenceOp>(node, 1, 1, {
|
||||
{"axis","", ""}
|
||||
,{"new_axis","int","0"}
|
||||
});
|
||||
}else if (OpName == "Constant") {
|
||||
OneOut(Constant, 0, 1);
|
||||
ImportNodeOneOut<mlir::ONNXConstantOp>(node, 0, 1, {
|
||||
{"sparse_value","", ""}
|
||||
,{"value","", ""}
|
||||
});
|
||||
}else if (OpName == "ConstantOfShape") {
|
||||
OneOut(ConstantOfShape, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXConstantOfShapeOp>(node, 1, 1, {
|
||||
{"value","", ""}
|
||||
});
|
||||
}else if (OpName == "Conv") {
|
||||
OneOut(Conv, 3, 1);
|
||||
ImportNodeConv(node, 3, 1, {
|
||||
{"auto_pad","str","NOTSET"}
|
||||
,{"dilations","", ""}
|
||||
,{"group","int", "1"}
|
||||
,{"kernel_shape","", ""}
|
||||
,{"pads","", ""}
|
||||
,{"strides","", ""}
|
||||
});
|
||||
}else if (OpName == "ConvInteger") {
|
||||
OneOut(ConvInteger, 4, 1);
|
||||
ImportNodeOneOut<mlir::ONNXConvIntegerOp>(node, 4, 1, {
|
||||
{"auto_pad","str","NOTSET"}
|
||||
,{"dilations","", ""}
|
||||
,{"group","int","1"}
|
||||
,{"kernel_shape","", ""}
|
||||
,{"pads","", ""}
|
||||
,{"strides","", ""}
|
||||
});
|
||||
}else if (OpName == "ConvTranspose") {
|
||||
OneOut(ConvTranspose, 3, 1);
|
||||
ImportNodeOneOut<mlir::ONNXConvTransposeOp>(node, 3, 1, {
|
||||
{"auto_pad","str","NOTSET"}
|
||||
,{"dilations","", ""}
|
||||
,{"group","int","1"}
|
||||
,{"kernel_shape","", ""}
|
||||
,{"output_padding","", ""}
|
||||
,{"output_shape","", ""}
|
||||
,{"pads","", ""}
|
||||
,{"strides","", ""}
|
||||
});
|
||||
}else if (OpName == "Cos") {
|
||||
OneOut(Cos, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXCosOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Cosh") {
|
||||
OneOut(Cosh, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXCoshOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "CumSum") {
|
||||
OneOut(CumSum, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXCumSumOp>(node, 2, 1, {
|
||||
{"exclusive","int","0"}
|
||||
,{"reverse","int","0"}
|
||||
});
|
||||
}else if (OpName == "DepthToSpace") {
|
||||
OneOut(DepthToSpace, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXDepthToSpaceOp>(node, 1, 1, {
|
||||
{"blocksize","", ""}
|
||||
,{"mode","str","DCR"}
|
||||
});
|
||||
}else if (OpName == "DequantizeLinear") {
|
||||
OneOut(DequantizeLinear, 3, 1);
|
||||
ImportNodeOneOut<mlir::ONNXDequantizeLinearOp>(node, 3, 1, {
|
||||
});
|
||||
}else if (OpName == "Det") {
|
||||
OneOut(Det, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXDetOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Div") {
|
||||
OneOut(Div, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXDivOp>(node, 2, 1, {
|
||||
});
|
||||
}else if (OpName == "Dropout") {
|
||||
MultipleOuts(Dropout, 1, 2);
|
||||
ImportNodeMultipleOuts<mlir::ONNXDropoutOp>(node, 1, 2, {
|
||||
{"ratio","float","0.5"}
|
||||
});
|
||||
}else if (OpName == "DynamicQuantizeLinear") {
|
||||
MultipleOuts(DynamicQuantizeLinear, 1, 3);
|
||||
ImportNodeMultipleOuts<mlir::ONNXDynamicQuantizeLinearOp>(node, 1, 3, {
|
||||
});
|
||||
}else if (OpName == "Elu") {
|
||||
OneOut(Elu, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXEluOp>(node, 1, 1, {
|
||||
{"alpha","float","1.0"}
|
||||
});
|
||||
}else if (OpName == "Equal") {
|
||||
OneOut(Equal, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXEqualOp>(node, 2, 1, {
|
||||
});
|
||||
}else if (OpName == "Erf") {
|
||||
OneOut(Erf, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXErfOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Exp") {
|
||||
OneOut(Exp, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXExpOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Expand") {
|
||||
OneOut(Expand, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXExpandOp>(node, 2, 1, {
|
||||
});
|
||||
}else if (OpName == "EyeLike") {
|
||||
OneOut(EyeLike, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXEyeLikeOp>(node, 1, 1, {
|
||||
{"dtype","", ""}
|
||||
,{"k","int","0"}
|
||||
});
|
||||
}else if (OpName == "Flatten") {
|
||||
OneOut(Flatten, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXFlattenOp>(node, 1, 1, {
|
||||
{"axis","int","1"}
|
||||
});
|
||||
}else if (OpName == "Floor") {
|
||||
OneOut(Floor, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXFloorOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "GRU") {
|
||||
MultipleOuts(GRU, 6, 2);
|
||||
ImportNodeMultipleOuts<mlir::ONNXGRUOp>(node, 6, 2, {
|
||||
{"activation_alpha","", ""}
|
||||
,{"activation_beta","", ""}
|
||||
,{"activations","", ""}
|
||||
,{"clip","", ""}
|
||||
,{"direction","str","forward"}
|
||||
,{"hidden_size","", ""}
|
||||
,{"linear_before_reset","int","0"}
|
||||
});
|
||||
}else if (OpName == "Gather") {
|
||||
OneOut(Gather, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXGatherOp>(node, 2, 1, {
|
||||
{"axis","int","0"}
|
||||
});
|
||||
}else if (OpName == "GatherElements") {
|
||||
OneOut(GatherElements, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXGatherElementsOp>(node, 2, 1, {
|
||||
{"axis","int","0"}
|
||||
});
|
||||
}else if (OpName == "GatherND") {
|
||||
OneOut(GatherND, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXGatherNDOp>(node, 2, 1, {
|
||||
});
|
||||
}else if (OpName == "Gemm") {
|
||||
OneOut(Gemm, 3, 1);
|
||||
ImportNodeOneOut<mlir::ONNXGemmOp>(node, 3, 1, {
|
||||
{"alpha","float","1.0"}
|
||||
,{"beta","float","1.0"}
|
||||
,{"transA","int","0"}
|
||||
,{"transB","int","0"}
|
||||
});
|
||||
}else if (OpName == "GlobalAveragePool") {
|
||||
OneOut(GlobalAveragePool, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXGlobalAveragePoolOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "GlobalLpPool") {
|
||||
OneOut(GlobalLpPool, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXGlobalLpPoolOp>(node, 1, 1, {
|
||||
{"p","int","2"}
|
||||
});
|
||||
}else if (OpName == "GlobalMaxPool") {
|
||||
OneOut(GlobalMaxPool, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXGlobalMaxPoolOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Greater") {
|
||||
OneOut(Greater, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXGreaterOp>(node, 2, 1, {
|
||||
});
|
||||
}else if (OpName == "HardSigmoid") {
|
||||
OneOut(HardSigmoid, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXHardSigmoidOp>(node, 1, 1, {
|
||||
{"alpha","float","0.2"}
|
||||
,{"beta","float","0.5"}
|
||||
});
|
||||
}else if (OpName == "Hardmax") {
|
||||
OneOut(Hardmax, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXHardmaxOp>(node, 1, 1, {
|
||||
{"axis","int","1"}
|
||||
});
|
||||
}else if (OpName == "Identity") {
|
||||
OneOut(Identity, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXIdentityOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "If") {
|
||||
OneOut(If, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXIfOp>(node, 1, 1, {
|
||||
{"else_branch","", ""}
|
||||
,{"then_branch","", ""}
|
||||
});
|
||||
}else if (OpName == "InstanceNormalization") {
|
||||
OneOut(InstanceNormalization, 3, 1);
|
||||
ImportNodeOneOut<mlir::ONNXInstanceNormalizationOp>(node, 3, 1, {
|
||||
{"epsilon","float","1e-05"}
|
||||
});
|
||||
}else if (OpName == "IsInf") {
|
||||
OneOut(IsInf, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXIsInfOp>(node, 1, 1, {
|
||||
{"detect_negative","int","1"}
|
||||
,{"detect_positive","int","1"}
|
||||
});
|
||||
}else if (OpName == "IsNaN") {
|
||||
OneOut(IsNaN, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXIsNaNOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "LRN") {
|
||||
OneOut(LRN, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXLRNOp>(node, 1, 1, {
|
||||
{"alpha","float","0.0001"}
|
||||
,{"beta","float","0.75"}
|
||||
,{"bias","float","1.0"}
|
||||
,{"size","int", ""}
|
||||
});
|
||||
}else if (OpName == "LSTM") {
|
||||
MultipleOuts(LSTM, 8, 3);
|
||||
ImportNodeMultipleOuts<mlir::ONNXLSTMOp>(node, 8, 3, {
|
||||
{"activation_alpha","", ""}
|
||||
,{"activation_beta","", ""}
|
||||
,{"activations","", ""}
|
||||
,{"clip","", ""}
|
||||
,{"direction","str","forward"}
|
||||
,{"hidden_size","", ""}
|
||||
,{"input_forget","int","0"}
|
||||
});
|
||||
}else if (OpName == "LeakyRelu") {
|
||||
OneOut(LeakyRelu, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXLeakyReluOp>(node, 1, 1, {
|
||||
{"alpha","float","0.01"}
|
||||
});
|
||||
}else if (OpName == "Less") {
|
||||
OneOut(Less, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXLessOp>(node, 2, 1, {
|
||||
});
|
||||
}else if (OpName == "Log") {
|
||||
OneOut(Log, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXLogOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "LogSoftmax") {
|
||||
OneOut(LogSoftmax, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXLogSoftmaxOp>(node, 1, 1, {
|
||||
{"axis","int","1"}
|
||||
});
|
||||
}else if (OpName == "Loop") {
|
||||
OneOut(Loop, 3, 1);
|
||||
ImportNodeOneOut<mlir::ONNXLoopOp>(node, 3, 1, {
|
||||
{"body","", ""}
|
||||
});
|
||||
}else if (OpName == "LpNormalization") {
|
||||
OneOut(LpNormalization, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXLpNormalizationOp>(node, 1, 1, {
|
||||
{"axis","int","-1"}
|
||||
,{"p","int","2"}
|
||||
});
|
||||
}else if (OpName == "LpPool") {
|
||||
OneOut(LpPool, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXLpPoolOp>(node, 1, 1, {
|
||||
{"auto_pad","str","NOTSET"}
|
||||
,{"kernel_shape","", ""}
|
||||
,{"p","int","2"}
|
||||
,{"pads","", ""}
|
||||
,{"strides","", ""}
|
||||
});
|
||||
}else if (OpName == "MatMul") {
|
||||
OneOut(MatMul, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXMatMulOp>(node, 2, 1, {
|
||||
});
|
||||
}else if (OpName == "MatMulInteger") {
|
||||
OneOut(MatMulInteger, 4, 1);
|
||||
ImportNodeOneOut<mlir::ONNXMatMulIntegerOp>(node, 4, 1, {
|
||||
});
|
||||
}else if (OpName == "Max") {
|
||||
OneOut(Max, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "MaxPool") {
|
||||
MultipleOuts(MaxPool, 1, 2);
|
||||
ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>(node, 1, 2, {
|
||||
{"auto_pad","str","NOTSET"}
|
||||
,{"ceil_mode","int","0"}
|
||||
,{"dilations","", ""}
|
||||
,{"kernel_shape","ints", ""}
|
||||
,{"pads","", ""}
|
||||
,{"storage_order","int","0"}
|
||||
,{"strides","", ""}
|
||||
});
|
||||
}else if (OpName == "MaxRoiPool") {
|
||||
OneOut(MaxRoiPool, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXMaxRoiPoolOp>(node, 2, 1, {
|
||||
{"pooled_shape","", ""}
|
||||
,{"spatial_scale","float","1.0"}
|
||||
});
|
||||
}else if (OpName == "MaxUnpool") {
|
||||
OneOut(MaxUnpool, 3, 1);
|
||||
ImportNodeOneOut<mlir::ONNXMaxUnpoolOp>(node, 3, 1, {
|
||||
{"kernel_shape","", ""}
|
||||
,{"pads","", ""}
|
||||
,{"strides","", ""}
|
||||
});
|
||||
}else if (OpName == "Mean") {
|
||||
OneOut(Mean, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXMeanOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "MeanVarianceNormalization") {
|
||||
OneOut(MeanVarianceNormalization, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXMeanVarianceNormalizationOp>(node, 1, 1, {
|
||||
{"axes","ints","{'0', '2', '3'}"}
|
||||
});
|
||||
}else if (OpName == "Min") {
|
||||
OneOut(Min, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXMinOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Mod") {
|
||||
OneOut(Mod, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXModOp>(node, 2, 1, {
|
||||
{"fmod","int","0"}
|
||||
});
|
||||
}else if (OpName == "Mul") {
|
||||
OneOut(Mul, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXMulOp>(node, 2, 1, {
|
||||
});
|
||||
}else if (OpName == "Multinomial") {
|
||||
OneOut(Multinomial, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXMultinomialOp>(node, 1, 1, {
|
||||
{"dtype","int","6"}
|
||||
,{"sample_size","int","1"}
|
||||
,{"seed","", ""}
|
||||
});
|
||||
}else if (OpName == "Neg") {
|
||||
OneOut(Neg, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXNegOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "NonMaxSuppression") {
|
||||
OneOut(NonMaxSuppression, 5, 1);
|
||||
ImportNodeOneOut<mlir::ONNXNonMaxSuppressionOp>(node, 5, 1, {
|
||||
{"center_point_box","int","0"}
|
||||
});
|
||||
}else if (OpName == "NonZero") {
|
||||
OneOut(NonZero, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXNonZeroOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Not") {
|
||||
OneOut(Not, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXNotOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "OneHot") {
|
||||
OneOut(OneHot, 3, 1);
|
||||
ImportNodeOneOut<mlir::ONNXOneHotOp>(node, 3, 1, {
|
||||
{"axis","int","-1"}
|
||||
});
|
||||
}else if (OpName == "Or") {
|
||||
OneOut(Or, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXOrOp>(node, 2, 1, {
|
||||
});
|
||||
}else if (OpName == "PRelu") {
|
||||
OneOut(PRelu, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXPReluOp>(node, 2, 1, {
|
||||
});
|
||||
}else if (OpName == "Pad") {
|
||||
OneOut(Pad, 3, 1);
|
||||
ImportNodeOneOut<mlir::ONNXPadOp>(node, 3, 1, {
|
||||
{"mode","str","constant"}
|
||||
});
|
||||
}else if (OpName == "Pow") {
|
||||
OneOut(Pow, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXPowOp>(node, 2, 1, {
|
||||
});
|
||||
}else if (OpName == "QLinearConv") {
|
||||
OneOut(QLinearConv, 9, 1);
|
||||
ImportNodeOneOut<mlir::ONNXQLinearConvOp>(node, 9, 1, {
|
||||
{"auto_pad","str","NOTSET"}
|
||||
,{"dilations","", ""}
|
||||
,{"group","int","1"}
|
||||
,{"kernel_shape","", ""}
|
||||
,{"pads","", ""}
|
||||
,{"strides","", ""}
|
||||
});
|
||||
}else if (OpName == "QLinearMatMul") {
|
||||
OneOut(QLinearMatMul, 8, 1);
|
||||
ImportNodeOneOut<mlir::ONNXQLinearMatMulOp>(node, 8, 1, {
|
||||
});
|
||||
}else if (OpName == "QuantizeLinear") {
|
||||
OneOut(QuantizeLinear, 3, 1);
|
||||
ImportNodeOneOut<mlir::ONNXQuantizeLinearOp>(node, 3, 1, {
|
||||
});
|
||||
}else if (OpName == "RNN") {
|
||||
MultipleOuts(RNN, 6, 2);
|
||||
ImportNodeMultipleOuts<mlir::ONNXRNNOp>(node, 6, 2, {
|
||||
{"activation_alpha","floats", "{}"}
|
||||
,{"activation_beta","floats", "{}"}
|
||||
,{"activations","", "{Tannh, Tanh}"}
|
||||
,{"clip","", ""}
|
||||
,{"direction","str","forward"}
|
||||
,{"hidden_size","", ""}
|
||||
});
|
||||
}else if (OpName == "RandomNormal") {
|
||||
OneOut(RandomNormal, 0, 1);
|
||||
ImportNodeOneOut<mlir::ONNXRandomNormalOp>(node, 0, 1, {
|
||||
{"dtype","int","1"}
|
||||
,{"mean","float","0.0"}
|
||||
,{"scale","float","1.0"}
|
||||
,{"seed","", ""}
|
||||
,{"shape","", ""}
|
||||
});
|
||||
}else if (OpName == "RandomNormalLike") {
|
||||
OneOut(RandomNormalLike, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXRandomNormalLikeOp>(node, 1, 1, {
|
||||
{"dtype","", ""}
|
||||
,{"mean","float","0.0"}
|
||||
,{"scale","float","1.0"}
|
||||
,{"seed","", ""}
|
||||
});
|
||||
}else if (OpName == "RandomUniform") {
|
||||
OneOut(RandomUniform, 0, 1);
|
||||
ImportNodeOneOut<mlir::ONNXRandomUniformOp>(node, 0, 1, {
|
||||
{"dtype","int","1"}
|
||||
,{"high","float","1.0"}
|
||||
,{"low","float","0.0"}
|
||||
,{"seed","", ""}
|
||||
,{"shape","", ""}
|
||||
});
|
||||
}else if (OpName == "RandomUniformLike") {
|
||||
OneOut(RandomUniformLike, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXRandomUniformLikeOp>(node, 1, 1, {
|
||||
{"dtype","", ""}
|
||||
,{"high","float","1.0"}
|
||||
,{"low","float","0.0"}
|
||||
,{"seed","", ""}
|
||||
});
|
||||
}else if (OpName == "Range") {
|
||||
OneOut(Range, 3, 1);
|
||||
ImportNodeOneOut<mlir::ONNXRangeOp>(node, 3, 1, {
|
||||
});
|
||||
}else if (OpName == "Reciprocal") {
|
||||
OneOut(Reciprocal, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXReciprocalOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "ReduceL1") {
|
||||
OneOut(ReduceL1, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXReduceL1Op>(node, 1, 1, {
|
||||
{"axes","", ""}
|
||||
,{"keepdims","int","1"}
|
||||
});
|
||||
}else if (OpName == "ReduceL2") {
|
||||
OneOut(ReduceL2, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXReduceL2Op>(node, 1, 1, {
|
||||
{"axes","", ""}
|
||||
,{"keepdims","int","1"}
|
||||
});
|
||||
}else if (OpName == "ReduceLogSum") {
|
||||
OneOut(ReduceLogSum, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXReduceLogSumOp>(node, 1, 1, {
|
||||
{"axes","", ""}
|
||||
,{"keepdims","int","1"}
|
||||
});
|
||||
}else if (OpName == "ReduceLogSumExp") {
|
||||
OneOut(ReduceLogSumExp, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXReduceLogSumExpOp>(node, 1, 1, {
|
||||
{"axes","", ""}
|
||||
,{"keepdims","int","1"}
|
||||
});
|
||||
}else if (OpName == "ReduceMax") {
|
||||
OneOut(ReduceMax, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXReduceMaxOp>(node, 1, 1, {
|
||||
{"axes","", ""}
|
||||
,{"keepdims","int","1"}
|
||||
});
|
||||
}else if (OpName == "ReduceMean") {
|
||||
OneOut(ReduceMean, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXReduceMeanOp>(node, 1, 1, {
|
||||
{"axes","", ""}
|
||||
,{"keepdims","int","1"}
|
||||
});
|
||||
}else if (OpName == "ReduceMin") {
|
||||
OneOut(ReduceMin, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXReduceMinOp>(node, 1, 1, {
|
||||
{"axes","", ""}
|
||||
,{"keepdims","int","1"}
|
||||
});
|
||||
}else if (OpName == "ReduceProd") {
|
||||
OneOut(ReduceProd, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXReduceProdOp>(node, 1, 1, {
|
||||
{"axes","", ""}
|
||||
,{"keepdims","int","1"}
|
||||
});
|
||||
}else if (OpName == "ReduceSum") {
|
||||
OneOut(ReduceSum, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXReduceSumOp>(node, 1, 1, {
|
||||
{"axes","", ""}
|
||||
,{"keepdims","int","1"}
|
||||
});
|
||||
}else if (OpName == "ReduceSumSquare") {
|
||||
OneOut(ReduceSumSquare, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXReduceSumSquareOp>(node, 1, 1, {
|
||||
{"axes","", ""}
|
||||
,{"keepdims","int","1"}
|
||||
});
|
||||
}else if (OpName == "Relu") {
|
||||
OneOut(Relu, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXReluOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Reshape") {
|
||||
OneOut(Reshape, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXReshapeOp>(node, 2, 1, {
|
||||
});
|
||||
}else if (OpName == "Resize") {
|
||||
OneOut(Resize, 4, 1);
|
||||
ImportNodeOneOut<mlir::ONNXResizeOp>(node, 4, 1, {
|
||||
{"coordinate_transformation_mode","str","half_pixel"}
|
||||
,{"cubic_coeff_a","float","-0.75"}
|
||||
,{"exclude_outside","int","0"}
|
||||
,{"extrapolation_value","float","0.0"}
|
||||
,{"mode","str","nearest"}
|
||||
,{"nearest_mode","str","round_prefer_floor"}
|
||||
});
|
||||
}else if (OpName == "ReverseSequence") {
|
||||
OneOut(ReverseSequence, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXReverseSequenceOp>(node, 2, 1, {
|
||||
{"batch_axis","int","1"}
|
||||
,{"time_axis","int","0"}
|
||||
});
|
||||
}else if (OpName == "RoiAlign") {
|
||||
OneOut(RoiAlign, 3, 1);
|
||||
ImportNodeOneOut<mlir::ONNXRoiAlignOp>(node, 3, 1, {
|
||||
{"mode","str","avg"}
|
||||
,{"output_height","int","1"}
|
||||
,{"output_width","int","1"}
|
||||
,{"sampling_ratio","int","0"}
|
||||
,{"spatial_scale","float","1.0"}
|
||||
});
|
||||
}else if (OpName == "Round") {
|
||||
OneOut(Round, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXRoundOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Scan") {
|
||||
OneOut(Scan, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXScanOp>(node, 1, 1, {
|
||||
{"body","", ""}
|
||||
,{"num_scan_inputs","", ""}
|
||||
,{"scan_input_axes","", ""}
|
||||
,{"scan_input_directions","", ""}
|
||||
,{"scan_output_axes","", ""}
|
||||
,{"scan_output_directions","", ""}
|
||||
});
|
||||
}else if (OpName == "Scatter") {
|
||||
OneOut(Scatter, 3, 1);
|
||||
ImportNodeOneOut<mlir::ONNXScatterOp>(node, 3, 1, {
|
||||
{"axis","int","0"}
|
||||
});
|
||||
}else if (OpName == "ScatterElements") {
|
||||
OneOut(ScatterElements, 3, 1);
|
||||
ImportNodeOneOut<mlir::ONNXScatterElementsOp>(node, 3, 1, {
|
||||
{"axis","int","0"}
|
||||
});
|
||||
}else if (OpName == "ScatterND") {
|
||||
OneOut(ScatterND, 3, 1);
|
||||
ImportNodeOneOut<mlir::ONNXScatterNDOp>(node, 3, 1, {
|
||||
});
|
||||
}else if (OpName == "Selu") {
|
||||
OneOut(Selu, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXSeluOp>(node, 1, 1, {
|
||||
{"alpha","float","1.67326"}
|
||||
,{"gamma","float","1.0507"}
|
||||
});
|
||||
}else if (OpName == "SequenceAt") {
|
||||
OneOut(SequenceAt, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXSequenceAtOp>(node, 2, 1, {
|
||||
});
|
||||
}else if (OpName == "SequenceConstruct") {
|
||||
OneOut(SequenceConstruct, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXSequenceConstructOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "SequenceEmpty") {
|
||||
OneOut(SequenceEmpty, 0, 1);
|
||||
ImportNodeOneOut<mlir::ONNXSequenceEmptyOp>(node, 0, 1, {
|
||||
{"dtype","", ""}
|
||||
});
|
||||
}else if (OpName == "SequenceErase") {
|
||||
OneOut(SequenceErase, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXSequenceEraseOp>(node, 2, 1, {
|
||||
});
|
||||
}else if (OpName == "SequenceInsert") {
|
||||
OneOut(SequenceInsert, 3, 1);
|
||||
ImportNodeOneOut<mlir::ONNXSequenceInsertOp>(node, 3, 1, {
|
||||
});
|
||||
}else if (OpName == "SequenceLength") {
|
||||
OneOut(SequenceLength, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXSequenceLengthOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Shape") {
|
||||
OneOut(Shape, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXShapeOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Shrink") {
|
||||
OneOut(Shrink, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXShrinkOp>(node, 1, 1, {
|
||||
{"bias","float","0.0"}
|
||||
,{"lambd","float","0.5"}
|
||||
});
|
||||
}else if (OpName == "Sigmoid") {
|
||||
OneOut(Sigmoid, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXSigmoidOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Sign") {
|
||||
OneOut(Sign, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXSignOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Sin") {
|
||||
OneOut(Sin, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXSinOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Sinh") {
|
||||
OneOut(Sinh, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXSinhOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Size") {
|
||||
OneOut(Size, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXSizeOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Slice") {
|
||||
OneOut(Slice, 5, 1);
|
||||
ImportNodeOneOut<mlir::ONNXSliceOp>(node, 5, 1, {
|
||||
});
|
||||
}else if (OpName == "Softmax") {
|
||||
OneOut(Softmax, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXSoftmaxOp>(node, 1, 1, {
|
||||
{"axis","int","1"}
|
||||
});
|
||||
}else if (OpName == "Softplus") {
|
||||
OneOut(Softplus, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXSoftplusOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Softsign") {
|
||||
OneOut(Softsign, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXSoftsignOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "SpaceToDepth") {
|
||||
OneOut(SpaceToDepth, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXSpaceToDepthOp>(node, 1, 1, {
|
||||
{"blocksize","", ""}
|
||||
});
|
||||
}else if (OpName == "Split") {
|
||||
OneOut(Split, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXSplitOp>(node, 1, 1, {
|
||||
{"axis","int","0"}
|
||||
,{"split","", ""}
|
||||
});
|
||||
}else if (OpName == "SplitToSequence") {
|
||||
OneOut(SplitToSequence, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXSplitToSequenceOp>(node, 2, 1, {
|
||||
{"axis","int","0"}
|
||||
,{"keepdims","int","1"}
|
||||
});
|
||||
}else if (OpName == "Sqrt") {
|
||||
OneOut(Sqrt, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXSqrtOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Squeeze") {
|
||||
OneOut(Squeeze, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXSqueezeOp>(node, 1, 1, {
|
||||
{"axes","", ""}
|
||||
});
|
||||
}else if (OpName == "StringNormalizer") {
|
||||
OneOut(StringNormalizer, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXStringNormalizerOp>(node, 1, 1, {
|
||||
{"case_change_action","str","NONE"}
|
||||
,{"is_case_sensitive","int","0"}
|
||||
,{"locale","", ""}
|
||||
,{"stopwords","", ""}
|
||||
});
|
||||
}else if (OpName == "Sub") {
|
||||
OneOut(Sub, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXSubOp>(node, 2, 1, {
|
||||
});
|
||||
}else if (OpName == "Sum") {
|
||||
OneOut(Sum, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXSumOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Tan") {
|
||||
OneOut(Tan, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXTanOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "Tanh") {
|
||||
OneOut(Tanh, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXTanhOp>(node, 1, 1, {
|
||||
});
|
||||
}else if (OpName == "TfIdfVectorizer") {
|
||||
OneOut(TfIdfVectorizer, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXTfIdfVectorizerOp>(node, 1, 1, {
|
||||
{"max_gram_length","", ""}
|
||||
,{"max_skip_count","", ""}
|
||||
,{"min_gram_length","", ""}
|
||||
,{"mode","", ""}
|
||||
,{"ngram_counts","", ""}
|
||||
,{"ngram_indexes","", ""}
|
||||
,{"pool_int64s","", ""}
|
||||
,{"pool_strings","", ""}
|
||||
,{"weights","", ""}
|
||||
});
|
||||
}else if (OpName == "ThresholdedRelu") {
|
||||
OneOut(ThresholdedRelu, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXThresholdedReluOp>(node, 1, 1, {
|
||||
{"alpha","float","1.0"}
|
||||
});
|
||||
}else if (OpName == "Tile") {
|
||||
OneOut(Tile, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXTileOp>(node, 2, 1, {
|
||||
});
|
||||
}else if (OpName == "TopK") {
|
||||
MultipleOuts(TopK, 2, 2);
|
||||
ImportNodeMultipleOuts<mlir::ONNXTopKOp>(node, 2, 2, {
|
||||
{"axis","int","-1"}
|
||||
,{"largest","int","1"}
|
||||
,{"sorted","int","1"}
|
||||
});
|
||||
}else if (OpName == "Transpose") {
|
||||
OneOut(Transpose, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXTransposeOp>(node, 1, 1, {
|
||||
{"perm","", ""}
|
||||
});
|
||||
}else if (OpName == "Unique") {
|
||||
MultipleOuts(Unique, 1, 4);
|
||||
ImportNodeMultipleOuts<mlir::ONNXUniqueOp>(node, 1, 4, {
|
||||
{"axis","", ""}
|
||||
,{"sorted","int","1"}
|
||||
});
|
||||
}else if (OpName == "Unsqueeze") {
|
||||
OneOut(Unsqueeze, 1, 1);
|
||||
ImportNodeOneOut<mlir::ONNXUnsqueezeOp>(node, 1, 1, {
|
||||
{"axes","ints", ""}
|
||||
});
|
||||
}else if (OpName == "Upsample") {
|
||||
OneOut(Upsample, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXUpsampleOp>(node, 2, 1, {
|
||||
{"mode","str","nearest"}
|
||||
});
|
||||
}else if (OpName == "Where") {
|
||||
OneOut(Where, 3, 1);
|
||||
ImportNodeOneOut<mlir::ONNXWhereOp>(node, 3, 1, {
|
||||
});
|
||||
}else if (OpName == "Xor") {
|
||||
OneOut(Xor, 2, 1);
|
||||
ImportNodeOneOut<mlir::ONNXXorOp>(node, 2, 1, {
|
||||
});
|
||||
}
|
|
@ -352,6 +352,121 @@ def gen_schema(schema) :
|
|||
|
||||
return s
|
||||
|
||||
"""
|
||||
special cases:
|
||||
* Split: attr split default value: sizeof(output1) namely 1
|
||||
* Conv: attr dilations default value is {num_dim of first input - 2, 1}
|
||||
* Conv: attr kernel_shape type is ints
|
||||
* Transpose: attr perm default value is {} empty int list
|
||||
"""
|
||||
|
||||
def gen_code(schema,fefile) :
|
||||
special_handler = dict([
|
||||
("Conv", "ImportNodeConv"),
|
||||
#("Transpose", "ImportNodeTranspose")
|
||||
])
|
||||
special_type = dict([
|
||||
("AveragePool "+"kernel_shape", '"ints", ""'),
|
||||
("MaxPool "+"kernel_shape", '"ints", ""'),
|
||||
("Cast "+"to", '"int", "0"'),
|
||||
("Concat "+"axis", '"int", "0"'),
|
||||
("Conv "+"group", '"int", "1"'),
|
||||
("Unsqueeze "+"axes", '"ints", ""'),
|
||||
("RNN "+"activation_alpha", '"floats", "{}"'),
|
||||
("RNN "+"activation_beta", '"floats", "{}"'),
|
||||
("RNN "+"activations", '"", "{Tannh, Tanh}"'),
|
||||
("LRN "+"size", '"int", ""')
|
||||
])
|
||||
line_indent = ' '
|
||||
fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n')
|
||||
op_type_str='mlir::ONNX'+schema.name+'Op'
|
||||
if schema.name in special_handler :
|
||||
fefile.write(' '+special_handler[schema.name]+'(node, '
|
||||
+str(len(schema.inputs))
|
||||
+', ' +str(len(schema.outputs))+', {\n')
|
||||
elif len(schema.outputs) > 1 :
|
||||
fefile.write(' '+'ImportNodeMultipleOuts<'+op_type_str+'>(node, '
|
||||
+str(len(schema.inputs))
|
||||
+', ' +str(len(schema.outputs))+', {\n')
|
||||
else :
|
||||
fefile.write(' '+'ImportNodeOneOut<'+op_type_str+'>(node, '
|
||||
+str(len(schema.inputs))
|
||||
+', ' +str(len(schema.outputs))+', {\n')
|
||||
|
||||
|
||||
if schema.attributes:
|
||||
first_attr = True
|
||||
for _, attr in sorted(schema.attributes.items()):
|
||||
attr_line = line_indent+line_indent+line_indent+line_indent
|
||||
if not first_attr:
|
||||
attr_line += ',{'
|
||||
else :
|
||||
attr_line += ' {'
|
||||
first_attr = False
|
||||
|
||||
attr_line += '"'+attr.name+'",'
|
||||
|
||||
if schema.name+' '+attr.name in special_type:
|
||||
attr_line += special_type[schema.name+' '+attr.name]
|
||||
# option holds either required or default value
|
||||
elif attr.required:
|
||||
attr_line += '"", ""'
|
||||
|
||||
elif attr.default_value.name:
|
||||
default_value = helper.get_attribute_value(attr.default_value)
|
||||
|
||||
def format_value(value): # type: (Any) -> Text
|
||||
if isinstance(value, float):
|
||||
formatted = str(np.round(value, 5))
|
||||
# use default formatting, unless too long.
|
||||
if (len(formatted) > 10):
|
||||
formatted = str("({:e})".format(value))
|
||||
return formatted
|
||||
elif isinstance(value, (bytes, bytearray)) and sys.version_info[0] == 3:
|
||||
return str(value.decode('utf-8'))
|
||||
return str(value)
|
||||
|
||||
if isinstance(default_value, list):
|
||||
value = default_value[0]
|
||||
default_value = [format_value(val) for val in default_value]
|
||||
# TODO the list type is homogenous or htergeneous?
|
||||
|
||||
if isinstance(value, float) :
|
||||
attr_type_str = '"floats"'
|
||||
elif isinstance(value, int) :
|
||||
attr_type_str = '"ints"'
|
||||
elif isinstance(value, str) :
|
||||
attr_type_str = '"strs"'
|
||||
elif isinstance(value, (bytes, bytearray)) :
|
||||
attr_type_str = '"strs"'
|
||||
else :
|
||||
attr_type_str = '"unknowns"'
|
||||
attr_option_str = '"{}"'.format(default_value)
|
||||
attr_option_str = attr_option_str.replace('[', '{', 1)
|
||||
attr_option_str = attr_option_str.replace(']', '}', 1)
|
||||
else:
|
||||
if isinstance(default_value, float) :
|
||||
attr_type_str = '"float"'
|
||||
elif isinstance(default_value, int) :
|
||||
attr_type_str = '"int"'
|
||||
elif isinstance(default_value, str) :
|
||||
attr_type_str = '"str"'
|
||||
elif isinstance(default_value, (bytes, bytearray)) :
|
||||
attr_type_str = '"str"'
|
||||
else :
|
||||
attr_type_str = '"unknown"'
|
||||
default_value = format_value(default_value)
|
||||
attr_option_str = '"{}"'.format(default_value)
|
||||
attr_line += attr_type_str+','+attr_option_str
|
||||
else:
|
||||
#TODO why?
|
||||
attr_line += '"", ""'
|
||||
attr_line += '}\n'
|
||||
fefile.write(attr_line)
|
||||
fefile.write(line_indent+line_indent+line_indent+'});\n')
|
||||
|
||||
|
||||
|
||||
def main(args): # type: (Type[Args]) -> None
|
||||
with io.open(args.changelog, 'w', newline='') as fout:
|
||||
fout.write('## Operator Changelog\n')
|
||||
|
@ -453,6 +568,7 @@ def main(args): # type: (Type[Args]) -> None
|
|||
fefile=io.open('op_build_table.inc', 'w', newline='')
|
||||
firstfunc = True
|
||||
|
||||
fefile.write(' '+'if (OpName == "DUMMY") {\n')
|
||||
for domain, supportmap in operator_schemas:
|
||||
s = '## {}\n'.format(display_domain_short(domain))
|
||||
fout.write(s)
|
||||
|
@ -461,21 +577,8 @@ def main(args): # type: (Type[Args]) -> None
|
|||
for op_type, schema, versions in namemap:
|
||||
# op_type
|
||||
#print("check point 1", schema.name, len(schema.inputs), len(schema.outputs))
|
||||
if firstfunc :
|
||||
fefile.write(' '+'if (OpName == "'+schema.name+'") {\n')
|
||||
firstfunc = False
|
||||
else :
|
||||
fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n')
|
||||
if len(schema.outputs) > 1 :
|
||||
fefile.write(' '+'MultipleOuts('+schema.name+', '
|
||||
+str(schema.since_version)+', '
|
||||
+str(len(schema.inputs))+', '
|
||||
+str(len(schema.outputs))+');\n')
|
||||
else :
|
||||
fefile.write(' '+'OneOut('+schema.name+', '
|
||||
+str(schema.since_version)+', '
|
||||
+str(len(schema.inputs))+', '
|
||||
+str(len(schema.outputs))+');\n')
|
||||
gen_code(schema, fefile)
|
||||
|
||||
r = gen_schema(schema)
|
||||
tdfile.write(r)
|
||||
s = ('### {}<a name="{}"></a><a name="{}">**{}**' + (' (deprecated)' if schema.deprecated else '') + '</a>\n').format(
|
||||
|
|
|
@ -71,4 +71,26 @@ def ONNXFullGemmOp: ONNX_Op<"FullGemm",
|
|||
let results = (outs AnyTensor);
|
||||
}
|
||||
|
||||
def ONNXConv1Op:ONNX_Op<"Conv1",
|
||||
[NoSideEffect]> {
|
||||
let summary = "ONNX Conv operation";
|
||||
let description = [{
|
||||
"The convolution operator consumes an input tensor and a filter, and"
|
||||
"computes the output."
|
||||
}];
|
||||
let arguments = (ins AnyTensor:$X);
|
||||
let results = (outs AnyTensor);
|
||||
}
|
||||
|
||||
def ONNXConv3Op:ONNX_Op<"Conv3",
|
||||
[NoSideEffect]> {
|
||||
let summary = "ONNX Conv operation";
|
||||
let description = [{
|
||||
"The convolution operator consumes an input tensor and a filter, and"
|
||||
"computes the output."
|
||||
}];
|
||||
let arguments = (ins AnyTensor:$X, AnyTensor:$W, AnyTensor:$B);
|
||||
let results = (outs AnyTensor);
|
||||
}
|
||||
|
||||
#endif // ONNX_OPS
|
||||
|
|
Loading…
Reference in New Issue