[MLIR] import attribute of onnx node (#383)
* add attributes as NamedAttribute * support list value for attribute * use std::tie to avoid c++17 feature
This commit is contained in:
parent
45608282e0
commit
c8d591fb28
|
@ -1,9 +1,12 @@
|
||||||
add_executable(onnf main.cpp)
|
add_executable(onnf main.cpp)
|
||||||
|
|
||||||
target_link_libraries(onnf builder compiler ${MLIRLibs} ${Boost_LIBRARIES})
|
target_link_libraries(onnf builder compiler ${MLIRLibs} ${Boost_LIBRARIES})
|
||||||
|
set_target_properties(onnf PROPERTIES LINK_FLAGS "-lz")
|
||||||
|
|
||||||
whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs})
|
whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs})
|
||||||
whole_archive_link_onnf(onnf onnf_transform)
|
whole_archive_link_onnf(onnf onnf_transform)
|
||||||
|
|
||||||
|
|
||||||
target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR})
|
target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR})
|
||||||
target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR})
|
target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR})
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
add_library(builder
|
add_library(builder
|
||||||
frontend_dialect_transformer.cpp
|
frontend_dialect_transformer.cpp
|
||||||
|
frontend_dialect_transformer.hpp
|
||||||
op_build_table.inc
|
op_build_table.inc
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -71,7 +71,10 @@ struct OnnxOnnfSymbolMapping {
|
||||||
* @param name onnx tensor name.
|
* @param name onnx tensor name.
|
||||||
* @return onnf tensor corresponding to `name`.
|
* @return onnf tensor corresponding to `name`.
|
||||||
*/
|
*/
|
||||||
mlir::Value* GetTensorByOnnxName(std::string name) {
|
mlir::Value *GetTensorByOnnxName(std::string name) {
|
||||||
|
assert(onnx_name2onnf_tensor.find(legalize_name(name)) !=
|
||||||
|
onnx_name2onnf_tensor.end() &&
|
||||||
|
"Tensor not found");
|
||||||
return onnx_name2onnf_tensor.at(legalize_name(name));
|
return onnx_name2onnf_tensor.at(legalize_name(name));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,7 +83,9 @@ struct OnnxOnnfSymbolMapping {
|
||||||
* @param name onnx tensor name.
|
* @param name onnx tensor name.
|
||||||
* @param tensor MLIR Value* pointer.
|
* @param tensor MLIR Value* pointer.
|
||||||
*/
|
*/
|
||||||
void AddMapping(std::string name, mlir::Value* tensor) {
|
void AddMapping(std::string name, mlir::Value *tensor) {
|
||||||
|
assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 &&
|
||||||
|
"Tensor already exists.");
|
||||||
onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
|
onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -88,7 +93,7 @@ struct OnnxOnnfSymbolMapping {
|
||||||
return onnx_name2onnf_tensor.count(name) != 0;
|
return onnx_name2onnf_tensor.count(name) != 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/*!
|
/*!
|
||||||
* mapping from onnx tensor names to MLIR tensor.
|
* mapping from onnx tensor names to MLIR tensor.
|
||||||
*/
|
*/
|
||||||
|
@ -96,8 +101,8 @@ struct OnnxOnnfSymbolMapping {
|
||||||
};
|
};
|
||||||
|
|
||||||
class FrontendGenImpl {
|
class FrontendGenImpl {
|
||||||
public:
|
public:
|
||||||
FrontendGenImpl(mlir::MLIRContext& context)
|
FrontendGenImpl(mlir::MLIRContext &context)
|
||||||
: context_(context), builder_(&context) {
|
: context_(context), builder_(&context) {
|
||||||
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||||
}
|
}
|
||||||
|
@ -107,8 +112,8 @@ class FrontendGenImpl {
|
||||||
return module_;
|
return module_;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
mlir::MLIRContext& context_;
|
mlir::MLIRContext &context_;
|
||||||
mlir::ModuleOp module_;
|
mlir::ModuleOp module_;
|
||||||
mlir::OpBuilder builder_;
|
mlir::OpBuilder builder_;
|
||||||
// mapping between string name and symbol
|
// mapping between string name and symbol
|
||||||
|
@ -145,55 +150,31 @@ class FrontendGenImpl {
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
|
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
|
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
|
case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
|
||||||
|
assert(false && "Unsupported data type encountered.");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//if c++17 is used, these two def can be combined with 'if constexpr'
|
|
||||||
//leave n there for possible future use
|
|
||||||
//alternative is to use template and pass the outputTypes, inputs and attributes
|
|
||||||
//as parameter
|
|
||||||
|
|
||||||
#define MultipleOuts(name, nIn, nOut)\
|
|
||||||
{ \
|
|
||||||
if (nIn == inputs.size() && nOut == outputTypes.size()) {\
|
|
||||||
auto op = builder_.create<mlir::ONNX##name##Op>(UnknownLoc(), outputTypes, inputs, attributes); \
|
|
||||||
for (int i = 0; i < node.output().size(); i++) { \
|
|
||||||
frontend_symbols_.AddMapping(\
|
|
||||||
legalize_name(node.output()[i]), op.getResult(i));\
|
|
||||||
}\
|
|
||||||
return;\
|
|
||||||
}\
|
|
||||||
}
|
|
||||||
|
|
||||||
#define OneOut(name, nIn, nOut)\
|
|
||||||
{ \
|
|
||||||
if (nIn == inputs.size() && nOut == outputTypes.size()) {\
|
|
||||||
auto op = builder_.create<mlir::ONNX##name##Op>(UnknownLoc(), outputTypes, inputs, attributes); \
|
|
||||||
frontend_symbols_.AddMapping(\
|
|
||||||
legalize_name(node.output()[0]), op.getResult());\
|
|
||||||
return;\
|
|
||||||
}\
|
|
||||||
}
|
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* Import an onnx input tensor type by determining and recording its type
|
* Import an onnx input tensor type by determining and recording its type
|
||||||
* in a list of input tensor mlir types.
|
* in a list of input tensor mlir types.
|
||||||
* @param input onnx input tensor ValueInfoProto.
|
* @param input onnx input tensor ValueInfoProto.
|
||||||
* @param arg_types list of mlir types representing types of graph input.
|
* @param arg_types list of mlir types representing types of graph input.
|
||||||
*/
|
*/
|
||||||
void ImportInputTensorType(const onnx::ValueInfoProto& input,
|
void ImportInputTensorType(const onnx::ValueInfoProto &input,
|
||||||
llvm::SmallVector<mlir::Type, 4>& arg_types) {
|
llvm::SmallVector<mlir::Type, 4> &arg_types) {
|
||||||
std::vector<int64_t> dims;
|
std::vector<int64_t> dims;
|
||||||
auto shape_proto = input.type().tensor_type().shape();
|
auto shape_proto = input.type().tensor_type().shape();
|
||||||
auto input_tensor_legalized_name = legalize_name(input.name());
|
auto input_tensor_legalized_name = legalize_name(input.name());
|
||||||
for (int i = 0; i < shape_proto.dim_size(); i++) {
|
for (int i = 0; i < shape_proto.dim_size(); i++) {
|
||||||
if (shape_proto.dim()[i].dim_value()) {
|
if (shape_proto.dim()[i].dim_value()) {
|
||||||
int dim_numeric_size = shape_proto.dim()[i].dim_value();
|
int dim_numeric_size = shape_proto.dim()[i].dim_value();
|
||||||
|
assert(
|
||||||
|
dim_numeric_size != 0 && "Parsed an input tensor with a dimension size of zero");
|
||||||
if (dim_numeric_size > 0) {
|
if (dim_numeric_size > 0) {
|
||||||
dims.push_back(dim_numeric_size);
|
dims.push_back(dim_numeric_size);
|
||||||
} else { // If dim_value < 0, then dim is parametric.
|
} else { // If dim_value < 0, then dim is parametric.
|
||||||
// TODO Verify the unknown dim size in MLIR
|
// TODO Verify the unknown dim size in MLIR
|
||||||
dims.push_back(-1);
|
dims.push_back(-1);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -216,8 +197,8 @@ class FrontendGenImpl {
|
||||||
* @param input onnx input tensor ValueInfoProto.
|
* @param input onnx input tensor ValueInfoProto.
|
||||||
* @param symbol mlir input argument.
|
* @param symbol mlir input argument.
|
||||||
*/
|
*/
|
||||||
void ImportInputTensorSymbol(
|
void ImportInputTensorSymbol(const onnx::ValueInfoProto &input,
|
||||||
const onnx::ValueInfoProto& input, mlir::Value* symbol) {
|
mlir::Value *symbol) {
|
||||||
auto input_tensor_legalized_name = legalize_name(input.name());
|
auto input_tensor_legalized_name = legalize_name(input.name());
|
||||||
assert(
|
assert(
|
||||||
!frontend_symbols_.ContainKey(input_tensor_legalized_name) &&
|
!frontend_symbols_.ContainKey(input_tensor_legalized_name) &&
|
||||||
|
@ -225,32 +206,286 @@ class FrontendGenImpl {
|
||||||
frontend_symbols_.AddMapping(input_tensor_legalized_name, symbol);
|
frontend_symbols_.AddMapping(input_tensor_legalized_name, symbol);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ImportNode(onnx::NodeProto node) {
|
template <typename T>
|
||||||
std::vector<mlir::Value*> inputs;
|
T get_attr_generic(onnx::NodeProto &node, std::string name,
|
||||||
|
std::function<T(onnx::AttributeProto &)> attr_getter,
|
||||||
|
T default_val) {
|
||||||
|
for (int i = 0; i < node.attribute_size(); ++i) {
|
||||||
|
auto attr = node.attribute(i);
|
||||||
|
if (attr.name() == name) {
|
||||||
|
return attr_getter(attr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return default_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T get_attr_generic(onnx::NodeProto &node, std::string name,
|
||||||
|
std::function<T(onnx::AttributeProto &)> attr_getter) {
|
||||||
|
for (int i = 0; i < node.attribute_size(); ++i) {
|
||||||
|
auto attr = node.attribute(i);
|
||||||
|
if (attr.name() == name) {
|
||||||
|
return attr_getter(attr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert(false && "ONNX Node Attribute Not Found!");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto get_attr_ints(onnx::NodeProto &node, std::string name,
|
||||||
|
std::vector<int> default_val) {
|
||||||
|
std::function<std::vector<int>(onnx::AttributeProto &)> attr_getter =
|
||||||
|
[](onnx::AttributeProto &attr) {
|
||||||
|
std::vector<int> ints(attr.ints_size());
|
||||||
|
std::copy(attr.ints().begin(), attr.ints().end(), ints.begin());
|
||||||
|
return ints;
|
||||||
|
};
|
||||||
|
auto r = get_attr_generic(node, name, attr_getter, default_val);
|
||||||
|
auto dataType =
|
||||||
|
mlir::RankedTensorType::get(r.size(), builder_.getIntegerType(32));
|
||||||
|
auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r));
|
||||||
|
auto aname = node.op_type() + "." + name;
|
||||||
|
auto attr_output = builder_.getNamedAttr(aname, attr_v);
|
||||||
|
return attr_output;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto get_attr_ints(onnx::NodeProto &node, std::string name) {
|
||||||
|
std::function<std::vector<int>(onnx::AttributeProto &)> attr_getter =
|
||||||
|
[](onnx::AttributeProto &attr) {
|
||||||
|
std::vector<int> ints(attr.ints_size());
|
||||||
|
std::copy(attr.ints().begin(), attr.ints().end(), ints.begin());
|
||||||
|
return ints;
|
||||||
|
};
|
||||||
|
auto r = get_attr_generic(node, name, attr_getter);
|
||||||
|
auto dataType =
|
||||||
|
mlir::RankedTensorType::get(r.size(), builder_.getIntegerType(32));
|
||||||
|
auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r));
|
||||||
|
auto aname = node.op_type() + "." + name;
|
||||||
|
auto attr_output = builder_.getNamedAttr(aname, attr_v);
|
||||||
|
return attr_output;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto get_attr_floats(onnx::NodeProto &node, std::string name) {
|
||||||
|
std::function<std::vector<float>(onnx::AttributeProto &)> attr_getter =
|
||||||
|
[](onnx::AttributeProto &attr) {
|
||||||
|
std::vector<float> floats(attr.floats_size());
|
||||||
|
std::copy(attr.floats().begin(), attr.floats().end(), floats.begin());
|
||||||
|
return floats;
|
||||||
|
};
|
||||||
|
auto r = get_attr_generic(node, name, attr_getter);
|
||||||
|
auto dataType =
|
||||||
|
mlir::RankedTensorType::get(r.size(), builder_.getF32Type());
|
||||||
|
auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r));
|
||||||
|
auto aname = node.op_type() + "." + name;
|
||||||
|
auto attr_output = builder_.getNamedAttr(aname, attr_v);
|
||||||
|
return attr_output;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto get_attr_floats(onnx::NodeProto &node, std::string name,
|
||||||
|
std::vector<float> default_val) {
|
||||||
|
std::function<std::vector<float>(onnx::AttributeProto &)> attr_getter =
|
||||||
|
[](onnx::AttributeProto &attr) {
|
||||||
|
std::vector<float> floats(attr.floats_size());
|
||||||
|
std::copy(attr.floats().begin(), attr.floats().end(), floats.begin());
|
||||||
|
return floats;
|
||||||
|
};
|
||||||
|
auto r = get_attr_generic(node, name, attr_getter, default_val);
|
||||||
|
auto dataType =
|
||||||
|
mlir::RankedTensorType::get(r.size(), builder_.getF32Type());
|
||||||
|
auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r));
|
||||||
|
auto aname = node.op_type() + "." + name;
|
||||||
|
auto attr_output = builder_.getNamedAttr(aname, attr_v);
|
||||||
|
return attr_output;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto get_attr_int(onnx::NodeProto &node, std::string name) {
|
||||||
|
std::function<int(onnx::AttributeProto &)> attr_getter =
|
||||||
|
[](onnx::AttributeProto &attr) { return attr.i(); };
|
||||||
|
int r = get_attr_generic(node, name, attr_getter);
|
||||||
|
auto attr_v = builder_.getI32IntegerAttr(r);
|
||||||
|
auto aname = node.op_type() + "." + name;
|
||||||
|
auto attr_output = builder_.getNamedAttr(aname, attr_v);
|
||||||
|
return attr_output;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto get_attr_int(onnx::NodeProto &node, std::string name, int default_val) {
|
||||||
|
std::function<int(onnx::AttributeProto &)> attr_getter =
|
||||||
|
[](onnx::AttributeProto &attr) { return attr.i(); };
|
||||||
|
int r = get_attr_generic(node, name, attr_getter, default_val);
|
||||||
|
auto attr_v = builder_.getI32IntegerAttr(r);
|
||||||
|
auto aname = node.op_type() + "." + name;
|
||||||
|
auto attr_output = builder_.getNamedAttr(aname, attr_v);
|
||||||
|
return attr_output;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto get_attr_float(onnx::NodeProto &node, std::string name) {
|
||||||
|
std::function<float(onnx::AttributeProto &)> attr_getter =
|
||||||
|
[](onnx::AttributeProto &attr) { return attr.f(); };
|
||||||
|
auto r = get_attr_generic(node, name, attr_getter);
|
||||||
|
auto attr_v = builder_.getF32FloatAttr(r);
|
||||||
|
auto aname = node.op_type() + "." + name;
|
||||||
|
return builder_.getNamedAttr(aname, attr_v);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto get_attr_float(onnx::NodeProto &node, std::string name,
|
||||||
|
float default_val) {
|
||||||
|
std::function<float(onnx::AttributeProto &)> attr_getter =
|
||||||
|
[](onnx::AttributeProto &attr) { return attr.f(); };
|
||||||
|
auto r = get_attr_generic(node, name, attr_getter, default_val);
|
||||||
|
auto attr_v = builder_.getF32FloatAttr(r);
|
||||||
|
auto aname = node.op_type() + "." + name;
|
||||||
|
return builder_.getNamedAttr(aname, attr_v);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto get_attr_string(onnx::NodeProto &node, std::string name) {
|
||||||
|
std::function<std::string(onnx::AttributeProto &)> attr_getter =
|
||||||
|
[](onnx::AttributeProto &attr) { return attr.s(); };
|
||||||
|
auto r = get_attr_generic(node, name, attr_getter);
|
||||||
|
auto attr_v = builder_.getStringAttr(r);
|
||||||
|
auto aname = node.op_type() + "." + name;
|
||||||
|
return builder_.getNamedAttr(aname, attr_v);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto get_attr_string(onnx::NodeProto &node, std::string name,
|
||||||
|
std::string default_val) {
|
||||||
|
std::function<std::string(onnx::AttributeProto &)> attr_getter =
|
||||||
|
[](onnx::AttributeProto &attr) { return attr.s(); };
|
||||||
|
auto r = get_attr_generic(node, name, attr_getter, default_val);
|
||||||
|
auto attr_v = builder_.getStringAttr(r);
|
||||||
|
auto aname = node.op_type() + "." + name;
|
||||||
|
return builder_.getNamedAttr(aname, attr_v);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
auto get_attr_strings(onnx::NodeProto &node, std::string name) {
|
||||||
|
std::function<std::vector<std::string>(onnx::AttributeProto &)>
|
||||||
|
attr_getter =
|
||||||
|
[](onnx::AttributeProto &attr) {
|
||||||
|
std::vector<std::string> strings(attr.strings_size());
|
||||||
|
std::copy(attr.strings().begin(), attr.strings().end(),
|
||||||
|
strings.begin()); return strings;
|
||||||
|
};
|
||||||
|
auto r = get_attr_generic(node, name, attr_getter);
|
||||||
|
return r;
|
||||||
|
return builder_.getNamedAttr(aname, attr_v);
|
||||||
|
auto dataType =
|
||||||
|
mlir::RankedTensorType::get(r.size(), builder_.get???Type());
|
||||||
|
auto attr_v = mlir::DenseElementsAttr::get(dataType,
|
||||||
|
llvm::makeArrayRef(r)); auto aname = node.op_type() + "." + name; auto
|
||||||
|
attr_output = builder_.getNamedAttr(aname, attr_v); return attr_output;
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
auto get_default_ints(std::string default_str) {
|
||||||
|
std::vector<int> r;
|
||||||
|
auto start = default_str.find("{");
|
||||||
|
while (true) {
|
||||||
|
auto end = default_str.find(",", start + 1);
|
||||||
|
if (end == std::string::npos) {
|
||||||
|
end = default_str.find("}", start + 1);
|
||||||
|
if (end != std::string::npos && end > start+1) {
|
||||||
|
r.push_back(std::stoi(default_str.substr(start + 1, end)));
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
r.push_back(std::stoi(default_str.substr(start + 1, end)));
|
||||||
|
}
|
||||||
|
start = end + 1;
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto get_default_floats(std::string default_str) {
|
||||||
|
std::vector<float> r;
|
||||||
|
auto start = default_str.find("{");
|
||||||
|
while (true) {
|
||||||
|
auto end = default_str.find(",", start + 1);
|
||||||
|
if (end == std::string::npos) {
|
||||||
|
end = default_str.find("}", start + 1);
|
||||||
|
if (end != std::string::npos && end > start+1) {
|
||||||
|
r.push_back(std::stof(default_str.substr(start + 1, end)));
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
r.push_back(std::stof(default_str.substr(start + 1, end)));
|
||||||
|
}
|
||||||
|
start = end + 1;
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto get_default_strings(std::string default_str) {
|
||||||
|
std::vector<std::string> r;
|
||||||
|
auto start = default_str.find("{");
|
||||||
|
while (true) {
|
||||||
|
auto end = default_str.find(",", start + 1);
|
||||||
|
if (end == std::string::npos) {
|
||||||
|
end = default_str.find("}", start + 1);
|
||||||
|
if (end != std::string::npos && end > start+1) {
|
||||||
|
r.push_back(default_str.substr(start + 1, end));
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
r.push_back(default_str.substr(start + 1, end));
|
||||||
|
}
|
||||||
|
start = end + 1;
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
onnx::TensorProto get_attr_tensor(onnx::NodeProto &node, std::string name) {
|
||||||
|
std::function<onnx::TensorProto(onnx::AttributeProto &)> attr_getter =
|
||||||
|
[](onnx::AttributeProto &attr) { return attr.t(); };
|
||||||
|
return get_attr_generic(node, name, attr_getter);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ImportNodeAttr(onnx::NodeProto node, std::string attr_name,
|
||||||
|
std::string type_name, std::string default_str) {
|
||||||
|
if (default_str == "") {
|
||||||
|
if (type_name == "int") {
|
||||||
|
return get_attr_int(node, attr_name);
|
||||||
|
} else if (type_name == "float") {
|
||||||
|
return get_attr_float(node, attr_name);
|
||||||
|
} else if (type_name == "str") {
|
||||||
|
return get_attr_string(node, attr_name);
|
||||||
|
} else if (type_name == "ints") {
|
||||||
|
return get_attr_ints(node, attr_name);
|
||||||
|
} else if (type_name == "floats") {
|
||||||
|
return get_attr_floats(node, attr_name);
|
||||||
|
} else {
|
||||||
|
assert(
|
||||||
|
false &&
|
||||||
|
"Got an empty initializer or initializer for this "
|
||||||
|
"datatype is not implemented. Something is wrong.");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// with default value
|
||||||
|
if (type_name == "int") {
|
||||||
|
return get_attr_int(node, attr_name, std::stoi(default_str));
|
||||||
|
} else if (type_name == "float") {
|
||||||
|
return get_attr_float(node, attr_name, std::stof(default_str));
|
||||||
|
} else if (type_name == "str") {
|
||||||
|
return get_attr_string(node, attr_name, default_str);
|
||||||
|
} else if (type_name == "ints") {
|
||||||
|
return get_attr_ints(node, attr_name, get_default_ints(default_str));
|
||||||
|
} else if (type_name == "floats") {
|
||||||
|
return get_attr_floats(node, attr_name,
|
||||||
|
get_default_floats(default_str));
|
||||||
|
} else {
|
||||||
|
assert(
|
||||||
|
false &&
|
||||||
|
"Got an empty initializer or initializer for this "
|
||||||
|
"datatype is not implemented. Something is wrong.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ImportNodeGeneric(onnx::NodeProto node) {
|
||||||
|
std::vector<mlir::Value *> inputs;
|
||||||
for (auto item : node.input()) {
|
for (auto item : node.input()) {
|
||||||
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||||
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<mlir::Type> outputTypes;
|
|
||||||
for (auto item : node.output()) {
|
|
||||||
outputTypes.push_back(mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<mlir::NamedAttribute> attributes;
|
|
||||||
llvm::StringRef OpName = node.op_type();
|
|
||||||
|
|
||||||
//the following code is generated by gen_doc.py
|
|
||||||
//refer to dialect/onnx/onnx.td for details
|
|
||||||
//when the input or output of then op does not match the specification,
|
|
||||||
//the generic operator is used
|
|
||||||
//one known reeason is the optional input
|
|
||||||
|
|
||||||
#include "src/builder/op_build_table.inc"
|
|
||||||
|
|
||||||
|
|
||||||
// Old way of doing things.
|
|
||||||
mlir::OperationState result(UnknownLoc(), "frontend." + node.op_type());
|
mlir::OperationState result(UnknownLoc(), "frontend." + node.op_type());
|
||||||
for (auto item : node.output()) {
|
for (auto item : node.output()) {
|
||||||
result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
||||||
|
@ -261,8 +496,165 @@ class FrontendGenImpl {
|
||||||
auto r = op->getResult(i);
|
auto r = op->getResult(i);
|
||||||
frontend_symbols_.AddMapping(legalize_name(node.output()[i]), r);
|
frontend_symbols_.AddMapping(legalize_name(node.output()[i]), r);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TODO more info from node: attributes
|
// if c++17 is used, ImportNodeOneOut and ImportNodeMultipleOuts can be
|
||||||
|
// combined with 'if constexpr' the issue is the type of the output is
|
||||||
|
// different. alternative way to use variadic output for all the op
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* Important onnx node which generates only one output
|
||||||
|
* @param node onnx node
|
||||||
|
* @param nIn number of expected inputs
|
||||||
|
* @param nOut number of expected outputs
|
||||||
|
* @param attrs list of desription for attributes with format {name, type,
|
||||||
|
* default}
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
void ImportNodeOneOut(
|
||||||
|
onnx::NodeProto node, int nIn, int nOut,
|
||||||
|
std::initializer_list<std::tuple<std::string, std::string, std::string>>
|
||||||
|
attrs) {
|
||||||
|
std::vector<mlir::Value *> inputs;
|
||||||
|
for (auto item : node.input()) {
|
||||||
|
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||||
|
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<mlir::Type> outputTypes;
|
||||||
|
for (auto item : node.output()) {
|
||||||
|
outputTypes.push_back(
|
||||||
|
mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<mlir::NamedAttribute> attributes;
|
||||||
|
//for (auto [attr_name, attr_type, attr_default] : attrs) {
|
||||||
|
for (auto oneAttr: attrs) {
|
||||||
|
std::string attr_name;
|
||||||
|
std::string attr_type;
|
||||||
|
std::string attr_default;
|
||||||
|
std::tie (attr_name, attr_type, attr_default) = oneAttr;
|
||||||
|
if (attr_type != "") {
|
||||||
|
auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default);
|
||||||
|
attributes.push_back(attr);
|
||||||
|
} else {
|
||||||
|
// TODO: the attributes need special handling
|
||||||
|
//std::cout << "missing " << node.op_type() << " " << attr_name << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::StringRef OpName = node.op_type();
|
||||||
|
|
||||||
|
if (nIn == inputs.size() && nOut == outputTypes.size()) {
|
||||||
|
auto op =
|
||||||
|
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
|
||||||
|
frontend_symbols_.AddMapping(legalize_name(node.output()[0]),
|
||||||
|
op.getResult());
|
||||||
|
} else {
|
||||||
|
ImportNodeGeneric(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void ImportNodeMultipleOuts(
|
||||||
|
onnx::NodeProto node, int nIn, int nOut,
|
||||||
|
std::initializer_list<std::tuple<std::string, std::string, std::string>>
|
||||||
|
attrs) {
|
||||||
|
std::vector<mlir::Value *> inputs;
|
||||||
|
for (auto item : node.input()) {
|
||||||
|
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||||
|
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<mlir::Type> outputTypes;
|
||||||
|
for (auto item : node.output()) {
|
||||||
|
outputTypes.push_back(
|
||||||
|
mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<mlir::NamedAttribute> attributes;
|
||||||
|
for (auto oneAttr: attrs) {
|
||||||
|
std::string attr_name;
|
||||||
|
std::string attr_type;
|
||||||
|
std::string attr_default;
|
||||||
|
std::tie (attr_name, attr_type, attr_default) = oneAttr;
|
||||||
|
if (attr_type != "") {
|
||||||
|
auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default);
|
||||||
|
attributes.push_back(attr);
|
||||||
|
} else {
|
||||||
|
// TODO: the attributes need special handling
|
||||||
|
//std::cout << "missing " << node.op_type() << " " << attr_name << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::StringRef OpName = node.op_type();
|
||||||
|
|
||||||
|
if (nIn == inputs.size() && nOut == outputTypes.size()) {
|
||||||
|
auto op =
|
||||||
|
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
|
||||||
|
for (int i = 0; i < node.output().size(); i++) {
|
||||||
|
frontend_symbols_.AddMapping(legalize_name(node.output()[i]),
|
||||||
|
op.getResult(i));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ImportNodeGeneric(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* Special handle for Conv operations.
|
||||||
|
* c++ does not allow template specialization inside a class scope
|
||||||
|
* a specialized function is used
|
||||||
|
*/
|
||||||
|
void ImportNodeConv(
|
||||||
|
onnx::NodeProto node, int nIn, int nOut,
|
||||||
|
std::initializer_list<std::tuple<std::string, std::string, std::string>>
|
||||||
|
attrs) {
|
||||||
|
|
||||||
|
// Conv has attribute dilations, kernel_shape, pads, the default value of
|
||||||
|
// which is determined by the shape of first argument. However, since the
|
||||||
|
// shape is unknown now, these attributes can be not generated auto
|
||||||
|
// dilations_attr = get_attr_ints(node, "dilations",
|
||||||
|
// std::vector<int>(inputs[0]->getType().cast<RankedTensorType>.getDims()-2,
|
||||||
|
// 1));
|
||||||
|
// attributes.push_back(dilations_attr)
|
||||||
|
// similar situation for pads, strides in AveragePool
|
||||||
|
// axes of ReduceSum, pads, strides, dilations and kernel_shape of MaxPool
|
||||||
|
// TODO: fix this after type inference
|
||||||
|
|
||||||
|
if (node.input().size() == 1) {
|
||||||
|
ImportNodeOneOut<mlir::ONNXConv1Op>(node, nIn, nOut, attrs);
|
||||||
|
} else {
|
||||||
|
ImportNodeOneOut<mlir::ONNXConv3Op>(node, nIn, nOut, attrs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ImportNode(onnx::NodeProto node) {
|
||||||
|
std::vector<mlir::Value *> inputs;
|
||||||
|
for (auto item : node.input()) {
|
||||||
|
if (frontend_symbols_.ContainKey(legalize_name(item))) {
|
||||||
|
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<mlir::Type> outputTypes;
|
||||||
|
for (auto item : node.output()) {
|
||||||
|
outputTypes.push_back(
|
||||||
|
mlir::UnrankedTensorType::get(builder_.getF32Type()));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<mlir::NamedAttribute> attributes;
|
||||||
|
llvm::StringRef OpName = node.op_type();
|
||||||
|
|
||||||
|
// the following code is generated by gen_doc.py
|
||||||
|
// refer to dialect/onnx/onnx.td for details
|
||||||
|
// when the input or output of then op does not match the specification,
|
||||||
|
// the generic operator is used
|
||||||
|
// one known reeason is the optional input
|
||||||
|
|
||||||
|
#include "src/builder/op_build_table.inc"
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
|
@ -277,9 +669,9 @@ class FrontendGenImpl {
|
||||||
* @param ret_vals a vector of mlir Value* representing graph's
|
* @param ret_vals a vector of mlir Value* representing graph's
|
||||||
* output tensor.
|
* output tensor.
|
||||||
*/
|
*/
|
||||||
void ImportOutputTensor(const onnx::ValueInfoProto& output,
|
void ImportOutputTensor(const onnx::ValueInfoProto &output,
|
||||||
llvm::SmallVectorImpl<mlir::Type>& ret_types,
|
llvm::SmallVectorImpl<mlir::Type> &ret_types,
|
||||||
llvm::SmallVectorImpl<mlir::Value*>& ret_vals) {
|
llvm::SmallVectorImpl<mlir::Value *> &ret_vals) {
|
||||||
auto output_tensor_legalized_name = legalize_name(output.name());
|
auto output_tensor_legalized_name = legalize_name(output.name());
|
||||||
assert(
|
assert(
|
||||||
frontend_symbols_.ContainKey(output_tensor_legalized_name) &&
|
frontend_symbols_.ContainKey(output_tensor_legalized_name) &&
|
||||||
|
@ -291,8 +683,8 @@ class FrontendGenImpl {
|
||||||
ret_vals.push_back(tensor_val);
|
ret_vals.push_back(tensor_val);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ImportGraph(
|
void ImportGraph(const onnx::GraphProto &graph,
|
||||||
const onnx::GraphProto& graph, const std::string& name = "main") {
|
const std::string &name = "main") {
|
||||||
// create a function for the graph
|
// create a function for the graph
|
||||||
// TODO:
|
// TODO:
|
||||||
// * get name and type for the function.
|
// * get name and type for the function.
|
||||||
|
@ -300,7 +692,7 @@ class FrontendGenImpl {
|
||||||
llvm::SmallVector<mlir::Type, 4> arg_types;
|
llvm::SmallVector<mlir::Type, 4> arg_types;
|
||||||
|
|
||||||
// Import the input tensor types.
|
// Import the input tensor types.
|
||||||
for (const auto& input : graph.input()) {
|
for (const auto &input : graph.input()) {
|
||||||
ImportInputTensorType(input, arg_types);
|
ImportInputTensorType(input, arg_types);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -308,7 +700,7 @@ class FrontendGenImpl {
|
||||||
auto func_type = builder_.getFunctionType(arg_types, {});
|
auto func_type = builder_.getFunctionType(arg_types, {});
|
||||||
auto main_func =
|
auto main_func =
|
||||||
mlir::FuncOp::create(UnknownLoc(), name, func_type, /* attrs = */ {});
|
mlir::FuncOp::create(UnknownLoc(), name, func_type, /* attrs = */ {});
|
||||||
auto& entryBlock = *main_func.addEntryBlock();
|
auto &entryBlock = *main_func.addEntryBlock();
|
||||||
|
|
||||||
builder_.setInsertionPointToStart(&entryBlock);
|
builder_.setInsertionPointToStart(&entryBlock);
|
||||||
module_.push_back(main_func);
|
module_.push_back(main_func);
|
||||||
|
@ -319,14 +711,14 @@ class FrontendGenImpl {
|
||||||
|
|
||||||
// import nodes in the graph
|
// import nodes in the graph
|
||||||
auto node = graph.node();
|
auto node = graph.node();
|
||||||
for (const auto& item : node) {
|
for (const auto &item : node) {
|
||||||
ImportNode(item);
|
ImportNode(item);
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::SmallVector<mlir::Type, 4> ret_types;
|
llvm::SmallVector<mlir::Type, 4> ret_types;
|
||||||
llvm::SmallVector<mlir::Value*, 4> ret_vals;
|
llvm::SmallVector<mlir::Value *, 4> ret_vals;
|
||||||
// Import the output tensors
|
// Import the output tensors
|
||||||
for (const auto& output : graph.output()) {
|
for (const auto &output : graph.output()) {
|
||||||
ImportOutputTensor(output, ret_types, ret_vals);
|
ImportOutputTensor(output, ret_types, ret_vals);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -337,7 +729,6 @@ class FrontendGenImpl {
|
||||||
func_type = builder_.getFunctionType(arg_types, ret_types);
|
func_type = builder_.getFunctionType(arg_types, ret_types);
|
||||||
main_func.setType(func_type);
|
main_func.setType(func_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
}; // FrontendGenImpl class
|
}; // FrontendGenImpl class
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace onnf
|
} // namespace onnf
|
||||||
|
@ -354,11 +745,13 @@ mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void ImportFrontendModelFile(std::string model_fname,
|
void ImportFrontendModelFile(std::string model_fname,
|
||||||
mlir::MLIRContext& context, mlir::OwningModuleRef& module) {
|
mlir::MLIRContext &context,
|
||||||
|
mlir::OwningModuleRef &module) {
|
||||||
onnx::ModelProto model;
|
onnx::ModelProto model;
|
||||||
std::fstream input(model_fname, std::ios::in | std::ios::binary);
|
std::fstream input(model_fname, std::ios::in | std::ios::binary);
|
||||||
|
|
||||||
auto parse_success = model.ParseFromIstream(&input);
|
auto parse_success = model.ParseFromIstream(&input);
|
||||||
|
assert(parse_success && "Onnx Model Parsing Failed.");
|
||||||
|
|
||||||
FrontendGenImpl myONNXGen(context);
|
FrontendGenImpl myONNXGen(context);
|
||||||
module = myONNXGen.ImportONNXModel(model);
|
module = myONNXGen.ImportONNXModel(model);
|
||||||
|
|
|
@ -1,313 +1,688 @@
|
||||||
if (OpName == "Abs") {
|
if (OpName == "DUMMY") {
|
||||||
OneOut(Abs, 1, 1);
|
}else if (OpName == "Abs") {
|
||||||
|
ImportNodeOneOut<mlir::ONNXAbsOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Acos") {
|
}else if (OpName == "Acos") {
|
||||||
OneOut(Acos, 1, 1);
|
ImportNodeOneOut<mlir::ONNXAcosOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Acosh") {
|
}else if (OpName == "Acosh") {
|
||||||
OneOut(Acosh, 1, 1);
|
ImportNodeOneOut<mlir::ONNXAcoshOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Add") {
|
}else if (OpName == "Add") {
|
||||||
OneOut(Add, 2, 1);
|
ImportNodeOneOut<mlir::ONNXAddOp>(node, 2, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "And") {
|
}else if (OpName == "And") {
|
||||||
OneOut(And, 2, 1);
|
ImportNodeOneOut<mlir::ONNXAndOp>(node, 2, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "ArgMax") {
|
}else if (OpName == "ArgMax") {
|
||||||
OneOut(ArgMax, 1, 1);
|
ImportNodeOneOut<mlir::ONNXArgMaxOp>(node, 1, 1, {
|
||||||
|
{"axis","int","0"}
|
||||||
|
,{"keepdims","int","1"}
|
||||||
|
});
|
||||||
}else if (OpName == "ArgMin") {
|
}else if (OpName == "ArgMin") {
|
||||||
OneOut(ArgMin, 1, 1);
|
ImportNodeOneOut<mlir::ONNXArgMinOp>(node, 1, 1, {
|
||||||
|
{"axis","int","0"}
|
||||||
|
,{"keepdims","int","1"}
|
||||||
|
});
|
||||||
}else if (OpName == "Asin") {
|
}else if (OpName == "Asin") {
|
||||||
OneOut(Asin, 1, 1);
|
ImportNodeOneOut<mlir::ONNXAsinOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Asinh") {
|
}else if (OpName == "Asinh") {
|
||||||
OneOut(Asinh, 1, 1);
|
ImportNodeOneOut<mlir::ONNXAsinhOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Atan") {
|
}else if (OpName == "Atan") {
|
||||||
OneOut(Atan, 1, 1);
|
ImportNodeOneOut<mlir::ONNXAtanOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Atanh") {
|
}else if (OpName == "Atanh") {
|
||||||
OneOut(Atanh, 1, 1);
|
ImportNodeOneOut<mlir::ONNXAtanhOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "AveragePool") {
|
}else if (OpName == "AveragePool") {
|
||||||
OneOut(AveragePool, 1, 1);
|
ImportNodeOneOut<mlir::ONNXAveragePoolOp>(node, 1, 1, {
|
||||||
|
{"auto_pad","str","NOTSET"}
|
||||||
|
,{"ceil_mode","int","0"}
|
||||||
|
,{"count_include_pad","int","0"}
|
||||||
|
,{"kernel_shape","ints", ""}
|
||||||
|
,{"pads","", ""}
|
||||||
|
,{"strides","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "BatchNormalization") {
|
}else if (OpName == "BatchNormalization") {
|
||||||
MultipleOuts(BatchNormalization, 5, 5);
|
ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, 5, 5, {
|
||||||
|
{"epsilon","float","1e-05"}
|
||||||
|
,{"momentum","float","0.9"}
|
||||||
|
});
|
||||||
}else if (OpName == "BitShift") {
|
}else if (OpName == "BitShift") {
|
||||||
OneOut(BitShift, 2, 1);
|
ImportNodeOneOut<mlir::ONNXBitShiftOp>(node, 2, 1, {
|
||||||
|
{"direction","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "Cast") {
|
}else if (OpName == "Cast") {
|
||||||
OneOut(Cast, 1, 1);
|
ImportNodeOneOut<mlir::ONNXCastOp>(node, 1, 1, {
|
||||||
|
{"to","int", "0"}
|
||||||
|
});
|
||||||
}else if (OpName == "Ceil") {
|
}else if (OpName == "Ceil") {
|
||||||
OneOut(Ceil, 1, 1);
|
ImportNodeOneOut<mlir::ONNXCeilOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Clip") {
|
}else if (OpName == "Clip") {
|
||||||
OneOut(Clip, 3, 1);
|
ImportNodeOneOut<mlir::ONNXClipOp>(node, 3, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Compress") {
|
}else if (OpName == "Compress") {
|
||||||
OneOut(Compress, 2, 1);
|
ImportNodeOneOut<mlir::ONNXCompressOp>(node, 2, 1, {
|
||||||
|
{"axis","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "Concat") {
|
}else if (OpName == "Concat") {
|
||||||
OneOut(Concat, 1, 1);
|
ImportNodeOneOut<mlir::ONNXConcatOp>(node, 1, 1, {
|
||||||
|
{"axis","int", "0"}
|
||||||
|
});
|
||||||
}else if (OpName == "ConcatFromSequence") {
|
}else if (OpName == "ConcatFromSequence") {
|
||||||
OneOut(ConcatFromSequence, 1, 1);
|
ImportNodeOneOut<mlir::ONNXConcatFromSequenceOp>(node, 1, 1, {
|
||||||
|
{"axis","", ""}
|
||||||
|
,{"new_axis","int","0"}
|
||||||
|
});
|
||||||
}else if (OpName == "Constant") {
|
}else if (OpName == "Constant") {
|
||||||
OneOut(Constant, 0, 1);
|
ImportNodeOneOut<mlir::ONNXConstantOp>(node, 0, 1, {
|
||||||
|
{"sparse_value","", ""}
|
||||||
|
,{"value","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "ConstantOfShape") {
|
}else if (OpName == "ConstantOfShape") {
|
||||||
OneOut(ConstantOfShape, 1, 1);
|
ImportNodeOneOut<mlir::ONNXConstantOfShapeOp>(node, 1, 1, {
|
||||||
|
{"value","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "Conv") {
|
}else if (OpName == "Conv") {
|
||||||
OneOut(Conv, 3, 1);
|
ImportNodeConv(node, 3, 1, {
|
||||||
|
{"auto_pad","str","NOTSET"}
|
||||||
|
,{"dilations","", ""}
|
||||||
|
,{"group","int", "1"}
|
||||||
|
,{"kernel_shape","", ""}
|
||||||
|
,{"pads","", ""}
|
||||||
|
,{"strides","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "ConvInteger") {
|
}else if (OpName == "ConvInteger") {
|
||||||
OneOut(ConvInteger, 4, 1);
|
ImportNodeOneOut<mlir::ONNXConvIntegerOp>(node, 4, 1, {
|
||||||
|
{"auto_pad","str","NOTSET"}
|
||||||
|
,{"dilations","", ""}
|
||||||
|
,{"group","int","1"}
|
||||||
|
,{"kernel_shape","", ""}
|
||||||
|
,{"pads","", ""}
|
||||||
|
,{"strides","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "ConvTranspose") {
|
}else if (OpName == "ConvTranspose") {
|
||||||
OneOut(ConvTranspose, 3, 1);
|
ImportNodeOneOut<mlir::ONNXConvTransposeOp>(node, 3, 1, {
|
||||||
|
{"auto_pad","str","NOTSET"}
|
||||||
|
,{"dilations","", ""}
|
||||||
|
,{"group","int","1"}
|
||||||
|
,{"kernel_shape","", ""}
|
||||||
|
,{"output_padding","", ""}
|
||||||
|
,{"output_shape","", ""}
|
||||||
|
,{"pads","", ""}
|
||||||
|
,{"strides","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "Cos") {
|
}else if (OpName == "Cos") {
|
||||||
OneOut(Cos, 1, 1);
|
ImportNodeOneOut<mlir::ONNXCosOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Cosh") {
|
}else if (OpName == "Cosh") {
|
||||||
OneOut(Cosh, 1, 1);
|
ImportNodeOneOut<mlir::ONNXCoshOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "CumSum") {
|
}else if (OpName == "CumSum") {
|
||||||
OneOut(CumSum, 2, 1);
|
ImportNodeOneOut<mlir::ONNXCumSumOp>(node, 2, 1, {
|
||||||
|
{"exclusive","int","0"}
|
||||||
|
,{"reverse","int","0"}
|
||||||
|
});
|
||||||
}else if (OpName == "DepthToSpace") {
|
}else if (OpName == "DepthToSpace") {
|
||||||
OneOut(DepthToSpace, 1, 1);
|
ImportNodeOneOut<mlir::ONNXDepthToSpaceOp>(node, 1, 1, {
|
||||||
|
{"blocksize","", ""}
|
||||||
|
,{"mode","str","DCR"}
|
||||||
|
});
|
||||||
}else if (OpName == "DequantizeLinear") {
|
}else if (OpName == "DequantizeLinear") {
|
||||||
OneOut(DequantizeLinear, 3, 1);
|
ImportNodeOneOut<mlir::ONNXDequantizeLinearOp>(node, 3, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Det") {
|
}else if (OpName == "Det") {
|
||||||
OneOut(Det, 1, 1);
|
ImportNodeOneOut<mlir::ONNXDetOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Div") {
|
}else if (OpName == "Div") {
|
||||||
OneOut(Div, 2, 1);
|
ImportNodeOneOut<mlir::ONNXDivOp>(node, 2, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Dropout") {
|
}else if (OpName == "Dropout") {
|
||||||
MultipleOuts(Dropout, 1, 2);
|
ImportNodeMultipleOuts<mlir::ONNXDropoutOp>(node, 1, 2, {
|
||||||
|
{"ratio","float","0.5"}
|
||||||
|
});
|
||||||
}else if (OpName == "DynamicQuantizeLinear") {
|
}else if (OpName == "DynamicQuantizeLinear") {
|
||||||
MultipleOuts(DynamicQuantizeLinear, 1, 3);
|
ImportNodeMultipleOuts<mlir::ONNXDynamicQuantizeLinearOp>(node, 1, 3, {
|
||||||
|
});
|
||||||
}else if (OpName == "Elu") {
|
}else if (OpName == "Elu") {
|
||||||
OneOut(Elu, 1, 1);
|
ImportNodeOneOut<mlir::ONNXEluOp>(node, 1, 1, {
|
||||||
|
{"alpha","float","1.0"}
|
||||||
|
});
|
||||||
}else if (OpName == "Equal") {
|
}else if (OpName == "Equal") {
|
||||||
OneOut(Equal, 2, 1);
|
ImportNodeOneOut<mlir::ONNXEqualOp>(node, 2, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Erf") {
|
}else if (OpName == "Erf") {
|
||||||
OneOut(Erf, 1, 1);
|
ImportNodeOneOut<mlir::ONNXErfOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Exp") {
|
}else if (OpName == "Exp") {
|
||||||
OneOut(Exp, 1, 1);
|
ImportNodeOneOut<mlir::ONNXExpOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Expand") {
|
}else if (OpName == "Expand") {
|
||||||
OneOut(Expand, 2, 1);
|
ImportNodeOneOut<mlir::ONNXExpandOp>(node, 2, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "EyeLike") {
|
}else if (OpName == "EyeLike") {
|
||||||
OneOut(EyeLike, 1, 1);
|
ImportNodeOneOut<mlir::ONNXEyeLikeOp>(node, 1, 1, {
|
||||||
|
{"dtype","", ""}
|
||||||
|
,{"k","int","0"}
|
||||||
|
});
|
||||||
}else if (OpName == "Flatten") {
|
}else if (OpName == "Flatten") {
|
||||||
OneOut(Flatten, 1, 1);
|
ImportNodeOneOut<mlir::ONNXFlattenOp>(node, 1, 1, {
|
||||||
|
{"axis","int","1"}
|
||||||
|
});
|
||||||
}else if (OpName == "Floor") {
|
}else if (OpName == "Floor") {
|
||||||
OneOut(Floor, 1, 1);
|
ImportNodeOneOut<mlir::ONNXFloorOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "GRU") {
|
}else if (OpName == "GRU") {
|
||||||
MultipleOuts(GRU, 6, 2);
|
ImportNodeMultipleOuts<mlir::ONNXGRUOp>(node, 6, 2, {
|
||||||
|
{"activation_alpha","", ""}
|
||||||
|
,{"activation_beta","", ""}
|
||||||
|
,{"activations","", ""}
|
||||||
|
,{"clip","", ""}
|
||||||
|
,{"direction","str","forward"}
|
||||||
|
,{"hidden_size","", ""}
|
||||||
|
,{"linear_before_reset","int","0"}
|
||||||
|
});
|
||||||
}else if (OpName == "Gather") {
|
}else if (OpName == "Gather") {
|
||||||
OneOut(Gather, 2, 1);
|
ImportNodeOneOut<mlir::ONNXGatherOp>(node, 2, 1, {
|
||||||
|
{"axis","int","0"}
|
||||||
|
});
|
||||||
}else if (OpName == "GatherElements") {
|
}else if (OpName == "GatherElements") {
|
||||||
OneOut(GatherElements, 2, 1);
|
ImportNodeOneOut<mlir::ONNXGatherElementsOp>(node, 2, 1, {
|
||||||
|
{"axis","int","0"}
|
||||||
|
});
|
||||||
}else if (OpName == "GatherND") {
|
}else if (OpName == "GatherND") {
|
||||||
OneOut(GatherND, 2, 1);
|
ImportNodeOneOut<mlir::ONNXGatherNDOp>(node, 2, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Gemm") {
|
}else if (OpName == "Gemm") {
|
||||||
OneOut(Gemm, 3, 1);
|
ImportNodeOneOut<mlir::ONNXGemmOp>(node, 3, 1, {
|
||||||
|
{"alpha","float","1.0"}
|
||||||
|
,{"beta","float","1.0"}
|
||||||
|
,{"transA","int","0"}
|
||||||
|
,{"transB","int","0"}
|
||||||
|
});
|
||||||
}else if (OpName == "GlobalAveragePool") {
|
}else if (OpName == "GlobalAveragePool") {
|
||||||
OneOut(GlobalAveragePool, 1, 1);
|
ImportNodeOneOut<mlir::ONNXGlobalAveragePoolOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "GlobalLpPool") {
|
}else if (OpName == "GlobalLpPool") {
|
||||||
OneOut(GlobalLpPool, 1, 1);
|
ImportNodeOneOut<mlir::ONNXGlobalLpPoolOp>(node, 1, 1, {
|
||||||
|
{"p","int","2"}
|
||||||
|
});
|
||||||
}else if (OpName == "GlobalMaxPool") {
|
}else if (OpName == "GlobalMaxPool") {
|
||||||
OneOut(GlobalMaxPool, 1, 1);
|
ImportNodeOneOut<mlir::ONNXGlobalMaxPoolOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Greater") {
|
}else if (OpName == "Greater") {
|
||||||
OneOut(Greater, 2, 1);
|
ImportNodeOneOut<mlir::ONNXGreaterOp>(node, 2, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "HardSigmoid") {
|
}else if (OpName == "HardSigmoid") {
|
||||||
OneOut(HardSigmoid, 1, 1);
|
ImportNodeOneOut<mlir::ONNXHardSigmoidOp>(node, 1, 1, {
|
||||||
|
{"alpha","float","0.2"}
|
||||||
|
,{"beta","float","0.5"}
|
||||||
|
});
|
||||||
}else if (OpName == "Hardmax") {
|
}else if (OpName == "Hardmax") {
|
||||||
OneOut(Hardmax, 1, 1);
|
ImportNodeOneOut<mlir::ONNXHardmaxOp>(node, 1, 1, {
|
||||||
|
{"axis","int","1"}
|
||||||
|
});
|
||||||
}else if (OpName == "Identity") {
|
}else if (OpName == "Identity") {
|
||||||
OneOut(Identity, 1, 1);
|
ImportNodeOneOut<mlir::ONNXIdentityOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "If") {
|
}else if (OpName == "If") {
|
||||||
OneOut(If, 1, 1);
|
ImportNodeOneOut<mlir::ONNXIfOp>(node, 1, 1, {
|
||||||
|
{"else_branch","", ""}
|
||||||
|
,{"then_branch","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "InstanceNormalization") {
|
}else if (OpName == "InstanceNormalization") {
|
||||||
OneOut(InstanceNormalization, 3, 1);
|
ImportNodeOneOut<mlir::ONNXInstanceNormalizationOp>(node, 3, 1, {
|
||||||
|
{"epsilon","float","1e-05"}
|
||||||
|
});
|
||||||
}else if (OpName == "IsInf") {
|
}else if (OpName == "IsInf") {
|
||||||
OneOut(IsInf, 1, 1);
|
ImportNodeOneOut<mlir::ONNXIsInfOp>(node, 1, 1, {
|
||||||
|
{"detect_negative","int","1"}
|
||||||
|
,{"detect_positive","int","1"}
|
||||||
|
});
|
||||||
}else if (OpName == "IsNaN") {
|
}else if (OpName == "IsNaN") {
|
||||||
OneOut(IsNaN, 1, 1);
|
ImportNodeOneOut<mlir::ONNXIsNaNOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "LRN") {
|
}else if (OpName == "LRN") {
|
||||||
OneOut(LRN, 1, 1);
|
ImportNodeOneOut<mlir::ONNXLRNOp>(node, 1, 1, {
|
||||||
|
{"alpha","float","0.0001"}
|
||||||
|
,{"beta","float","0.75"}
|
||||||
|
,{"bias","float","1.0"}
|
||||||
|
,{"size","int", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "LSTM") {
|
}else if (OpName == "LSTM") {
|
||||||
MultipleOuts(LSTM, 8, 3);
|
ImportNodeMultipleOuts<mlir::ONNXLSTMOp>(node, 8, 3, {
|
||||||
|
{"activation_alpha","", ""}
|
||||||
|
,{"activation_beta","", ""}
|
||||||
|
,{"activations","", ""}
|
||||||
|
,{"clip","", ""}
|
||||||
|
,{"direction","str","forward"}
|
||||||
|
,{"hidden_size","", ""}
|
||||||
|
,{"input_forget","int","0"}
|
||||||
|
});
|
||||||
}else if (OpName == "LeakyRelu") {
|
}else if (OpName == "LeakyRelu") {
|
||||||
OneOut(LeakyRelu, 1, 1);
|
ImportNodeOneOut<mlir::ONNXLeakyReluOp>(node, 1, 1, {
|
||||||
|
{"alpha","float","0.01"}
|
||||||
|
});
|
||||||
}else if (OpName == "Less") {
|
}else if (OpName == "Less") {
|
||||||
OneOut(Less, 2, 1);
|
ImportNodeOneOut<mlir::ONNXLessOp>(node, 2, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Log") {
|
}else if (OpName == "Log") {
|
||||||
OneOut(Log, 1, 1);
|
ImportNodeOneOut<mlir::ONNXLogOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "LogSoftmax") {
|
}else if (OpName == "LogSoftmax") {
|
||||||
OneOut(LogSoftmax, 1, 1);
|
ImportNodeOneOut<mlir::ONNXLogSoftmaxOp>(node, 1, 1, {
|
||||||
|
{"axis","int","1"}
|
||||||
|
});
|
||||||
}else if (OpName == "Loop") {
|
}else if (OpName == "Loop") {
|
||||||
OneOut(Loop, 3, 1);
|
ImportNodeOneOut<mlir::ONNXLoopOp>(node, 3, 1, {
|
||||||
|
{"body","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "LpNormalization") {
|
}else if (OpName == "LpNormalization") {
|
||||||
OneOut(LpNormalization, 1, 1);
|
ImportNodeOneOut<mlir::ONNXLpNormalizationOp>(node, 1, 1, {
|
||||||
|
{"axis","int","-1"}
|
||||||
|
,{"p","int","2"}
|
||||||
|
});
|
||||||
}else if (OpName == "LpPool") {
|
}else if (OpName == "LpPool") {
|
||||||
OneOut(LpPool, 1, 1);
|
ImportNodeOneOut<mlir::ONNXLpPoolOp>(node, 1, 1, {
|
||||||
|
{"auto_pad","str","NOTSET"}
|
||||||
|
,{"kernel_shape","", ""}
|
||||||
|
,{"p","int","2"}
|
||||||
|
,{"pads","", ""}
|
||||||
|
,{"strides","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "MatMul") {
|
}else if (OpName == "MatMul") {
|
||||||
OneOut(MatMul, 2, 1);
|
ImportNodeOneOut<mlir::ONNXMatMulOp>(node, 2, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "MatMulInteger") {
|
}else if (OpName == "MatMulInteger") {
|
||||||
OneOut(MatMulInteger, 4, 1);
|
ImportNodeOneOut<mlir::ONNXMatMulIntegerOp>(node, 4, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Max") {
|
}else if (OpName == "Max") {
|
||||||
OneOut(Max, 1, 1);
|
ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "MaxPool") {
|
}else if (OpName == "MaxPool") {
|
||||||
MultipleOuts(MaxPool, 1, 2);
|
ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>(node, 1, 2, {
|
||||||
|
{"auto_pad","str","NOTSET"}
|
||||||
|
,{"ceil_mode","int","0"}
|
||||||
|
,{"dilations","", ""}
|
||||||
|
,{"kernel_shape","ints", ""}
|
||||||
|
,{"pads","", ""}
|
||||||
|
,{"storage_order","int","0"}
|
||||||
|
,{"strides","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "MaxRoiPool") {
|
}else if (OpName == "MaxRoiPool") {
|
||||||
OneOut(MaxRoiPool, 2, 1);
|
ImportNodeOneOut<mlir::ONNXMaxRoiPoolOp>(node, 2, 1, {
|
||||||
|
{"pooled_shape","", ""}
|
||||||
|
,{"spatial_scale","float","1.0"}
|
||||||
|
});
|
||||||
}else if (OpName == "MaxUnpool") {
|
}else if (OpName == "MaxUnpool") {
|
||||||
OneOut(MaxUnpool, 3, 1);
|
ImportNodeOneOut<mlir::ONNXMaxUnpoolOp>(node, 3, 1, {
|
||||||
|
{"kernel_shape","", ""}
|
||||||
|
,{"pads","", ""}
|
||||||
|
,{"strides","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "Mean") {
|
}else if (OpName == "Mean") {
|
||||||
OneOut(Mean, 1, 1);
|
ImportNodeOneOut<mlir::ONNXMeanOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "MeanVarianceNormalization") {
|
}else if (OpName == "MeanVarianceNormalization") {
|
||||||
OneOut(MeanVarianceNormalization, 1, 1);
|
ImportNodeOneOut<mlir::ONNXMeanVarianceNormalizationOp>(node, 1, 1, {
|
||||||
|
{"axes","ints","{'0', '2', '3'}"}
|
||||||
|
});
|
||||||
}else if (OpName == "Min") {
|
}else if (OpName == "Min") {
|
||||||
OneOut(Min, 1, 1);
|
ImportNodeOneOut<mlir::ONNXMinOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Mod") {
|
}else if (OpName == "Mod") {
|
||||||
OneOut(Mod, 2, 1);
|
ImportNodeOneOut<mlir::ONNXModOp>(node, 2, 1, {
|
||||||
|
{"fmod","int","0"}
|
||||||
|
});
|
||||||
}else if (OpName == "Mul") {
|
}else if (OpName == "Mul") {
|
||||||
OneOut(Mul, 2, 1);
|
ImportNodeOneOut<mlir::ONNXMulOp>(node, 2, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Multinomial") {
|
}else if (OpName == "Multinomial") {
|
||||||
OneOut(Multinomial, 1, 1);
|
ImportNodeOneOut<mlir::ONNXMultinomialOp>(node, 1, 1, {
|
||||||
|
{"dtype","int","6"}
|
||||||
|
,{"sample_size","int","1"}
|
||||||
|
,{"seed","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "Neg") {
|
}else if (OpName == "Neg") {
|
||||||
OneOut(Neg, 1, 1);
|
ImportNodeOneOut<mlir::ONNXNegOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "NonMaxSuppression") {
|
}else if (OpName == "NonMaxSuppression") {
|
||||||
OneOut(NonMaxSuppression, 5, 1);
|
ImportNodeOneOut<mlir::ONNXNonMaxSuppressionOp>(node, 5, 1, {
|
||||||
|
{"center_point_box","int","0"}
|
||||||
|
});
|
||||||
}else if (OpName == "NonZero") {
|
}else if (OpName == "NonZero") {
|
||||||
OneOut(NonZero, 1, 1);
|
ImportNodeOneOut<mlir::ONNXNonZeroOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Not") {
|
}else if (OpName == "Not") {
|
||||||
OneOut(Not, 1, 1);
|
ImportNodeOneOut<mlir::ONNXNotOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "OneHot") {
|
}else if (OpName == "OneHot") {
|
||||||
OneOut(OneHot, 3, 1);
|
ImportNodeOneOut<mlir::ONNXOneHotOp>(node, 3, 1, {
|
||||||
|
{"axis","int","-1"}
|
||||||
|
});
|
||||||
}else if (OpName == "Or") {
|
}else if (OpName == "Or") {
|
||||||
OneOut(Or, 2, 1);
|
ImportNodeOneOut<mlir::ONNXOrOp>(node, 2, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "PRelu") {
|
}else if (OpName == "PRelu") {
|
||||||
OneOut(PRelu, 2, 1);
|
ImportNodeOneOut<mlir::ONNXPReluOp>(node, 2, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Pad") {
|
}else if (OpName == "Pad") {
|
||||||
OneOut(Pad, 3, 1);
|
ImportNodeOneOut<mlir::ONNXPadOp>(node, 3, 1, {
|
||||||
|
{"mode","str","constant"}
|
||||||
|
});
|
||||||
}else if (OpName == "Pow") {
|
}else if (OpName == "Pow") {
|
||||||
OneOut(Pow, 2, 1);
|
ImportNodeOneOut<mlir::ONNXPowOp>(node, 2, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "QLinearConv") {
|
}else if (OpName == "QLinearConv") {
|
||||||
OneOut(QLinearConv, 9, 1);
|
ImportNodeOneOut<mlir::ONNXQLinearConvOp>(node, 9, 1, {
|
||||||
|
{"auto_pad","str","NOTSET"}
|
||||||
|
,{"dilations","", ""}
|
||||||
|
,{"group","int","1"}
|
||||||
|
,{"kernel_shape","", ""}
|
||||||
|
,{"pads","", ""}
|
||||||
|
,{"strides","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "QLinearMatMul") {
|
}else if (OpName == "QLinearMatMul") {
|
||||||
OneOut(QLinearMatMul, 8, 1);
|
ImportNodeOneOut<mlir::ONNXQLinearMatMulOp>(node, 8, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "QuantizeLinear") {
|
}else if (OpName == "QuantizeLinear") {
|
||||||
OneOut(QuantizeLinear, 3, 1);
|
ImportNodeOneOut<mlir::ONNXQuantizeLinearOp>(node, 3, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "RNN") {
|
}else if (OpName == "RNN") {
|
||||||
MultipleOuts(RNN, 6, 2);
|
ImportNodeMultipleOuts<mlir::ONNXRNNOp>(node, 6, 2, {
|
||||||
|
{"activation_alpha","floats", "{}"}
|
||||||
|
,{"activation_beta","floats", "{}"}
|
||||||
|
,{"activations","", "{Tannh, Tanh}"}
|
||||||
|
,{"clip","", ""}
|
||||||
|
,{"direction","str","forward"}
|
||||||
|
,{"hidden_size","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "RandomNormal") {
|
}else if (OpName == "RandomNormal") {
|
||||||
OneOut(RandomNormal, 0, 1);
|
ImportNodeOneOut<mlir::ONNXRandomNormalOp>(node, 0, 1, {
|
||||||
|
{"dtype","int","1"}
|
||||||
|
,{"mean","float","0.0"}
|
||||||
|
,{"scale","float","1.0"}
|
||||||
|
,{"seed","", ""}
|
||||||
|
,{"shape","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "RandomNormalLike") {
|
}else if (OpName == "RandomNormalLike") {
|
||||||
OneOut(RandomNormalLike, 1, 1);
|
ImportNodeOneOut<mlir::ONNXRandomNormalLikeOp>(node, 1, 1, {
|
||||||
|
{"dtype","", ""}
|
||||||
|
,{"mean","float","0.0"}
|
||||||
|
,{"scale","float","1.0"}
|
||||||
|
,{"seed","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "RandomUniform") {
|
}else if (OpName == "RandomUniform") {
|
||||||
OneOut(RandomUniform, 0, 1);
|
ImportNodeOneOut<mlir::ONNXRandomUniformOp>(node, 0, 1, {
|
||||||
|
{"dtype","int","1"}
|
||||||
|
,{"high","float","1.0"}
|
||||||
|
,{"low","float","0.0"}
|
||||||
|
,{"seed","", ""}
|
||||||
|
,{"shape","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "RandomUniformLike") {
|
}else if (OpName == "RandomUniformLike") {
|
||||||
OneOut(RandomUniformLike, 1, 1);
|
ImportNodeOneOut<mlir::ONNXRandomUniformLikeOp>(node, 1, 1, {
|
||||||
|
{"dtype","", ""}
|
||||||
|
,{"high","float","1.0"}
|
||||||
|
,{"low","float","0.0"}
|
||||||
|
,{"seed","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "Range") {
|
}else if (OpName == "Range") {
|
||||||
OneOut(Range, 3, 1);
|
ImportNodeOneOut<mlir::ONNXRangeOp>(node, 3, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Reciprocal") {
|
}else if (OpName == "Reciprocal") {
|
||||||
OneOut(Reciprocal, 1, 1);
|
ImportNodeOneOut<mlir::ONNXReciprocalOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "ReduceL1") {
|
}else if (OpName == "ReduceL1") {
|
||||||
OneOut(ReduceL1, 1, 1);
|
ImportNodeOneOut<mlir::ONNXReduceL1Op>(node, 1, 1, {
|
||||||
|
{"axes","", ""}
|
||||||
|
,{"keepdims","int","1"}
|
||||||
|
});
|
||||||
}else if (OpName == "ReduceL2") {
|
}else if (OpName == "ReduceL2") {
|
||||||
OneOut(ReduceL2, 1, 1);
|
ImportNodeOneOut<mlir::ONNXReduceL2Op>(node, 1, 1, {
|
||||||
|
{"axes","", ""}
|
||||||
|
,{"keepdims","int","1"}
|
||||||
|
});
|
||||||
}else if (OpName == "ReduceLogSum") {
|
}else if (OpName == "ReduceLogSum") {
|
||||||
OneOut(ReduceLogSum, 1, 1);
|
ImportNodeOneOut<mlir::ONNXReduceLogSumOp>(node, 1, 1, {
|
||||||
|
{"axes","", ""}
|
||||||
|
,{"keepdims","int","1"}
|
||||||
|
});
|
||||||
}else if (OpName == "ReduceLogSumExp") {
|
}else if (OpName == "ReduceLogSumExp") {
|
||||||
OneOut(ReduceLogSumExp, 1, 1);
|
ImportNodeOneOut<mlir::ONNXReduceLogSumExpOp>(node, 1, 1, {
|
||||||
|
{"axes","", ""}
|
||||||
|
,{"keepdims","int","1"}
|
||||||
|
});
|
||||||
}else if (OpName == "ReduceMax") {
|
}else if (OpName == "ReduceMax") {
|
||||||
OneOut(ReduceMax, 1, 1);
|
ImportNodeOneOut<mlir::ONNXReduceMaxOp>(node, 1, 1, {
|
||||||
|
{"axes","", ""}
|
||||||
|
,{"keepdims","int","1"}
|
||||||
|
});
|
||||||
}else if (OpName == "ReduceMean") {
|
}else if (OpName == "ReduceMean") {
|
||||||
OneOut(ReduceMean, 1, 1);
|
ImportNodeOneOut<mlir::ONNXReduceMeanOp>(node, 1, 1, {
|
||||||
|
{"axes","", ""}
|
||||||
|
,{"keepdims","int","1"}
|
||||||
|
});
|
||||||
}else if (OpName == "ReduceMin") {
|
}else if (OpName == "ReduceMin") {
|
||||||
OneOut(ReduceMin, 1, 1);
|
ImportNodeOneOut<mlir::ONNXReduceMinOp>(node, 1, 1, {
|
||||||
|
{"axes","", ""}
|
||||||
|
,{"keepdims","int","1"}
|
||||||
|
});
|
||||||
}else if (OpName == "ReduceProd") {
|
}else if (OpName == "ReduceProd") {
|
||||||
OneOut(ReduceProd, 1, 1);
|
ImportNodeOneOut<mlir::ONNXReduceProdOp>(node, 1, 1, {
|
||||||
|
{"axes","", ""}
|
||||||
|
,{"keepdims","int","1"}
|
||||||
|
});
|
||||||
}else if (OpName == "ReduceSum") {
|
}else if (OpName == "ReduceSum") {
|
||||||
OneOut(ReduceSum, 1, 1);
|
ImportNodeOneOut<mlir::ONNXReduceSumOp>(node, 1, 1, {
|
||||||
|
{"axes","", ""}
|
||||||
|
,{"keepdims","int","1"}
|
||||||
|
});
|
||||||
}else if (OpName == "ReduceSumSquare") {
|
}else if (OpName == "ReduceSumSquare") {
|
||||||
OneOut(ReduceSumSquare, 1, 1);
|
ImportNodeOneOut<mlir::ONNXReduceSumSquareOp>(node, 1, 1, {
|
||||||
|
{"axes","", ""}
|
||||||
|
,{"keepdims","int","1"}
|
||||||
|
});
|
||||||
}else if (OpName == "Relu") {
|
}else if (OpName == "Relu") {
|
||||||
OneOut(Relu, 1, 1);
|
ImportNodeOneOut<mlir::ONNXReluOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Reshape") {
|
}else if (OpName == "Reshape") {
|
||||||
OneOut(Reshape, 2, 1);
|
ImportNodeOneOut<mlir::ONNXReshapeOp>(node, 2, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Resize") {
|
}else if (OpName == "Resize") {
|
||||||
OneOut(Resize, 4, 1);
|
ImportNodeOneOut<mlir::ONNXResizeOp>(node, 4, 1, {
|
||||||
|
{"coordinate_transformation_mode","str","half_pixel"}
|
||||||
|
,{"cubic_coeff_a","float","-0.75"}
|
||||||
|
,{"exclude_outside","int","0"}
|
||||||
|
,{"extrapolation_value","float","0.0"}
|
||||||
|
,{"mode","str","nearest"}
|
||||||
|
,{"nearest_mode","str","round_prefer_floor"}
|
||||||
|
});
|
||||||
}else if (OpName == "ReverseSequence") {
|
}else if (OpName == "ReverseSequence") {
|
||||||
OneOut(ReverseSequence, 2, 1);
|
ImportNodeOneOut<mlir::ONNXReverseSequenceOp>(node, 2, 1, {
|
||||||
|
{"batch_axis","int","1"}
|
||||||
|
,{"time_axis","int","0"}
|
||||||
|
});
|
||||||
}else if (OpName == "RoiAlign") {
|
}else if (OpName == "RoiAlign") {
|
||||||
OneOut(RoiAlign, 3, 1);
|
ImportNodeOneOut<mlir::ONNXRoiAlignOp>(node, 3, 1, {
|
||||||
|
{"mode","str","avg"}
|
||||||
|
,{"output_height","int","1"}
|
||||||
|
,{"output_width","int","1"}
|
||||||
|
,{"sampling_ratio","int","0"}
|
||||||
|
,{"spatial_scale","float","1.0"}
|
||||||
|
});
|
||||||
}else if (OpName == "Round") {
|
}else if (OpName == "Round") {
|
||||||
OneOut(Round, 1, 1);
|
ImportNodeOneOut<mlir::ONNXRoundOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Scan") {
|
}else if (OpName == "Scan") {
|
||||||
OneOut(Scan, 1, 1);
|
ImportNodeOneOut<mlir::ONNXScanOp>(node, 1, 1, {
|
||||||
|
{"body","", ""}
|
||||||
|
,{"num_scan_inputs","", ""}
|
||||||
|
,{"scan_input_axes","", ""}
|
||||||
|
,{"scan_input_directions","", ""}
|
||||||
|
,{"scan_output_axes","", ""}
|
||||||
|
,{"scan_output_directions","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "Scatter") {
|
}else if (OpName == "Scatter") {
|
||||||
OneOut(Scatter, 3, 1);
|
ImportNodeOneOut<mlir::ONNXScatterOp>(node, 3, 1, {
|
||||||
|
{"axis","int","0"}
|
||||||
|
});
|
||||||
}else if (OpName == "ScatterElements") {
|
}else if (OpName == "ScatterElements") {
|
||||||
OneOut(ScatterElements, 3, 1);
|
ImportNodeOneOut<mlir::ONNXScatterElementsOp>(node, 3, 1, {
|
||||||
|
{"axis","int","0"}
|
||||||
|
});
|
||||||
}else if (OpName == "ScatterND") {
|
}else if (OpName == "ScatterND") {
|
||||||
OneOut(ScatterND, 3, 1);
|
ImportNodeOneOut<mlir::ONNXScatterNDOp>(node, 3, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Selu") {
|
}else if (OpName == "Selu") {
|
||||||
OneOut(Selu, 1, 1);
|
ImportNodeOneOut<mlir::ONNXSeluOp>(node, 1, 1, {
|
||||||
|
{"alpha","float","1.67326"}
|
||||||
|
,{"gamma","float","1.0507"}
|
||||||
|
});
|
||||||
}else if (OpName == "SequenceAt") {
|
}else if (OpName == "SequenceAt") {
|
||||||
OneOut(SequenceAt, 2, 1);
|
ImportNodeOneOut<mlir::ONNXSequenceAtOp>(node, 2, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "SequenceConstruct") {
|
}else if (OpName == "SequenceConstruct") {
|
||||||
OneOut(SequenceConstruct, 1, 1);
|
ImportNodeOneOut<mlir::ONNXSequenceConstructOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "SequenceEmpty") {
|
}else if (OpName == "SequenceEmpty") {
|
||||||
OneOut(SequenceEmpty, 0, 1);
|
ImportNodeOneOut<mlir::ONNXSequenceEmptyOp>(node, 0, 1, {
|
||||||
|
{"dtype","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "SequenceErase") {
|
}else if (OpName == "SequenceErase") {
|
||||||
OneOut(SequenceErase, 2, 1);
|
ImportNodeOneOut<mlir::ONNXSequenceEraseOp>(node, 2, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "SequenceInsert") {
|
}else if (OpName == "SequenceInsert") {
|
||||||
OneOut(SequenceInsert, 3, 1);
|
ImportNodeOneOut<mlir::ONNXSequenceInsertOp>(node, 3, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "SequenceLength") {
|
}else if (OpName == "SequenceLength") {
|
||||||
OneOut(SequenceLength, 1, 1);
|
ImportNodeOneOut<mlir::ONNXSequenceLengthOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Shape") {
|
}else if (OpName == "Shape") {
|
||||||
OneOut(Shape, 1, 1);
|
ImportNodeOneOut<mlir::ONNXShapeOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Shrink") {
|
}else if (OpName == "Shrink") {
|
||||||
OneOut(Shrink, 1, 1);
|
ImportNodeOneOut<mlir::ONNXShrinkOp>(node, 1, 1, {
|
||||||
|
{"bias","float","0.0"}
|
||||||
|
,{"lambd","float","0.5"}
|
||||||
|
});
|
||||||
}else if (OpName == "Sigmoid") {
|
}else if (OpName == "Sigmoid") {
|
||||||
OneOut(Sigmoid, 1, 1);
|
ImportNodeOneOut<mlir::ONNXSigmoidOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Sign") {
|
}else if (OpName == "Sign") {
|
||||||
OneOut(Sign, 1, 1);
|
ImportNodeOneOut<mlir::ONNXSignOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Sin") {
|
}else if (OpName == "Sin") {
|
||||||
OneOut(Sin, 1, 1);
|
ImportNodeOneOut<mlir::ONNXSinOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Sinh") {
|
}else if (OpName == "Sinh") {
|
||||||
OneOut(Sinh, 1, 1);
|
ImportNodeOneOut<mlir::ONNXSinhOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Size") {
|
}else if (OpName == "Size") {
|
||||||
OneOut(Size, 1, 1);
|
ImportNodeOneOut<mlir::ONNXSizeOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Slice") {
|
}else if (OpName == "Slice") {
|
||||||
OneOut(Slice, 5, 1);
|
ImportNodeOneOut<mlir::ONNXSliceOp>(node, 5, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Softmax") {
|
}else if (OpName == "Softmax") {
|
||||||
OneOut(Softmax, 1, 1);
|
ImportNodeOneOut<mlir::ONNXSoftmaxOp>(node, 1, 1, {
|
||||||
|
{"axis","int","1"}
|
||||||
|
});
|
||||||
}else if (OpName == "Softplus") {
|
}else if (OpName == "Softplus") {
|
||||||
OneOut(Softplus, 1, 1);
|
ImportNodeOneOut<mlir::ONNXSoftplusOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Softsign") {
|
}else if (OpName == "Softsign") {
|
||||||
OneOut(Softsign, 1, 1);
|
ImportNodeOneOut<mlir::ONNXSoftsignOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "SpaceToDepth") {
|
}else if (OpName == "SpaceToDepth") {
|
||||||
OneOut(SpaceToDepth, 1, 1);
|
ImportNodeOneOut<mlir::ONNXSpaceToDepthOp>(node, 1, 1, {
|
||||||
|
{"blocksize","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "Split") {
|
}else if (OpName == "Split") {
|
||||||
OneOut(Split, 1, 1);
|
ImportNodeOneOut<mlir::ONNXSplitOp>(node, 1, 1, {
|
||||||
|
{"axis","int","0"}
|
||||||
|
,{"split","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "SplitToSequence") {
|
}else if (OpName == "SplitToSequence") {
|
||||||
OneOut(SplitToSequence, 2, 1);
|
ImportNodeOneOut<mlir::ONNXSplitToSequenceOp>(node, 2, 1, {
|
||||||
|
{"axis","int","0"}
|
||||||
|
,{"keepdims","int","1"}
|
||||||
|
});
|
||||||
}else if (OpName == "Sqrt") {
|
}else if (OpName == "Sqrt") {
|
||||||
OneOut(Sqrt, 1, 1);
|
ImportNodeOneOut<mlir::ONNXSqrtOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Squeeze") {
|
}else if (OpName == "Squeeze") {
|
||||||
OneOut(Squeeze, 1, 1);
|
ImportNodeOneOut<mlir::ONNXSqueezeOp>(node, 1, 1, {
|
||||||
|
{"axes","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "StringNormalizer") {
|
}else if (OpName == "StringNormalizer") {
|
||||||
OneOut(StringNormalizer, 1, 1);
|
ImportNodeOneOut<mlir::ONNXStringNormalizerOp>(node, 1, 1, {
|
||||||
|
{"case_change_action","str","NONE"}
|
||||||
|
,{"is_case_sensitive","int","0"}
|
||||||
|
,{"locale","", ""}
|
||||||
|
,{"stopwords","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "Sub") {
|
}else if (OpName == "Sub") {
|
||||||
OneOut(Sub, 2, 1);
|
ImportNodeOneOut<mlir::ONNXSubOp>(node, 2, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Sum") {
|
}else if (OpName == "Sum") {
|
||||||
OneOut(Sum, 1, 1);
|
ImportNodeOneOut<mlir::ONNXSumOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Tan") {
|
}else if (OpName == "Tan") {
|
||||||
OneOut(Tan, 1, 1);
|
ImportNodeOneOut<mlir::ONNXTanOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Tanh") {
|
}else if (OpName == "Tanh") {
|
||||||
OneOut(Tanh, 1, 1);
|
ImportNodeOneOut<mlir::ONNXTanhOp>(node, 1, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "TfIdfVectorizer") {
|
}else if (OpName == "TfIdfVectorizer") {
|
||||||
OneOut(TfIdfVectorizer, 1, 1);
|
ImportNodeOneOut<mlir::ONNXTfIdfVectorizerOp>(node, 1, 1, {
|
||||||
|
{"max_gram_length","", ""}
|
||||||
|
,{"max_skip_count","", ""}
|
||||||
|
,{"min_gram_length","", ""}
|
||||||
|
,{"mode","", ""}
|
||||||
|
,{"ngram_counts","", ""}
|
||||||
|
,{"ngram_indexes","", ""}
|
||||||
|
,{"pool_int64s","", ""}
|
||||||
|
,{"pool_strings","", ""}
|
||||||
|
,{"weights","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "ThresholdedRelu") {
|
}else if (OpName == "ThresholdedRelu") {
|
||||||
OneOut(ThresholdedRelu, 1, 1);
|
ImportNodeOneOut<mlir::ONNXThresholdedReluOp>(node, 1, 1, {
|
||||||
|
{"alpha","float","1.0"}
|
||||||
|
});
|
||||||
}else if (OpName == "Tile") {
|
}else if (OpName == "Tile") {
|
||||||
OneOut(Tile, 2, 1);
|
ImportNodeOneOut<mlir::ONNXTileOp>(node, 2, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "TopK") {
|
}else if (OpName == "TopK") {
|
||||||
MultipleOuts(TopK, 2, 2);
|
ImportNodeMultipleOuts<mlir::ONNXTopKOp>(node, 2, 2, {
|
||||||
|
{"axis","int","-1"}
|
||||||
|
,{"largest","int","1"}
|
||||||
|
,{"sorted","int","1"}
|
||||||
|
});
|
||||||
}else if (OpName == "Transpose") {
|
}else if (OpName == "Transpose") {
|
||||||
OneOut(Transpose, 1, 1);
|
ImportNodeOneOut<mlir::ONNXTransposeOp>(node, 1, 1, {
|
||||||
|
{"perm","", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "Unique") {
|
}else if (OpName == "Unique") {
|
||||||
MultipleOuts(Unique, 1, 4);
|
ImportNodeMultipleOuts<mlir::ONNXUniqueOp>(node, 1, 4, {
|
||||||
|
{"axis","", ""}
|
||||||
|
,{"sorted","int","1"}
|
||||||
|
});
|
||||||
}else if (OpName == "Unsqueeze") {
|
}else if (OpName == "Unsqueeze") {
|
||||||
OneOut(Unsqueeze, 1, 1);
|
ImportNodeOneOut<mlir::ONNXUnsqueezeOp>(node, 1, 1, {
|
||||||
|
{"axes","ints", ""}
|
||||||
|
});
|
||||||
}else if (OpName == "Upsample") {
|
}else if (OpName == "Upsample") {
|
||||||
OneOut(Upsample, 2, 1);
|
ImportNodeOneOut<mlir::ONNXUpsampleOp>(node, 2, 1, {
|
||||||
|
{"mode","str","nearest"}
|
||||||
|
});
|
||||||
}else if (OpName == "Where") {
|
}else if (OpName == "Where") {
|
||||||
OneOut(Where, 3, 1);
|
ImportNodeOneOut<mlir::ONNXWhereOp>(node, 3, 1, {
|
||||||
|
});
|
||||||
}else if (OpName == "Xor") {
|
}else if (OpName == "Xor") {
|
||||||
OneOut(Xor, 2, 1);
|
ImportNodeOneOut<mlir::ONNXXorOp>(node, 2, 1, {
|
||||||
|
});
|
||||||
}
|
}
|
|
@ -352,6 +352,121 @@ def gen_schema(schema) :
|
||||||
|
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
"""
|
||||||
|
special cases:
|
||||||
|
* Split: attr split default value: sizeof(output1) namely 1
|
||||||
|
* Conv: attr dilations default value is {num_dim of first input - 2, 1}
|
||||||
|
* Conv: attr kernel_shape type is ints
|
||||||
|
* Transpose: attr perm default value is {} empty int list
|
||||||
|
"""
|
||||||
|
|
||||||
|
def gen_code(schema,fefile) :
|
||||||
|
special_handler = dict([
|
||||||
|
("Conv", "ImportNodeConv"),
|
||||||
|
#("Transpose", "ImportNodeTranspose")
|
||||||
|
])
|
||||||
|
special_type = dict([
|
||||||
|
("AveragePool "+"kernel_shape", '"ints", ""'),
|
||||||
|
("MaxPool "+"kernel_shape", '"ints", ""'),
|
||||||
|
("Cast "+"to", '"int", "0"'),
|
||||||
|
("Concat "+"axis", '"int", "0"'),
|
||||||
|
("Conv "+"group", '"int", "1"'),
|
||||||
|
("Unsqueeze "+"axes", '"ints", ""'),
|
||||||
|
("RNN "+"activation_alpha", '"floats", "{}"'),
|
||||||
|
("RNN "+"activation_beta", '"floats", "{}"'),
|
||||||
|
("RNN "+"activations", '"", "{Tannh, Tanh}"'),
|
||||||
|
("LRN "+"size", '"int", ""')
|
||||||
|
])
|
||||||
|
line_indent = ' '
|
||||||
|
fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n')
|
||||||
|
op_type_str='mlir::ONNX'+schema.name+'Op'
|
||||||
|
if schema.name in special_handler :
|
||||||
|
fefile.write(' '+special_handler[schema.name]+'(node, '
|
||||||
|
+str(len(schema.inputs))
|
||||||
|
+', ' +str(len(schema.outputs))+', {\n')
|
||||||
|
elif len(schema.outputs) > 1 :
|
||||||
|
fefile.write(' '+'ImportNodeMultipleOuts<'+op_type_str+'>(node, '
|
||||||
|
+str(len(schema.inputs))
|
||||||
|
+', ' +str(len(schema.outputs))+', {\n')
|
||||||
|
else :
|
||||||
|
fefile.write(' '+'ImportNodeOneOut<'+op_type_str+'>(node, '
|
||||||
|
+str(len(schema.inputs))
|
||||||
|
+', ' +str(len(schema.outputs))+', {\n')
|
||||||
|
|
||||||
|
|
||||||
|
if schema.attributes:
|
||||||
|
first_attr = True
|
||||||
|
for _, attr in sorted(schema.attributes.items()):
|
||||||
|
attr_line = line_indent+line_indent+line_indent+line_indent
|
||||||
|
if not first_attr:
|
||||||
|
attr_line += ',{'
|
||||||
|
else :
|
||||||
|
attr_line += ' {'
|
||||||
|
first_attr = False
|
||||||
|
|
||||||
|
attr_line += '"'+attr.name+'",'
|
||||||
|
|
||||||
|
if schema.name+' '+attr.name in special_type:
|
||||||
|
attr_line += special_type[schema.name+' '+attr.name]
|
||||||
|
# option holds either required or default value
|
||||||
|
elif attr.required:
|
||||||
|
attr_line += '"", ""'
|
||||||
|
|
||||||
|
elif attr.default_value.name:
|
||||||
|
default_value = helper.get_attribute_value(attr.default_value)
|
||||||
|
|
||||||
|
def format_value(value): # type: (Any) -> Text
|
||||||
|
if isinstance(value, float):
|
||||||
|
formatted = str(np.round(value, 5))
|
||||||
|
# use default formatting, unless too long.
|
||||||
|
if (len(formatted) > 10):
|
||||||
|
formatted = str("({:e})".format(value))
|
||||||
|
return formatted
|
||||||
|
elif isinstance(value, (bytes, bytearray)) and sys.version_info[0] == 3:
|
||||||
|
return str(value.decode('utf-8'))
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
if isinstance(default_value, list):
|
||||||
|
value = default_value[0]
|
||||||
|
default_value = [format_value(val) for val in default_value]
|
||||||
|
# TODO the list type is homogenous or htergeneous?
|
||||||
|
|
||||||
|
if isinstance(value, float) :
|
||||||
|
attr_type_str = '"floats"'
|
||||||
|
elif isinstance(value, int) :
|
||||||
|
attr_type_str = '"ints"'
|
||||||
|
elif isinstance(value, str) :
|
||||||
|
attr_type_str = '"strs"'
|
||||||
|
elif isinstance(value, (bytes, bytearray)) :
|
||||||
|
attr_type_str = '"strs"'
|
||||||
|
else :
|
||||||
|
attr_type_str = '"unknowns"'
|
||||||
|
attr_option_str = '"{}"'.format(default_value)
|
||||||
|
attr_option_str = attr_option_str.replace('[', '{', 1)
|
||||||
|
attr_option_str = attr_option_str.replace(']', '}', 1)
|
||||||
|
else:
|
||||||
|
if isinstance(default_value, float) :
|
||||||
|
attr_type_str = '"float"'
|
||||||
|
elif isinstance(default_value, int) :
|
||||||
|
attr_type_str = '"int"'
|
||||||
|
elif isinstance(default_value, str) :
|
||||||
|
attr_type_str = '"str"'
|
||||||
|
elif isinstance(default_value, (bytes, bytearray)) :
|
||||||
|
attr_type_str = '"str"'
|
||||||
|
else :
|
||||||
|
attr_type_str = '"unknown"'
|
||||||
|
default_value = format_value(default_value)
|
||||||
|
attr_option_str = '"{}"'.format(default_value)
|
||||||
|
attr_line += attr_type_str+','+attr_option_str
|
||||||
|
else:
|
||||||
|
#TODO why?
|
||||||
|
attr_line += '"", ""'
|
||||||
|
attr_line += '}\n'
|
||||||
|
fefile.write(attr_line)
|
||||||
|
fefile.write(line_indent+line_indent+line_indent+'});\n')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main(args): # type: (Type[Args]) -> None
|
def main(args): # type: (Type[Args]) -> None
|
||||||
with io.open(args.changelog, 'w', newline='') as fout:
|
with io.open(args.changelog, 'w', newline='') as fout:
|
||||||
fout.write('## Operator Changelog\n')
|
fout.write('## Operator Changelog\n')
|
||||||
|
@ -453,6 +568,7 @@ def main(args): # type: (Type[Args]) -> None
|
||||||
fefile=io.open('op_build_table.inc', 'w', newline='')
|
fefile=io.open('op_build_table.inc', 'w', newline='')
|
||||||
firstfunc = True
|
firstfunc = True
|
||||||
|
|
||||||
|
fefile.write(' '+'if (OpName == "DUMMY") {\n')
|
||||||
for domain, supportmap in operator_schemas:
|
for domain, supportmap in operator_schemas:
|
||||||
s = '## {}\n'.format(display_domain_short(domain))
|
s = '## {}\n'.format(display_domain_short(domain))
|
||||||
fout.write(s)
|
fout.write(s)
|
||||||
|
@ -461,21 +577,8 @@ def main(args): # type: (Type[Args]) -> None
|
||||||
for op_type, schema, versions in namemap:
|
for op_type, schema, versions in namemap:
|
||||||
# op_type
|
# op_type
|
||||||
#print("check point 1", schema.name, len(schema.inputs), len(schema.outputs))
|
#print("check point 1", schema.name, len(schema.inputs), len(schema.outputs))
|
||||||
if firstfunc :
|
gen_code(schema, fefile)
|
||||||
fefile.write(' '+'if (OpName == "'+schema.name+'") {\n')
|
|
||||||
firstfunc = False
|
|
||||||
else :
|
|
||||||
fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n')
|
|
||||||
if len(schema.outputs) > 1 :
|
|
||||||
fefile.write(' '+'MultipleOuts('+schema.name+', '
|
|
||||||
+str(schema.since_version)+', '
|
|
||||||
+str(len(schema.inputs))+', '
|
|
||||||
+str(len(schema.outputs))+');\n')
|
|
||||||
else :
|
|
||||||
fefile.write(' '+'OneOut('+schema.name+', '
|
|
||||||
+str(schema.since_version)+', '
|
|
||||||
+str(len(schema.inputs))+', '
|
|
||||||
+str(len(schema.outputs))+');\n')
|
|
||||||
r = gen_schema(schema)
|
r = gen_schema(schema)
|
||||||
tdfile.write(r)
|
tdfile.write(r)
|
||||||
s = ('### {}<a name="{}"></a><a name="{}">**{}**' + (' (deprecated)' if schema.deprecated else '') + '</a>\n').format(
|
s = ('### {}<a name="{}"></a><a name="{}">**{}**' + (' (deprecated)' if schema.deprecated else '') + '</a>\n').format(
|
||||||
|
|
|
@ -71,4 +71,26 @@ def ONNXFullGemmOp: ONNX_Op<"FullGemm",
|
||||||
let results = (outs AnyTensor);
|
let results = (outs AnyTensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def ONNXConv1Op:ONNX_Op<"Conv1",
|
||||||
|
[NoSideEffect]> {
|
||||||
|
let summary = "ONNX Conv operation";
|
||||||
|
let description = [{
|
||||||
|
"The convolution operator consumes an input tensor and a filter, and"
|
||||||
|
"computes the output."
|
||||||
|
}];
|
||||||
|
let arguments = (ins AnyTensor:$X);
|
||||||
|
let results = (outs AnyTensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
def ONNXConv3Op:ONNX_Op<"Conv3",
|
||||||
|
[NoSideEffect]> {
|
||||||
|
let summary = "ONNX Conv operation";
|
||||||
|
let description = [{
|
||||||
|
"The convolution operator consumes an input tensor and a filter, and"
|
||||||
|
"computes the output."
|
||||||
|
}];
|
||||||
|
let arguments = (ins AnyTensor:$X, AnyTensor:$W, AnyTensor:$B);
|
||||||
|
let results = (outs AnyTensor);
|
||||||
|
}
|
||||||
|
|
||||||
#endif // ONNX_OPS
|
#endif // ONNX_OPS
|
||||||
|
|
Loading…
Reference in New Issue