[MLIR] import attribute of onnx node (#383)

* add attributes as NamedAttribute

* support list value for attribute

* use std::tie to avoid c++17 feature
This commit is contained in:
TONG CHEN 2019-12-21 01:58:23 -05:00 committed by Tian Jin
parent 45608282e0
commit c8d591fb28
6 changed files with 1143 additions and 246 deletions

View File

@ -1,9 +1,12 @@
add_executable(onnf main.cpp)
target_link_libraries(onnf builder compiler ${MLIRLibs} ${Boost_LIBRARIES})
set_target_properties(onnf PROPERTIES LINK_FLAGS "-lz")
whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs})
whole_archive_link_onnf(onnf onnf_transform)
target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR})
target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR})

View File

@ -1,5 +1,6 @@
add_library(builder
frontend_dialect_transformer.cpp
frontend_dialect_transformer.hpp
op_build_table.inc
)

View File

@ -71,7 +71,10 @@ struct OnnxOnnfSymbolMapping {
* @param name onnx tensor name.
* @return onnf tensor corresponding to `name`.
*/
mlir::Value* GetTensorByOnnxName(std::string name) {
mlir::Value *GetTensorByOnnxName(std::string name) {
assert(onnx_name2onnf_tensor.find(legalize_name(name)) !=
onnx_name2onnf_tensor.end() &&
"Tensor not found");
return onnx_name2onnf_tensor.at(legalize_name(name));
}
@ -80,7 +83,9 @@ struct OnnxOnnfSymbolMapping {
* @param name onnx tensor name.
* @param tensor MLIR Value* pointer.
*/
void AddMapping(std::string name, mlir::Value* tensor) {
void AddMapping(std::string name, mlir::Value *tensor) {
assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 &&
"Tensor already exists.");
onnx_name2onnf_tensor.emplace(legalize_name(name), tensor);
}
@ -88,7 +93,7 @@ struct OnnxOnnfSymbolMapping {
return onnx_name2onnf_tensor.count(name) != 0;
}
private:
private:
/*!
* mapping from onnx tensor names to MLIR tensor.
*/
@ -96,8 +101,8 @@ struct OnnxOnnfSymbolMapping {
};
class FrontendGenImpl {
public:
FrontendGenImpl(mlir::MLIRContext& context)
public:
FrontendGenImpl(mlir::MLIRContext &context)
: context_(context), builder_(&context) {
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
}
@ -107,8 +112,8 @@ class FrontendGenImpl {
return module_;
}
private:
mlir::MLIRContext& context_;
private:
mlir::MLIRContext &context_;
mlir::ModuleOp module_;
mlir::OpBuilder builder_;
// mapping between string name and symbol
@ -145,55 +150,31 @@ class FrontendGenImpl {
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
assert(false && "Unsupported data type encountered.");
return nullptr;
}
}
//if c++17 is used, these two def can be combined with 'if constexpr'
//leave n there for possible future use
//alternative is to use template and pass the outputTypes, inputs and attributes
//as parameter
#define MultipleOuts(name, nIn, nOut)\
{ \
if (nIn == inputs.size() && nOut == outputTypes.size()) {\
auto op = builder_.create<mlir::ONNX##name##Op>(UnknownLoc(), outputTypes, inputs, attributes); \
for (int i = 0; i < node.output().size(); i++) { \
frontend_symbols_.AddMapping(\
legalize_name(node.output()[i]), op.getResult(i));\
}\
return;\
}\
}
#define OneOut(name, nIn, nOut)\
{ \
if (nIn == inputs.size() && nOut == outputTypes.size()) {\
auto op = builder_.create<mlir::ONNX##name##Op>(UnknownLoc(), outputTypes, inputs, attributes); \
frontend_symbols_.AddMapping(\
legalize_name(node.output()[0]), op.getResult());\
return;\
}\
}
/*!
* Import an onnx input tensor type by determining and recording its type
* in a list of input tensor mlir types.
* @param input onnx input tensor ValueInfoProto.
* @param arg_types list of mlir types representing types of graph input.
*/
void ImportInputTensorType(const onnx::ValueInfoProto& input,
llvm::SmallVector<mlir::Type, 4>& arg_types) {
void ImportInputTensorType(const onnx::ValueInfoProto &input,
llvm::SmallVector<mlir::Type, 4> &arg_types) {
std::vector<int64_t> dims;
auto shape_proto = input.type().tensor_type().shape();
auto input_tensor_legalized_name = legalize_name(input.name());
for (int i = 0; i < shape_proto.dim_size(); i++) {
if (shape_proto.dim()[i].dim_value()) {
int dim_numeric_size = shape_proto.dim()[i].dim_value();
assert(
dim_numeric_size != 0 && "Parsed an input tensor with a dimension size of zero");
if (dim_numeric_size > 0) {
dims.push_back(dim_numeric_size);
} else { // If dim_value < 0, then dim is parametric.
// TODO Verify the unknown dim size in MLIR
} else { // If dim_value < 0, then dim is parametric.
// TODO Verify the unknown dim size in MLIR
dims.push_back(-1);
}
} else {
@ -216,8 +197,8 @@ class FrontendGenImpl {
* @param input onnx input tensor ValueInfoProto.
* @param symbol mlir input argument.
*/
void ImportInputTensorSymbol(
const onnx::ValueInfoProto& input, mlir::Value* symbol) {
void ImportInputTensorSymbol(const onnx::ValueInfoProto &input,
mlir::Value *symbol) {
auto input_tensor_legalized_name = legalize_name(input.name());
assert(
!frontend_symbols_.ContainKey(input_tensor_legalized_name) &&
@ -225,32 +206,286 @@ class FrontendGenImpl {
frontend_symbols_.AddMapping(input_tensor_legalized_name, symbol);
}
void ImportNode(onnx::NodeProto node) {
std::vector<mlir::Value*> inputs;
template <typename T>
T get_attr_generic(onnx::NodeProto &node, std::string name,
std::function<T(onnx::AttributeProto &)> attr_getter,
T default_val) {
for (int i = 0; i < node.attribute_size(); ++i) {
auto attr = node.attribute(i);
if (attr.name() == name) {
return attr_getter(attr);
}
}
return default_val;
}
template <typename T>
T get_attr_generic(onnx::NodeProto &node, std::string name,
std::function<T(onnx::AttributeProto &)> attr_getter) {
for (int i = 0; i < node.attribute_size(); ++i) {
auto attr = node.attribute(i);
if (attr.name() == name) {
return attr_getter(attr);
}
}
assert(false && "ONNX Node Attribute Not Found!");
}
auto get_attr_ints(onnx::NodeProto &node, std::string name,
std::vector<int> default_val) {
std::function<std::vector<int>(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) {
std::vector<int> ints(attr.ints_size());
std::copy(attr.ints().begin(), attr.ints().end(), ints.begin());
return ints;
};
auto r = get_attr_generic(node, name, attr_getter, default_val);
auto dataType =
mlir::RankedTensorType::get(r.size(), builder_.getIntegerType(32));
auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r));
auto aname = node.op_type() + "." + name;
auto attr_output = builder_.getNamedAttr(aname, attr_v);
return attr_output;
}
auto get_attr_ints(onnx::NodeProto &node, std::string name) {
std::function<std::vector<int>(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) {
std::vector<int> ints(attr.ints_size());
std::copy(attr.ints().begin(), attr.ints().end(), ints.begin());
return ints;
};
auto r = get_attr_generic(node, name, attr_getter);
auto dataType =
mlir::RankedTensorType::get(r.size(), builder_.getIntegerType(32));
auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r));
auto aname = node.op_type() + "." + name;
auto attr_output = builder_.getNamedAttr(aname, attr_v);
return attr_output;
}
auto get_attr_floats(onnx::NodeProto &node, std::string name) {
std::function<std::vector<float>(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) {
std::vector<float> floats(attr.floats_size());
std::copy(attr.floats().begin(), attr.floats().end(), floats.begin());
return floats;
};
auto r = get_attr_generic(node, name, attr_getter);
auto dataType =
mlir::RankedTensorType::get(r.size(), builder_.getF32Type());
auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r));
auto aname = node.op_type() + "." + name;
auto attr_output = builder_.getNamedAttr(aname, attr_v);
return attr_output;
}
auto get_attr_floats(onnx::NodeProto &node, std::string name,
std::vector<float> default_val) {
std::function<std::vector<float>(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) {
std::vector<float> floats(attr.floats_size());
std::copy(attr.floats().begin(), attr.floats().end(), floats.begin());
return floats;
};
auto r = get_attr_generic(node, name, attr_getter, default_val);
auto dataType =
mlir::RankedTensorType::get(r.size(), builder_.getF32Type());
auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r));
auto aname = node.op_type() + "." + name;
auto attr_output = builder_.getNamedAttr(aname, attr_v);
return attr_output;
}
auto get_attr_int(onnx::NodeProto &node, std::string name) {
std::function<int(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) { return attr.i(); };
int r = get_attr_generic(node, name, attr_getter);
auto attr_v = builder_.getI32IntegerAttr(r);
auto aname = node.op_type() + "." + name;
auto attr_output = builder_.getNamedAttr(aname, attr_v);
return attr_output;
}
auto get_attr_int(onnx::NodeProto &node, std::string name, int default_val) {
std::function<int(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) { return attr.i(); };
int r = get_attr_generic(node, name, attr_getter, default_val);
auto attr_v = builder_.getI32IntegerAttr(r);
auto aname = node.op_type() + "." + name;
auto attr_output = builder_.getNamedAttr(aname, attr_v);
return attr_output;
}
auto get_attr_float(onnx::NodeProto &node, std::string name) {
std::function<float(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) { return attr.f(); };
auto r = get_attr_generic(node, name, attr_getter);
auto attr_v = builder_.getF32FloatAttr(r);
auto aname = node.op_type() + "." + name;
return builder_.getNamedAttr(aname, attr_v);
}
auto get_attr_float(onnx::NodeProto &node, std::string name,
float default_val) {
std::function<float(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) { return attr.f(); };
auto r = get_attr_generic(node, name, attr_getter, default_val);
auto attr_v = builder_.getF32FloatAttr(r);
auto aname = node.op_type() + "." + name;
return builder_.getNamedAttr(aname, attr_v);
}
auto get_attr_string(onnx::NodeProto &node, std::string name) {
std::function<std::string(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) { return attr.s(); };
auto r = get_attr_generic(node, name, attr_getter);
auto attr_v = builder_.getStringAttr(r);
auto aname = node.op_type() + "." + name;
return builder_.getNamedAttr(aname, attr_v);
}
auto get_attr_string(onnx::NodeProto &node, std::string name,
std::string default_val) {
std::function<std::string(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) { return attr.s(); };
auto r = get_attr_generic(node, name, attr_getter, default_val);
auto attr_v = builder_.getStringAttr(r);
auto aname = node.op_type() + "." + name;
return builder_.getNamedAttr(aname, attr_v);
}
/*
auto get_attr_strings(onnx::NodeProto &node, std::string name) {
std::function<std::vector<std::string>(onnx::AttributeProto &)>
attr_getter =
[](onnx::AttributeProto &attr) {
std::vector<std::string> strings(attr.strings_size());
std::copy(attr.strings().begin(), attr.strings().end(),
strings.begin()); return strings;
};
auto r = get_attr_generic(node, name, attr_getter);
return r;
return builder_.getNamedAttr(aname, attr_v);
auto dataType =
mlir::RankedTensorType::get(r.size(), builder_.get???Type());
auto attr_v = mlir::DenseElementsAttr::get(dataType,
llvm::makeArrayRef(r)); auto aname = node.op_type() + "." + name; auto
attr_output = builder_.getNamedAttr(aname, attr_v); return attr_output;
}
*/
auto get_default_ints(std::string default_str) {
std::vector<int> r;
auto start = default_str.find("{");
while (true) {
auto end = default_str.find(",", start + 1);
if (end == std::string::npos) {
end = default_str.find("}", start + 1);
if (end != std::string::npos && end > start+1) {
r.push_back(std::stoi(default_str.substr(start + 1, end)));
}
break;
} else {
r.push_back(std::stoi(default_str.substr(start + 1, end)));
}
start = end + 1;
}
return r;
}
auto get_default_floats(std::string default_str) {
std::vector<float> r;
auto start = default_str.find("{");
while (true) {
auto end = default_str.find(",", start + 1);
if (end == std::string::npos) {
end = default_str.find("}", start + 1);
if (end != std::string::npos && end > start+1) {
r.push_back(std::stof(default_str.substr(start + 1, end)));
}
break;
} else {
r.push_back(std::stof(default_str.substr(start + 1, end)));
}
start = end + 1;
}
return r;
}
auto get_default_strings(std::string default_str) {
std::vector<std::string> r;
auto start = default_str.find("{");
while (true) {
auto end = default_str.find(",", start + 1);
if (end == std::string::npos) {
end = default_str.find("}", start + 1);
if (end != std::string::npos && end > start+1) {
r.push_back(default_str.substr(start + 1, end));
}
break;
} else {
r.push_back(default_str.substr(start + 1, end));
}
start = end + 1;
}
return r;
}
onnx::TensorProto get_attr_tensor(onnx::NodeProto &node, std::string name) {
std::function<onnx::TensorProto(onnx::AttributeProto &)> attr_getter =
[](onnx::AttributeProto &attr) { return attr.t(); };
return get_attr_generic(node, name, attr_getter);
}
auto ImportNodeAttr(onnx::NodeProto node, std::string attr_name,
std::string type_name, std::string default_str) {
if (default_str == "") {
if (type_name == "int") {
return get_attr_int(node, attr_name);
} else if (type_name == "float") {
return get_attr_float(node, attr_name);
} else if (type_name == "str") {
return get_attr_string(node, attr_name);
} else if (type_name == "ints") {
return get_attr_ints(node, attr_name);
} else if (type_name == "floats") {
return get_attr_floats(node, attr_name);
} else {
assert(
false &&
"Got an empty initializer or initializer for this "
"datatype is not implemented. Something is wrong.");
}
} else {
// with default value
if (type_name == "int") {
return get_attr_int(node, attr_name, std::stoi(default_str));
} else if (type_name == "float") {
return get_attr_float(node, attr_name, std::stof(default_str));
} else if (type_name == "str") {
return get_attr_string(node, attr_name, default_str);
} else if (type_name == "ints") {
return get_attr_ints(node, attr_name, get_default_ints(default_str));
} else if (type_name == "floats") {
return get_attr_floats(node, attr_name,
get_default_floats(default_str));
} else {
assert(
false &&
"Got an empty initializer or initializer for this "
"datatype is not implemented. Something is wrong.");
}
}
}
void ImportNodeGeneric(onnx::NodeProto node) {
std::vector<mlir::Value *> inputs;
for (auto item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
}
}
std::vector<mlir::Type> outputTypes;
for (auto item : node.output()) {
outputTypes.push_back(mlir::UnrankedTensorType::get(builder_.getF32Type()));
}
std::vector<mlir::NamedAttribute> attributes;
llvm::StringRef OpName = node.op_type();
//the following code is generated by gen_doc.py
//refer to dialect/onnx/onnx.td for details
//when the input or output of then op does not match the specification,
//the generic operator is used
//one known reeason is the optional input
#include "src/builder/op_build_table.inc"
// Old way of doing things.
mlir::OperationState result(UnknownLoc(), "frontend." + node.op_type());
for (auto item : node.output()) {
result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type()));
@ -261,8 +496,165 @@ class FrontendGenImpl {
auto r = op->getResult(i);
frontend_symbols_.AddMapping(legalize_name(node.output()[i]), r);
}
}
// TODO more info from node: attributes
// if c++17 is used, ImportNodeOneOut and ImportNodeMultipleOuts can be
// combined with 'if constexpr' the issue is the type of the output is
// different. alternative way to use variadic output for all the op
/*!
* Important onnx node which generates only one output
* @param node onnx node
* @param nIn number of expected inputs
* @param nOut number of expected outputs
* @param attrs list of desription for attributes with format {name, type,
* default}
*/
template <typename T>
void ImportNodeOneOut(
onnx::NodeProto node, int nIn, int nOut,
std::initializer_list<std::tuple<std::string, std::string, std::string>>
attrs) {
std::vector<mlir::Value *> inputs;
for (auto item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
}
}
std::vector<mlir::Type> outputTypes;
for (auto item : node.output()) {
outputTypes.push_back(
mlir::UnrankedTensorType::get(builder_.getF32Type()));
}
std::vector<mlir::NamedAttribute> attributes;
//for (auto [attr_name, attr_type, attr_default] : attrs) {
for (auto oneAttr: attrs) {
std::string attr_name;
std::string attr_type;
std::string attr_default;
std::tie (attr_name, attr_type, attr_default) = oneAttr;
if (attr_type != "") {
auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default);
attributes.push_back(attr);
} else {
// TODO: the attributes need special handling
//std::cout << "missing " << node.op_type() << " " << attr_name << std::endl;
}
}
llvm::StringRef OpName = node.op_type();
if (nIn == inputs.size() && nOut == outputTypes.size()) {
auto op =
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
frontend_symbols_.AddMapping(legalize_name(node.output()[0]),
op.getResult());
} else {
ImportNodeGeneric(node);
}
}
template <typename T>
void ImportNodeMultipleOuts(
onnx::NodeProto node, int nIn, int nOut,
std::initializer_list<std::tuple<std::string, std::string, std::string>>
attrs) {
std::vector<mlir::Value *> inputs;
for (auto item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
}
}
std::vector<mlir::Type> outputTypes;
for (auto item : node.output()) {
outputTypes.push_back(
mlir::UnrankedTensorType::get(builder_.getF32Type()));
}
std::vector<mlir::NamedAttribute> attributes;
for (auto oneAttr: attrs) {
std::string attr_name;
std::string attr_type;
std::string attr_default;
std::tie (attr_name, attr_type, attr_default) = oneAttr;
if (attr_type != "") {
auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default);
attributes.push_back(attr);
} else {
// TODO: the attributes need special handling
//std::cout << "missing " << node.op_type() << " " << attr_name << std::endl;
}
}
llvm::StringRef OpName = node.op_type();
if (nIn == inputs.size() && nOut == outputTypes.size()) {
auto op =
builder_.create<T>(UnknownLoc(), outputTypes, inputs, attributes);
for (int i = 0; i < node.output().size(); i++) {
frontend_symbols_.AddMapping(legalize_name(node.output()[i]),
op.getResult(i));
}
} else {
ImportNodeGeneric(node);
}
}
/*!
* Special handle for Conv operations.
* c++ does not allow template specialization inside a class scope
* a specialized function is used
*/
void ImportNodeConv(
onnx::NodeProto node, int nIn, int nOut,
std::initializer_list<std::tuple<std::string, std::string, std::string>>
attrs) {
// Conv has attribute dilations, kernel_shape, pads, the default value of
// which is determined by the shape of first argument. However, since the
// shape is unknown now, these attributes can be not generated auto
// dilations_attr = get_attr_ints(node, "dilations",
// std::vector<int>(inputs[0]->getType().cast<RankedTensorType>.getDims()-2,
// 1));
// attributes.push_back(dilations_attr)
// similar situation for pads, strides in AveragePool
// axes of ReduceSum, pads, strides, dilations and kernel_shape of MaxPool
// TODO: fix this after type inference
if (node.input().size() == 1) {
ImportNodeOneOut<mlir::ONNXConv1Op>(node, nIn, nOut, attrs);
} else {
ImportNodeOneOut<mlir::ONNXConv3Op>(node, nIn, nOut, attrs);
}
}
void ImportNode(onnx::NodeProto node) {
std::vector<mlir::Value *> inputs;
for (auto item : node.input()) {
if (frontend_symbols_.ContainKey(legalize_name(item))) {
inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item));
}
}
std::vector<mlir::Type> outputTypes;
for (auto item : node.output()) {
outputTypes.push_back(
mlir::UnrankedTensorType::get(builder_.getF32Type()));
}
std::vector<mlir::NamedAttribute> attributes;
llvm::StringRef OpName = node.op_type();
// the following code is generated by gen_doc.py
// refer to dialect/onnx/onnx.td for details
// when the input or output of then op does not match the specification,
// the generic operator is used
// one known reeason is the optional input
#include "src/builder/op_build_table.inc"
}
/*!
@ -277,9 +669,9 @@ class FrontendGenImpl {
* @param ret_vals a vector of mlir Value* representing graph's
* output tensor.
*/
void ImportOutputTensor(const onnx::ValueInfoProto& output,
llvm::SmallVectorImpl<mlir::Type>& ret_types,
llvm::SmallVectorImpl<mlir::Value*>& ret_vals) {
void ImportOutputTensor(const onnx::ValueInfoProto &output,
llvm::SmallVectorImpl<mlir::Type> &ret_types,
llvm::SmallVectorImpl<mlir::Value *> &ret_vals) {
auto output_tensor_legalized_name = legalize_name(output.name());
assert(
frontend_symbols_.ContainKey(output_tensor_legalized_name) &&
@ -291,8 +683,8 @@ class FrontendGenImpl {
ret_vals.push_back(tensor_val);
}
void ImportGraph(
const onnx::GraphProto& graph, const std::string& name = "main") {
void ImportGraph(const onnx::GraphProto &graph,
const std::string &name = "main") {
// create a function for the graph
// TODO:
// * get name and type for the function.
@ -300,7 +692,7 @@ class FrontendGenImpl {
llvm::SmallVector<mlir::Type, 4> arg_types;
// Import the input tensor types.
for (const auto& input : graph.input()) {
for (const auto &input : graph.input()) {
ImportInputTensorType(input, arg_types);
}
@ -308,7 +700,7 @@ class FrontendGenImpl {
auto func_type = builder_.getFunctionType(arg_types, {});
auto main_func =
mlir::FuncOp::create(UnknownLoc(), name, func_type, /* attrs = */ {});
auto& entryBlock = *main_func.addEntryBlock();
auto &entryBlock = *main_func.addEntryBlock();
builder_.setInsertionPointToStart(&entryBlock);
module_.push_back(main_func);
@ -319,14 +711,14 @@ class FrontendGenImpl {
// import nodes in the graph
auto node = graph.node();
for (const auto& item : node) {
for (const auto &item : node) {
ImportNode(item);
}
llvm::SmallVector<mlir::Type, 4> ret_types;
llvm::SmallVector<mlir::Value*, 4> ret_vals;
llvm::SmallVector<mlir::Value *, 4> ret_vals;
// Import the output tensors
for (const auto& output : graph.output()) {
for (const auto &output : graph.output()) {
ImportOutputTensor(output, ret_types, ret_vals);
}
@ -337,7 +729,6 @@ class FrontendGenImpl {
func_type = builder_.getFunctionType(arg_types, ret_types);
main_func.setType(func_type);
}
}; // FrontendGenImpl class
} // namespace
} // namespace onnf
@ -354,11 +745,13 @@ mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model) {
}
void ImportFrontendModelFile(std::string model_fname,
mlir::MLIRContext& context, mlir::OwningModuleRef& module) {
mlir::MLIRContext &context,
mlir::OwningModuleRef &module) {
onnx::ModelProto model;
std::fstream input(model_fname, std::ios::in | std::ios::binary);
auto parse_success = model.ParseFromIstream(&input);
assert(parse_success && "Onnx Model Parsing Failed.");
FrontendGenImpl myONNXGen(context);
module = myONNXGen.ImportONNXModel(model);

View File

@ -1,313 +1,688 @@
if (OpName == "Abs") {
OneOut(Abs, 1, 1);
if (OpName == "DUMMY") {
}else if (OpName == "Abs") {
ImportNodeOneOut<mlir::ONNXAbsOp>(node, 1, 1, {
});
}else if (OpName == "Acos") {
OneOut(Acos, 1, 1);
ImportNodeOneOut<mlir::ONNXAcosOp>(node, 1, 1, {
});
}else if (OpName == "Acosh") {
OneOut(Acosh, 1, 1);
ImportNodeOneOut<mlir::ONNXAcoshOp>(node, 1, 1, {
});
}else if (OpName == "Add") {
OneOut(Add, 2, 1);
ImportNodeOneOut<mlir::ONNXAddOp>(node, 2, 1, {
});
}else if (OpName == "And") {
OneOut(And, 2, 1);
ImportNodeOneOut<mlir::ONNXAndOp>(node, 2, 1, {
});
}else if (OpName == "ArgMax") {
OneOut(ArgMax, 1, 1);
ImportNodeOneOut<mlir::ONNXArgMaxOp>(node, 1, 1, {
{"axis","int","0"}
,{"keepdims","int","1"}
});
}else if (OpName == "ArgMin") {
OneOut(ArgMin, 1, 1);
ImportNodeOneOut<mlir::ONNXArgMinOp>(node, 1, 1, {
{"axis","int","0"}
,{"keepdims","int","1"}
});
}else if (OpName == "Asin") {
OneOut(Asin, 1, 1);
ImportNodeOneOut<mlir::ONNXAsinOp>(node, 1, 1, {
});
}else if (OpName == "Asinh") {
OneOut(Asinh, 1, 1);
ImportNodeOneOut<mlir::ONNXAsinhOp>(node, 1, 1, {
});
}else if (OpName == "Atan") {
OneOut(Atan, 1, 1);
ImportNodeOneOut<mlir::ONNXAtanOp>(node, 1, 1, {
});
}else if (OpName == "Atanh") {
OneOut(Atanh, 1, 1);
ImportNodeOneOut<mlir::ONNXAtanhOp>(node, 1, 1, {
});
}else if (OpName == "AveragePool") {
OneOut(AveragePool, 1, 1);
ImportNodeOneOut<mlir::ONNXAveragePoolOp>(node, 1, 1, {
{"auto_pad","str","NOTSET"}
,{"ceil_mode","int","0"}
,{"count_include_pad","int","0"}
,{"kernel_shape","ints", ""}
,{"pads","", ""}
,{"strides","", ""}
});
}else if (OpName == "BatchNormalization") {
MultipleOuts(BatchNormalization, 5, 5);
ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, 5, 5, {
{"epsilon","float","1e-05"}
,{"momentum","float","0.9"}
});
}else if (OpName == "BitShift") {
OneOut(BitShift, 2, 1);
ImportNodeOneOut<mlir::ONNXBitShiftOp>(node, 2, 1, {
{"direction","", ""}
});
}else if (OpName == "Cast") {
OneOut(Cast, 1, 1);
ImportNodeOneOut<mlir::ONNXCastOp>(node, 1, 1, {
{"to","int", "0"}
});
}else if (OpName == "Ceil") {
OneOut(Ceil, 1, 1);
ImportNodeOneOut<mlir::ONNXCeilOp>(node, 1, 1, {
});
}else if (OpName == "Clip") {
OneOut(Clip, 3, 1);
ImportNodeOneOut<mlir::ONNXClipOp>(node, 3, 1, {
});
}else if (OpName == "Compress") {
OneOut(Compress, 2, 1);
ImportNodeOneOut<mlir::ONNXCompressOp>(node, 2, 1, {
{"axis","", ""}
});
}else if (OpName == "Concat") {
OneOut(Concat, 1, 1);
ImportNodeOneOut<mlir::ONNXConcatOp>(node, 1, 1, {
{"axis","int", "0"}
});
}else if (OpName == "ConcatFromSequence") {
OneOut(ConcatFromSequence, 1, 1);
ImportNodeOneOut<mlir::ONNXConcatFromSequenceOp>(node, 1, 1, {
{"axis","", ""}
,{"new_axis","int","0"}
});
}else if (OpName == "Constant") {
OneOut(Constant, 0, 1);
ImportNodeOneOut<mlir::ONNXConstantOp>(node, 0, 1, {
{"sparse_value","", ""}
,{"value","", ""}
});
}else if (OpName == "ConstantOfShape") {
OneOut(ConstantOfShape, 1, 1);
ImportNodeOneOut<mlir::ONNXConstantOfShapeOp>(node, 1, 1, {
{"value","", ""}
});
}else if (OpName == "Conv") {
OneOut(Conv, 3, 1);
ImportNodeConv(node, 3, 1, {
{"auto_pad","str","NOTSET"}
,{"dilations","", ""}
,{"group","int", "1"}
,{"kernel_shape","", ""}
,{"pads","", ""}
,{"strides","", ""}
});
}else if (OpName == "ConvInteger") {
OneOut(ConvInteger, 4, 1);
ImportNodeOneOut<mlir::ONNXConvIntegerOp>(node, 4, 1, {
{"auto_pad","str","NOTSET"}
,{"dilations","", ""}
,{"group","int","1"}
,{"kernel_shape","", ""}
,{"pads","", ""}
,{"strides","", ""}
});
}else if (OpName == "ConvTranspose") {
OneOut(ConvTranspose, 3, 1);
ImportNodeOneOut<mlir::ONNXConvTransposeOp>(node, 3, 1, {
{"auto_pad","str","NOTSET"}
,{"dilations","", ""}
,{"group","int","1"}
,{"kernel_shape","", ""}
,{"output_padding","", ""}
,{"output_shape","", ""}
,{"pads","", ""}
,{"strides","", ""}
});
}else if (OpName == "Cos") {
OneOut(Cos, 1, 1);
ImportNodeOneOut<mlir::ONNXCosOp>(node, 1, 1, {
});
}else if (OpName == "Cosh") {
OneOut(Cosh, 1, 1);
ImportNodeOneOut<mlir::ONNXCoshOp>(node, 1, 1, {
});
}else if (OpName == "CumSum") {
OneOut(CumSum, 2, 1);
ImportNodeOneOut<mlir::ONNXCumSumOp>(node, 2, 1, {
{"exclusive","int","0"}
,{"reverse","int","0"}
});
}else if (OpName == "DepthToSpace") {
OneOut(DepthToSpace, 1, 1);
ImportNodeOneOut<mlir::ONNXDepthToSpaceOp>(node, 1, 1, {
{"blocksize","", ""}
,{"mode","str","DCR"}
});
}else if (OpName == "DequantizeLinear") {
OneOut(DequantizeLinear, 3, 1);
ImportNodeOneOut<mlir::ONNXDequantizeLinearOp>(node, 3, 1, {
});
}else if (OpName == "Det") {
OneOut(Det, 1, 1);
ImportNodeOneOut<mlir::ONNXDetOp>(node, 1, 1, {
});
}else if (OpName == "Div") {
OneOut(Div, 2, 1);
ImportNodeOneOut<mlir::ONNXDivOp>(node, 2, 1, {
});
}else if (OpName == "Dropout") {
MultipleOuts(Dropout, 1, 2);
ImportNodeMultipleOuts<mlir::ONNXDropoutOp>(node, 1, 2, {
{"ratio","float","0.5"}
});
}else if (OpName == "DynamicQuantizeLinear") {
MultipleOuts(DynamicQuantizeLinear, 1, 3);
ImportNodeMultipleOuts<mlir::ONNXDynamicQuantizeLinearOp>(node, 1, 3, {
});
}else if (OpName == "Elu") {
OneOut(Elu, 1, 1);
ImportNodeOneOut<mlir::ONNXEluOp>(node, 1, 1, {
{"alpha","float","1.0"}
});
}else if (OpName == "Equal") {
OneOut(Equal, 2, 1);
ImportNodeOneOut<mlir::ONNXEqualOp>(node, 2, 1, {
});
}else if (OpName == "Erf") {
OneOut(Erf, 1, 1);
ImportNodeOneOut<mlir::ONNXErfOp>(node, 1, 1, {
});
}else if (OpName == "Exp") {
OneOut(Exp, 1, 1);
ImportNodeOneOut<mlir::ONNXExpOp>(node, 1, 1, {
});
}else if (OpName == "Expand") {
OneOut(Expand, 2, 1);
ImportNodeOneOut<mlir::ONNXExpandOp>(node, 2, 1, {
});
}else if (OpName == "EyeLike") {
OneOut(EyeLike, 1, 1);
ImportNodeOneOut<mlir::ONNXEyeLikeOp>(node, 1, 1, {
{"dtype","", ""}
,{"k","int","0"}
});
}else if (OpName == "Flatten") {
OneOut(Flatten, 1, 1);
ImportNodeOneOut<mlir::ONNXFlattenOp>(node, 1, 1, {
{"axis","int","1"}
});
}else if (OpName == "Floor") {
OneOut(Floor, 1, 1);
ImportNodeOneOut<mlir::ONNXFloorOp>(node, 1, 1, {
});
}else if (OpName == "GRU") {
MultipleOuts(GRU, 6, 2);
ImportNodeMultipleOuts<mlir::ONNXGRUOp>(node, 6, 2, {
{"activation_alpha","", ""}
,{"activation_beta","", ""}
,{"activations","", ""}
,{"clip","", ""}
,{"direction","str","forward"}
,{"hidden_size","", ""}
,{"linear_before_reset","int","0"}
});
}else if (OpName == "Gather") {
OneOut(Gather, 2, 1);
ImportNodeOneOut<mlir::ONNXGatherOp>(node, 2, 1, {
{"axis","int","0"}
});
}else if (OpName == "GatherElements") {
OneOut(GatherElements, 2, 1);
ImportNodeOneOut<mlir::ONNXGatherElementsOp>(node, 2, 1, {
{"axis","int","0"}
});
}else if (OpName == "GatherND") {
OneOut(GatherND, 2, 1);
ImportNodeOneOut<mlir::ONNXGatherNDOp>(node, 2, 1, {
});
}else if (OpName == "Gemm") {
OneOut(Gemm, 3, 1);
ImportNodeOneOut<mlir::ONNXGemmOp>(node, 3, 1, {
{"alpha","float","1.0"}
,{"beta","float","1.0"}
,{"transA","int","0"}
,{"transB","int","0"}
});
}else if (OpName == "GlobalAveragePool") {
OneOut(GlobalAveragePool, 1, 1);
ImportNodeOneOut<mlir::ONNXGlobalAveragePoolOp>(node, 1, 1, {
});
}else if (OpName == "GlobalLpPool") {
OneOut(GlobalLpPool, 1, 1);
ImportNodeOneOut<mlir::ONNXGlobalLpPoolOp>(node, 1, 1, {
{"p","int","2"}
});
}else if (OpName == "GlobalMaxPool") {
OneOut(GlobalMaxPool, 1, 1);
ImportNodeOneOut<mlir::ONNXGlobalMaxPoolOp>(node, 1, 1, {
});
}else if (OpName == "Greater") {
OneOut(Greater, 2, 1);
ImportNodeOneOut<mlir::ONNXGreaterOp>(node, 2, 1, {
});
}else if (OpName == "HardSigmoid") {
OneOut(HardSigmoid, 1, 1);
ImportNodeOneOut<mlir::ONNXHardSigmoidOp>(node, 1, 1, {
{"alpha","float","0.2"}
,{"beta","float","0.5"}
});
}else if (OpName == "Hardmax") {
OneOut(Hardmax, 1, 1);
ImportNodeOneOut<mlir::ONNXHardmaxOp>(node, 1, 1, {
{"axis","int","1"}
});
}else if (OpName == "Identity") {
OneOut(Identity, 1, 1);
ImportNodeOneOut<mlir::ONNXIdentityOp>(node, 1, 1, {
});
}else if (OpName == "If") {
OneOut(If, 1, 1);
ImportNodeOneOut<mlir::ONNXIfOp>(node, 1, 1, {
{"else_branch","", ""}
,{"then_branch","", ""}
});
}else if (OpName == "InstanceNormalization") {
OneOut(InstanceNormalization, 3, 1);
ImportNodeOneOut<mlir::ONNXInstanceNormalizationOp>(node, 3, 1, {
{"epsilon","float","1e-05"}
});
}else if (OpName == "IsInf") {
OneOut(IsInf, 1, 1);
ImportNodeOneOut<mlir::ONNXIsInfOp>(node, 1, 1, {
{"detect_negative","int","1"}
,{"detect_positive","int","1"}
});
}else if (OpName == "IsNaN") {
OneOut(IsNaN, 1, 1);
ImportNodeOneOut<mlir::ONNXIsNaNOp>(node, 1, 1, {
});
}else if (OpName == "LRN") {
OneOut(LRN, 1, 1);
ImportNodeOneOut<mlir::ONNXLRNOp>(node, 1, 1, {
{"alpha","float","0.0001"}
,{"beta","float","0.75"}
,{"bias","float","1.0"}
,{"size","int", ""}
});
}else if (OpName == "LSTM") {
MultipleOuts(LSTM, 8, 3);
ImportNodeMultipleOuts<mlir::ONNXLSTMOp>(node, 8, 3, {
{"activation_alpha","", ""}
,{"activation_beta","", ""}
,{"activations","", ""}
,{"clip","", ""}
,{"direction","str","forward"}
,{"hidden_size","", ""}
,{"input_forget","int","0"}
});
}else if (OpName == "LeakyRelu") {
OneOut(LeakyRelu, 1, 1);
ImportNodeOneOut<mlir::ONNXLeakyReluOp>(node, 1, 1, {
{"alpha","float","0.01"}
});
}else if (OpName == "Less") {
OneOut(Less, 2, 1);
ImportNodeOneOut<mlir::ONNXLessOp>(node, 2, 1, {
});
}else if (OpName == "Log") {
OneOut(Log, 1, 1);
ImportNodeOneOut<mlir::ONNXLogOp>(node, 1, 1, {
});
}else if (OpName == "LogSoftmax") {
OneOut(LogSoftmax, 1, 1);
ImportNodeOneOut<mlir::ONNXLogSoftmaxOp>(node, 1, 1, {
{"axis","int","1"}
});
}else if (OpName == "Loop") {
OneOut(Loop, 3, 1);
ImportNodeOneOut<mlir::ONNXLoopOp>(node, 3, 1, {
{"body","", ""}
});
}else if (OpName == "LpNormalization") {
OneOut(LpNormalization, 1, 1);
ImportNodeOneOut<mlir::ONNXLpNormalizationOp>(node, 1, 1, {
{"axis","int","-1"}
,{"p","int","2"}
});
}else if (OpName == "LpPool") {
OneOut(LpPool, 1, 1);
ImportNodeOneOut<mlir::ONNXLpPoolOp>(node, 1, 1, {
{"auto_pad","str","NOTSET"}
,{"kernel_shape","", ""}
,{"p","int","2"}
,{"pads","", ""}
,{"strides","", ""}
});
}else if (OpName == "MatMul") {
OneOut(MatMul, 2, 1);
ImportNodeOneOut<mlir::ONNXMatMulOp>(node, 2, 1, {
});
}else if (OpName == "MatMulInteger") {
OneOut(MatMulInteger, 4, 1);
ImportNodeOneOut<mlir::ONNXMatMulIntegerOp>(node, 4, 1, {
});
}else if (OpName == "Max") {
OneOut(Max, 1, 1);
ImportNodeOneOut<mlir::ONNXMaxOp>(node, 1, 1, {
});
}else if (OpName == "MaxPool") {
MultipleOuts(MaxPool, 1, 2);
ImportNodeMultipleOuts<mlir::ONNXMaxPoolOp>(node, 1, 2, {
{"auto_pad","str","NOTSET"}
,{"ceil_mode","int","0"}
,{"dilations","", ""}
,{"kernel_shape","ints", ""}
,{"pads","", ""}
,{"storage_order","int","0"}
,{"strides","", ""}
});
}else if (OpName == "MaxRoiPool") {
OneOut(MaxRoiPool, 2, 1);
ImportNodeOneOut<mlir::ONNXMaxRoiPoolOp>(node, 2, 1, {
{"pooled_shape","", ""}
,{"spatial_scale","float","1.0"}
});
}else if (OpName == "MaxUnpool") {
OneOut(MaxUnpool, 3, 1);
ImportNodeOneOut<mlir::ONNXMaxUnpoolOp>(node, 3, 1, {
{"kernel_shape","", ""}
,{"pads","", ""}
,{"strides","", ""}
});
}else if (OpName == "Mean") {
OneOut(Mean, 1, 1);
ImportNodeOneOut<mlir::ONNXMeanOp>(node, 1, 1, {
});
}else if (OpName == "MeanVarianceNormalization") {
OneOut(MeanVarianceNormalization, 1, 1);
ImportNodeOneOut<mlir::ONNXMeanVarianceNormalizationOp>(node, 1, 1, {
{"axes","ints","{'0', '2', '3'}"}
});
}else if (OpName == "Min") {
OneOut(Min, 1, 1);
ImportNodeOneOut<mlir::ONNXMinOp>(node, 1, 1, {
});
}else if (OpName == "Mod") {
OneOut(Mod, 2, 1);
ImportNodeOneOut<mlir::ONNXModOp>(node, 2, 1, {
{"fmod","int","0"}
});
}else if (OpName == "Mul") {
OneOut(Mul, 2, 1);
ImportNodeOneOut<mlir::ONNXMulOp>(node, 2, 1, {
});
}else if (OpName == "Multinomial") {
OneOut(Multinomial, 1, 1);
ImportNodeOneOut<mlir::ONNXMultinomialOp>(node, 1, 1, {
{"dtype","int","6"}
,{"sample_size","int","1"}
,{"seed","", ""}
});
}else if (OpName == "Neg") {
OneOut(Neg, 1, 1);
ImportNodeOneOut<mlir::ONNXNegOp>(node, 1, 1, {
});
}else if (OpName == "NonMaxSuppression") {
OneOut(NonMaxSuppression, 5, 1);
ImportNodeOneOut<mlir::ONNXNonMaxSuppressionOp>(node, 5, 1, {
{"center_point_box","int","0"}
});
}else if (OpName == "NonZero") {
OneOut(NonZero, 1, 1);
ImportNodeOneOut<mlir::ONNXNonZeroOp>(node, 1, 1, {
});
}else if (OpName == "Not") {
OneOut(Not, 1, 1);
ImportNodeOneOut<mlir::ONNXNotOp>(node, 1, 1, {
});
}else if (OpName == "OneHot") {
OneOut(OneHot, 3, 1);
ImportNodeOneOut<mlir::ONNXOneHotOp>(node, 3, 1, {
{"axis","int","-1"}
});
}else if (OpName == "Or") {
OneOut(Or, 2, 1);
ImportNodeOneOut<mlir::ONNXOrOp>(node, 2, 1, {
});
}else if (OpName == "PRelu") {
OneOut(PRelu, 2, 1);
ImportNodeOneOut<mlir::ONNXPReluOp>(node, 2, 1, {
});
}else if (OpName == "Pad") {
OneOut(Pad, 3, 1);
ImportNodeOneOut<mlir::ONNXPadOp>(node, 3, 1, {
{"mode","str","constant"}
});
}else if (OpName == "Pow") {
OneOut(Pow, 2, 1);
ImportNodeOneOut<mlir::ONNXPowOp>(node, 2, 1, {
});
}else if (OpName == "QLinearConv") {
OneOut(QLinearConv, 9, 1);
ImportNodeOneOut<mlir::ONNXQLinearConvOp>(node, 9, 1, {
{"auto_pad","str","NOTSET"}
,{"dilations","", ""}
,{"group","int","1"}
,{"kernel_shape","", ""}
,{"pads","", ""}
,{"strides","", ""}
});
}else if (OpName == "QLinearMatMul") {
OneOut(QLinearMatMul, 8, 1);
ImportNodeOneOut<mlir::ONNXQLinearMatMulOp>(node, 8, 1, {
});
}else if (OpName == "QuantizeLinear") {
OneOut(QuantizeLinear, 3, 1);
ImportNodeOneOut<mlir::ONNXQuantizeLinearOp>(node, 3, 1, {
});
}else if (OpName == "RNN") {
MultipleOuts(RNN, 6, 2);
ImportNodeMultipleOuts<mlir::ONNXRNNOp>(node, 6, 2, {
{"activation_alpha","floats", "{}"}
,{"activation_beta","floats", "{}"}
,{"activations","", "{Tannh, Tanh}"}
,{"clip","", ""}
,{"direction","str","forward"}
,{"hidden_size","", ""}
});
}else if (OpName == "RandomNormal") {
OneOut(RandomNormal, 0, 1);
ImportNodeOneOut<mlir::ONNXRandomNormalOp>(node, 0, 1, {
{"dtype","int","1"}
,{"mean","float","0.0"}
,{"scale","float","1.0"}
,{"seed","", ""}
,{"shape","", ""}
});
}else if (OpName == "RandomNormalLike") {
OneOut(RandomNormalLike, 1, 1);
ImportNodeOneOut<mlir::ONNXRandomNormalLikeOp>(node, 1, 1, {
{"dtype","", ""}
,{"mean","float","0.0"}
,{"scale","float","1.0"}
,{"seed","", ""}
});
}else if (OpName == "RandomUniform") {
OneOut(RandomUniform, 0, 1);
ImportNodeOneOut<mlir::ONNXRandomUniformOp>(node, 0, 1, {
{"dtype","int","1"}
,{"high","float","1.0"}
,{"low","float","0.0"}
,{"seed","", ""}
,{"shape","", ""}
});
}else if (OpName == "RandomUniformLike") {
OneOut(RandomUniformLike, 1, 1);
ImportNodeOneOut<mlir::ONNXRandomUniformLikeOp>(node, 1, 1, {
{"dtype","", ""}
,{"high","float","1.0"}
,{"low","float","0.0"}
,{"seed","", ""}
});
}else if (OpName == "Range") {
OneOut(Range, 3, 1);
ImportNodeOneOut<mlir::ONNXRangeOp>(node, 3, 1, {
});
}else if (OpName == "Reciprocal") {
OneOut(Reciprocal, 1, 1);
ImportNodeOneOut<mlir::ONNXReciprocalOp>(node, 1, 1, {
});
}else if (OpName == "ReduceL1") {
OneOut(ReduceL1, 1, 1);
ImportNodeOneOut<mlir::ONNXReduceL1Op>(node, 1, 1, {
{"axes","", ""}
,{"keepdims","int","1"}
});
}else if (OpName == "ReduceL2") {
OneOut(ReduceL2, 1, 1);
ImportNodeOneOut<mlir::ONNXReduceL2Op>(node, 1, 1, {
{"axes","", ""}
,{"keepdims","int","1"}
});
}else if (OpName == "ReduceLogSum") {
OneOut(ReduceLogSum, 1, 1);
ImportNodeOneOut<mlir::ONNXReduceLogSumOp>(node, 1, 1, {
{"axes","", ""}
,{"keepdims","int","1"}
});
}else if (OpName == "ReduceLogSumExp") {
OneOut(ReduceLogSumExp, 1, 1);
ImportNodeOneOut<mlir::ONNXReduceLogSumExpOp>(node, 1, 1, {
{"axes","", ""}
,{"keepdims","int","1"}
});
}else if (OpName == "ReduceMax") {
OneOut(ReduceMax, 1, 1);
ImportNodeOneOut<mlir::ONNXReduceMaxOp>(node, 1, 1, {
{"axes","", ""}
,{"keepdims","int","1"}
});
}else if (OpName == "ReduceMean") {
OneOut(ReduceMean, 1, 1);
ImportNodeOneOut<mlir::ONNXReduceMeanOp>(node, 1, 1, {
{"axes","", ""}
,{"keepdims","int","1"}
});
}else if (OpName == "ReduceMin") {
OneOut(ReduceMin, 1, 1);
ImportNodeOneOut<mlir::ONNXReduceMinOp>(node, 1, 1, {
{"axes","", ""}
,{"keepdims","int","1"}
});
}else if (OpName == "ReduceProd") {
OneOut(ReduceProd, 1, 1);
ImportNodeOneOut<mlir::ONNXReduceProdOp>(node, 1, 1, {
{"axes","", ""}
,{"keepdims","int","1"}
});
}else if (OpName == "ReduceSum") {
OneOut(ReduceSum, 1, 1);
ImportNodeOneOut<mlir::ONNXReduceSumOp>(node, 1, 1, {
{"axes","", ""}
,{"keepdims","int","1"}
});
}else if (OpName == "ReduceSumSquare") {
OneOut(ReduceSumSquare, 1, 1);
ImportNodeOneOut<mlir::ONNXReduceSumSquareOp>(node, 1, 1, {
{"axes","", ""}
,{"keepdims","int","1"}
});
}else if (OpName == "Relu") {
OneOut(Relu, 1, 1);
ImportNodeOneOut<mlir::ONNXReluOp>(node, 1, 1, {
});
}else if (OpName == "Reshape") {
OneOut(Reshape, 2, 1);
ImportNodeOneOut<mlir::ONNXReshapeOp>(node, 2, 1, {
});
}else if (OpName == "Resize") {
OneOut(Resize, 4, 1);
ImportNodeOneOut<mlir::ONNXResizeOp>(node, 4, 1, {
{"coordinate_transformation_mode","str","half_pixel"}
,{"cubic_coeff_a","float","-0.75"}
,{"exclude_outside","int","0"}
,{"extrapolation_value","float","0.0"}
,{"mode","str","nearest"}
,{"nearest_mode","str","round_prefer_floor"}
});
}else if (OpName == "ReverseSequence") {
OneOut(ReverseSequence, 2, 1);
ImportNodeOneOut<mlir::ONNXReverseSequenceOp>(node, 2, 1, {
{"batch_axis","int","1"}
,{"time_axis","int","0"}
});
}else if (OpName == "RoiAlign") {
OneOut(RoiAlign, 3, 1);
ImportNodeOneOut<mlir::ONNXRoiAlignOp>(node, 3, 1, {
{"mode","str","avg"}
,{"output_height","int","1"}
,{"output_width","int","1"}
,{"sampling_ratio","int","0"}
,{"spatial_scale","float","1.0"}
});
}else if (OpName == "Round") {
OneOut(Round, 1, 1);
ImportNodeOneOut<mlir::ONNXRoundOp>(node, 1, 1, {
});
}else if (OpName == "Scan") {
OneOut(Scan, 1, 1);
ImportNodeOneOut<mlir::ONNXScanOp>(node, 1, 1, {
{"body","", ""}
,{"num_scan_inputs","", ""}
,{"scan_input_axes","", ""}
,{"scan_input_directions","", ""}
,{"scan_output_axes","", ""}
,{"scan_output_directions","", ""}
});
}else if (OpName == "Scatter") {
OneOut(Scatter, 3, 1);
ImportNodeOneOut<mlir::ONNXScatterOp>(node, 3, 1, {
{"axis","int","0"}
});
}else if (OpName == "ScatterElements") {
OneOut(ScatterElements, 3, 1);
ImportNodeOneOut<mlir::ONNXScatterElementsOp>(node, 3, 1, {
{"axis","int","0"}
});
}else if (OpName == "ScatterND") {
OneOut(ScatterND, 3, 1);
ImportNodeOneOut<mlir::ONNXScatterNDOp>(node, 3, 1, {
});
}else if (OpName == "Selu") {
OneOut(Selu, 1, 1);
ImportNodeOneOut<mlir::ONNXSeluOp>(node, 1, 1, {
{"alpha","float","1.67326"}
,{"gamma","float","1.0507"}
});
}else if (OpName == "SequenceAt") {
OneOut(SequenceAt, 2, 1);
ImportNodeOneOut<mlir::ONNXSequenceAtOp>(node, 2, 1, {
});
}else if (OpName == "SequenceConstruct") {
OneOut(SequenceConstruct, 1, 1);
ImportNodeOneOut<mlir::ONNXSequenceConstructOp>(node, 1, 1, {
});
}else if (OpName == "SequenceEmpty") {
OneOut(SequenceEmpty, 0, 1);
ImportNodeOneOut<mlir::ONNXSequenceEmptyOp>(node, 0, 1, {
{"dtype","", ""}
});
}else if (OpName == "SequenceErase") {
OneOut(SequenceErase, 2, 1);
ImportNodeOneOut<mlir::ONNXSequenceEraseOp>(node, 2, 1, {
});
}else if (OpName == "SequenceInsert") {
OneOut(SequenceInsert, 3, 1);
ImportNodeOneOut<mlir::ONNXSequenceInsertOp>(node, 3, 1, {
});
}else if (OpName == "SequenceLength") {
OneOut(SequenceLength, 1, 1);
ImportNodeOneOut<mlir::ONNXSequenceLengthOp>(node, 1, 1, {
});
}else if (OpName == "Shape") {
OneOut(Shape, 1, 1);
ImportNodeOneOut<mlir::ONNXShapeOp>(node, 1, 1, {
});
}else if (OpName == "Shrink") {
OneOut(Shrink, 1, 1);
ImportNodeOneOut<mlir::ONNXShrinkOp>(node, 1, 1, {
{"bias","float","0.0"}
,{"lambd","float","0.5"}
});
}else if (OpName == "Sigmoid") {
OneOut(Sigmoid, 1, 1);
ImportNodeOneOut<mlir::ONNXSigmoidOp>(node, 1, 1, {
});
}else if (OpName == "Sign") {
OneOut(Sign, 1, 1);
ImportNodeOneOut<mlir::ONNXSignOp>(node, 1, 1, {
});
}else if (OpName == "Sin") {
OneOut(Sin, 1, 1);
ImportNodeOneOut<mlir::ONNXSinOp>(node, 1, 1, {
});
}else if (OpName == "Sinh") {
OneOut(Sinh, 1, 1);
ImportNodeOneOut<mlir::ONNXSinhOp>(node, 1, 1, {
});
}else if (OpName == "Size") {
OneOut(Size, 1, 1);
ImportNodeOneOut<mlir::ONNXSizeOp>(node, 1, 1, {
});
}else if (OpName == "Slice") {
OneOut(Slice, 5, 1);
ImportNodeOneOut<mlir::ONNXSliceOp>(node, 5, 1, {
});
}else if (OpName == "Softmax") {
OneOut(Softmax, 1, 1);
ImportNodeOneOut<mlir::ONNXSoftmaxOp>(node, 1, 1, {
{"axis","int","1"}
});
}else if (OpName == "Softplus") {
OneOut(Softplus, 1, 1);
ImportNodeOneOut<mlir::ONNXSoftplusOp>(node, 1, 1, {
});
}else if (OpName == "Softsign") {
OneOut(Softsign, 1, 1);
ImportNodeOneOut<mlir::ONNXSoftsignOp>(node, 1, 1, {
});
}else if (OpName == "SpaceToDepth") {
OneOut(SpaceToDepth, 1, 1);
ImportNodeOneOut<mlir::ONNXSpaceToDepthOp>(node, 1, 1, {
{"blocksize","", ""}
});
}else if (OpName == "Split") {
OneOut(Split, 1, 1);
ImportNodeOneOut<mlir::ONNXSplitOp>(node, 1, 1, {
{"axis","int","0"}
,{"split","", ""}
});
}else if (OpName == "SplitToSequence") {
OneOut(SplitToSequence, 2, 1);
ImportNodeOneOut<mlir::ONNXSplitToSequenceOp>(node, 2, 1, {
{"axis","int","0"}
,{"keepdims","int","1"}
});
}else if (OpName == "Sqrt") {
OneOut(Sqrt, 1, 1);
ImportNodeOneOut<mlir::ONNXSqrtOp>(node, 1, 1, {
});
}else if (OpName == "Squeeze") {
OneOut(Squeeze, 1, 1);
ImportNodeOneOut<mlir::ONNXSqueezeOp>(node, 1, 1, {
{"axes","", ""}
});
}else if (OpName == "StringNormalizer") {
OneOut(StringNormalizer, 1, 1);
ImportNodeOneOut<mlir::ONNXStringNormalizerOp>(node, 1, 1, {
{"case_change_action","str","NONE"}
,{"is_case_sensitive","int","0"}
,{"locale","", ""}
,{"stopwords","", ""}
});
}else if (OpName == "Sub") {
OneOut(Sub, 2, 1);
ImportNodeOneOut<mlir::ONNXSubOp>(node, 2, 1, {
});
}else if (OpName == "Sum") {
OneOut(Sum, 1, 1);
ImportNodeOneOut<mlir::ONNXSumOp>(node, 1, 1, {
});
}else if (OpName == "Tan") {
OneOut(Tan, 1, 1);
ImportNodeOneOut<mlir::ONNXTanOp>(node, 1, 1, {
});
}else if (OpName == "Tanh") {
OneOut(Tanh, 1, 1);
ImportNodeOneOut<mlir::ONNXTanhOp>(node, 1, 1, {
});
}else if (OpName == "TfIdfVectorizer") {
OneOut(TfIdfVectorizer, 1, 1);
ImportNodeOneOut<mlir::ONNXTfIdfVectorizerOp>(node, 1, 1, {
{"max_gram_length","", ""}
,{"max_skip_count","", ""}
,{"min_gram_length","", ""}
,{"mode","", ""}
,{"ngram_counts","", ""}
,{"ngram_indexes","", ""}
,{"pool_int64s","", ""}
,{"pool_strings","", ""}
,{"weights","", ""}
});
}else if (OpName == "ThresholdedRelu") {
OneOut(ThresholdedRelu, 1, 1);
ImportNodeOneOut<mlir::ONNXThresholdedReluOp>(node, 1, 1, {
{"alpha","float","1.0"}
});
}else if (OpName == "Tile") {
OneOut(Tile, 2, 1);
ImportNodeOneOut<mlir::ONNXTileOp>(node, 2, 1, {
});
}else if (OpName == "TopK") {
MultipleOuts(TopK, 2, 2);
ImportNodeMultipleOuts<mlir::ONNXTopKOp>(node, 2, 2, {
{"axis","int","-1"}
,{"largest","int","1"}
,{"sorted","int","1"}
});
}else if (OpName == "Transpose") {
OneOut(Transpose, 1, 1);
ImportNodeOneOut<mlir::ONNXTransposeOp>(node, 1, 1, {
{"perm","", ""}
});
}else if (OpName == "Unique") {
MultipleOuts(Unique, 1, 4);
ImportNodeMultipleOuts<mlir::ONNXUniqueOp>(node, 1, 4, {
{"axis","", ""}
,{"sorted","int","1"}
});
}else if (OpName == "Unsqueeze") {
OneOut(Unsqueeze, 1, 1);
ImportNodeOneOut<mlir::ONNXUnsqueezeOp>(node, 1, 1, {
{"axes","ints", ""}
});
}else if (OpName == "Upsample") {
OneOut(Upsample, 2, 1);
ImportNodeOneOut<mlir::ONNXUpsampleOp>(node, 2, 1, {
{"mode","str","nearest"}
});
}else if (OpName == "Where") {
OneOut(Where, 3, 1);
ImportNodeOneOut<mlir::ONNXWhereOp>(node, 3, 1, {
});
}else if (OpName == "Xor") {
OneOut(Xor, 2, 1);
ImportNodeOneOut<mlir::ONNXXorOp>(node, 2, 1, {
});
}

View File

@ -352,6 +352,121 @@ def gen_schema(schema) :
return s
"""
special cases:
* Split: attr split default value: sizeof(output1) namely 1
* Conv: attr dilations default value is {num_dim of first input - 2, 1}
* Conv: attr kernel_shape type is ints
* Transpose: attr perm default value is {} empty int list
"""
def gen_code(schema,fefile) :
special_handler = dict([
("Conv", "ImportNodeConv"),
#("Transpose", "ImportNodeTranspose")
])
special_type = dict([
("AveragePool "+"kernel_shape", '"ints", ""'),
("MaxPool "+"kernel_shape", '"ints", ""'),
("Cast "+"to", '"int", "0"'),
("Concat "+"axis", '"int", "0"'),
("Conv "+"group", '"int", "1"'),
("Unsqueeze "+"axes", '"ints", ""'),
("RNN "+"activation_alpha", '"floats", "{}"'),
("RNN "+"activation_beta", '"floats", "{}"'),
("RNN "+"activations", '"", "{Tannh, Tanh}"'),
("LRN "+"size", '"int", ""')
])
line_indent = ' '
fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n')
op_type_str='mlir::ONNX'+schema.name+'Op'
if schema.name in special_handler :
fefile.write(' '+special_handler[schema.name]+'(node, '
+str(len(schema.inputs))
+', ' +str(len(schema.outputs))+', {\n')
elif len(schema.outputs) > 1 :
fefile.write(' '+'ImportNodeMultipleOuts<'+op_type_str+'>(node, '
+str(len(schema.inputs))
+', ' +str(len(schema.outputs))+', {\n')
else :
fefile.write(' '+'ImportNodeOneOut<'+op_type_str+'>(node, '
+str(len(schema.inputs))
+', ' +str(len(schema.outputs))+', {\n')
if schema.attributes:
first_attr = True
for _, attr in sorted(schema.attributes.items()):
attr_line = line_indent+line_indent+line_indent+line_indent
if not first_attr:
attr_line += ',{'
else :
attr_line += ' {'
first_attr = False
attr_line += '"'+attr.name+'",'
if schema.name+' '+attr.name in special_type:
attr_line += special_type[schema.name+' '+attr.name]
# option holds either required or default value
elif attr.required:
attr_line += '"", ""'
elif attr.default_value.name:
default_value = helper.get_attribute_value(attr.default_value)
def format_value(value): # type: (Any) -> Text
if isinstance(value, float):
formatted = str(np.round(value, 5))
# use default formatting, unless too long.
if (len(formatted) > 10):
formatted = str("({:e})".format(value))
return formatted
elif isinstance(value, (bytes, bytearray)) and sys.version_info[0] == 3:
return str(value.decode('utf-8'))
return str(value)
if isinstance(default_value, list):
value = default_value[0]
default_value = [format_value(val) for val in default_value]
# TODO the list type is homogenous or htergeneous?
if isinstance(value, float) :
attr_type_str = '"floats"'
elif isinstance(value, int) :
attr_type_str = '"ints"'
elif isinstance(value, str) :
attr_type_str = '"strs"'
elif isinstance(value, (bytes, bytearray)) :
attr_type_str = '"strs"'
else :
attr_type_str = '"unknowns"'
attr_option_str = '"{}"'.format(default_value)
attr_option_str = attr_option_str.replace('[', '{', 1)
attr_option_str = attr_option_str.replace(']', '}', 1)
else:
if isinstance(default_value, float) :
attr_type_str = '"float"'
elif isinstance(default_value, int) :
attr_type_str = '"int"'
elif isinstance(default_value, str) :
attr_type_str = '"str"'
elif isinstance(default_value, (bytes, bytearray)) :
attr_type_str = '"str"'
else :
attr_type_str = '"unknown"'
default_value = format_value(default_value)
attr_option_str = '"{}"'.format(default_value)
attr_line += attr_type_str+','+attr_option_str
else:
#TODO why?
attr_line += '"", ""'
attr_line += '}\n'
fefile.write(attr_line)
fefile.write(line_indent+line_indent+line_indent+'});\n')
def main(args): # type: (Type[Args]) -> None
with io.open(args.changelog, 'w', newline='') as fout:
fout.write('## Operator Changelog\n')
@ -453,6 +568,7 @@ def main(args): # type: (Type[Args]) -> None
fefile=io.open('op_build_table.inc', 'w', newline='')
firstfunc = True
fefile.write(' '+'if (OpName == "DUMMY") {\n')
for domain, supportmap in operator_schemas:
s = '## {}\n'.format(display_domain_short(domain))
fout.write(s)
@ -461,21 +577,8 @@ def main(args): # type: (Type[Args]) -> None
for op_type, schema, versions in namemap:
# op_type
#print("check point 1", schema.name, len(schema.inputs), len(schema.outputs))
if firstfunc :
fefile.write(' '+'if (OpName == "'+schema.name+'") {\n')
firstfunc = False
else :
fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n')
if len(schema.outputs) > 1 :
fefile.write(' '+'MultipleOuts('+schema.name+', '
+str(schema.since_version)+', '
+str(len(schema.inputs))+', '
+str(len(schema.outputs))+');\n')
else :
fefile.write(' '+'OneOut('+schema.name+', '
+str(schema.since_version)+', '
+str(len(schema.inputs))+', '
+str(len(schema.outputs))+');\n')
gen_code(schema, fefile)
r = gen_schema(schema)
tdfile.write(r)
s = ('### {}<a name="{}"></a><a name="{}">**{}**' + (' (deprecated)' if schema.deprecated else '') + '</a>\n').format(

View File

@ -71,4 +71,26 @@ def ONNXFullGemmOp: ONNX_Op<"FullGemm",
let results = (outs AnyTensor);
}
def ONNXConv1Op:ONNX_Op<"Conv1",
[NoSideEffect]> {
let summary = "ONNX Conv operation";
let description = [{
"The convolution operator consumes an input tensor and a filter, and"
"computes the output."
}];
let arguments = (ins AnyTensor:$X);
let results = (outs AnyTensor);
}
def ONNXConv3Op:ONNX_Op<"Conv3",
[NoSideEffect]> {
let summary = "ONNX Conv operation";
let description = [{
"The convolution operator consumes an input tensor and a filter, and"
"computes the output."
}];
let arguments = (ins AnyTensor:$X, AnyTensor:$W, AnyTensor:$B);
let results = (outs AnyTensor);
}
#endif // ONNX_OPS