From c8d591fb28ef5b527a6b716c3837366986c38a36 Mon Sep 17 00:00:00 2001 From: TONG CHEN Date: Sat, 21 Dec 2019 01:58:23 -0500 Subject: [PATCH] [MLIR] import attribute of onnx node (#383) * add attributes as NamedAttribute * support list value for attribute * use std::tie to avoid c++17 feature --- src/CMakeLists.txt | 3 + src/builder/CMakeLists.txt | 1 + src/builder/frontend_dialect_transformer.cpp | 541 +++++++++++++-- src/builder/op_build_table.inc | 689 ++++++++++++++----- src/compiler/dialect/onnx/gen_doc.py | 133 +++- src/compiler/dialect/onnx/onnx.td | 22 + 6 files changed, 1143 insertions(+), 246 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 74fa34a..3fe6903 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,9 +1,12 @@ add_executable(onnf main.cpp) target_link_libraries(onnf builder compiler ${MLIRLibs} ${Boost_LIBRARIES}) +set_target_properties(onnf PROPERTIES LINK_FLAGS "-lz") + whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs}) whole_archive_link_onnf(onnf onnf_transform) + target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR}) target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR}) diff --git a/src/builder/CMakeLists.txt b/src/builder/CMakeLists.txt index 1cc773b..7d96296 100644 --- a/src/builder/CMakeLists.txt +++ b/src/builder/CMakeLists.txt @@ -1,5 +1,6 @@ add_library(builder frontend_dialect_transformer.cpp + frontend_dialect_transformer.hpp op_build_table.inc ) diff --git a/src/builder/frontend_dialect_transformer.cpp b/src/builder/frontend_dialect_transformer.cpp index 388650f..1d23802 100644 --- a/src/builder/frontend_dialect_transformer.cpp +++ b/src/builder/frontend_dialect_transformer.cpp @@ -71,7 +71,10 @@ struct OnnxOnnfSymbolMapping { * @param name onnx tensor name. * @return onnf tensor corresponding to `name`. */ - mlir::Value* GetTensorByOnnxName(std::string name) { + mlir::Value *GetTensorByOnnxName(std::string name) { + assert(onnx_name2onnf_tensor.find(legalize_name(name)) != + onnx_name2onnf_tensor.end() && + "Tensor not found"); return onnx_name2onnf_tensor.at(legalize_name(name)); } @@ -80,7 +83,9 @@ struct OnnxOnnfSymbolMapping { * @param name onnx tensor name. * @param tensor MLIR Value* pointer. */ - void AddMapping(std::string name, mlir::Value* tensor) { + void AddMapping(std::string name, mlir::Value *tensor) { + assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 && + "Tensor already exists."); onnx_name2onnf_tensor.emplace(legalize_name(name), tensor); } @@ -88,7 +93,7 @@ struct OnnxOnnfSymbolMapping { return onnx_name2onnf_tensor.count(name) != 0; } - private: +private: /*! * mapping from onnx tensor names to MLIR tensor. */ @@ -96,8 +101,8 @@ struct OnnxOnnfSymbolMapping { }; class FrontendGenImpl { - public: - FrontendGenImpl(mlir::MLIRContext& context) +public: + FrontendGenImpl(mlir::MLIRContext &context) : context_(context), builder_(&context) { module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); } @@ -107,8 +112,8 @@ class FrontendGenImpl { return module_; } - private: - mlir::MLIRContext& context_; +private: + mlir::MLIRContext &context_; mlir::ModuleOp module_; mlir::OpBuilder builder_; // mapping between string name and symbol @@ -145,55 +150,31 @@ class FrontendGenImpl { case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64: case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128: case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED: + assert(false && "Unsupported data type encountered."); return nullptr; } } -//if c++17 is used, these two def can be combined with 'if constexpr' -//leave n there for possible future use -//alternative is to use template and pass the outputTypes, inputs and attributes -//as parameter - -#define MultipleOuts(name, nIn, nOut)\ -{ \ - if (nIn == inputs.size() && nOut == outputTypes.size()) {\ - auto op = builder_.create(UnknownLoc(), outputTypes, inputs, attributes); \ - for (int i = 0; i < node.output().size(); i++) { \ - frontend_symbols_.AddMapping(\ - legalize_name(node.output()[i]), op.getResult(i));\ - }\ - return;\ - }\ -} - -#define OneOut(name, nIn, nOut)\ -{ \ - if (nIn == inputs.size() && nOut == outputTypes.size()) {\ - auto op = builder_.create(UnknownLoc(), outputTypes, inputs, attributes); \ - frontend_symbols_.AddMapping(\ - legalize_name(node.output()[0]), op.getResult());\ - return;\ - }\ -} - /*! * Import an onnx input tensor type by determining and recording its type * in a list of input tensor mlir types. * @param input onnx input tensor ValueInfoProto. * @param arg_types list of mlir types representing types of graph input. */ - void ImportInputTensorType(const onnx::ValueInfoProto& input, - llvm::SmallVector& arg_types) { + void ImportInputTensorType(const onnx::ValueInfoProto &input, + llvm::SmallVector &arg_types) { std::vector dims; auto shape_proto = input.type().tensor_type().shape(); auto input_tensor_legalized_name = legalize_name(input.name()); for (int i = 0; i < shape_proto.dim_size(); i++) { if (shape_proto.dim()[i].dim_value()) { int dim_numeric_size = shape_proto.dim()[i].dim_value(); + assert( + dim_numeric_size != 0 && "Parsed an input tensor with a dimension size of zero"); if (dim_numeric_size > 0) { dims.push_back(dim_numeric_size); - } else { // If dim_value < 0, then dim is parametric. - // TODO Verify the unknown dim size in MLIR + } else { // If dim_value < 0, then dim is parametric. + // TODO Verify the unknown dim size in MLIR dims.push_back(-1); } } else { @@ -216,8 +197,8 @@ class FrontendGenImpl { * @param input onnx input tensor ValueInfoProto. * @param symbol mlir input argument. */ - void ImportInputTensorSymbol( - const onnx::ValueInfoProto& input, mlir::Value* symbol) { + void ImportInputTensorSymbol(const onnx::ValueInfoProto &input, + mlir::Value *symbol) { auto input_tensor_legalized_name = legalize_name(input.name()); assert( !frontend_symbols_.ContainKey(input_tensor_legalized_name) && @@ -225,32 +206,286 @@ class FrontendGenImpl { frontend_symbols_.AddMapping(input_tensor_legalized_name, symbol); } - void ImportNode(onnx::NodeProto node) { - std::vector inputs; + template + T get_attr_generic(onnx::NodeProto &node, std::string name, + std::function attr_getter, + T default_val) { + for (int i = 0; i < node.attribute_size(); ++i) { + auto attr = node.attribute(i); + if (attr.name() == name) { + return attr_getter(attr); + } + } + return default_val; + } + + template + T get_attr_generic(onnx::NodeProto &node, std::string name, + std::function attr_getter) { + for (int i = 0; i < node.attribute_size(); ++i) { + auto attr = node.attribute(i); + if (attr.name() == name) { + return attr_getter(attr); + } + } + assert(false && "ONNX Node Attribute Not Found!"); + } + + auto get_attr_ints(onnx::NodeProto &node, std::string name, + std::vector default_val) { + std::function(onnx::AttributeProto &)> attr_getter = + [](onnx::AttributeProto &attr) { + std::vector ints(attr.ints_size()); + std::copy(attr.ints().begin(), attr.ints().end(), ints.begin()); + return ints; + }; + auto r = get_attr_generic(node, name, attr_getter, default_val); + auto dataType = + mlir::RankedTensorType::get(r.size(), builder_.getIntegerType(32)); + auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r)); + auto aname = node.op_type() + "." + name; + auto attr_output = builder_.getNamedAttr(aname, attr_v); + return attr_output; + } + + auto get_attr_ints(onnx::NodeProto &node, std::string name) { + std::function(onnx::AttributeProto &)> attr_getter = + [](onnx::AttributeProto &attr) { + std::vector ints(attr.ints_size()); + std::copy(attr.ints().begin(), attr.ints().end(), ints.begin()); + return ints; + }; + auto r = get_attr_generic(node, name, attr_getter); + auto dataType = + mlir::RankedTensorType::get(r.size(), builder_.getIntegerType(32)); + auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r)); + auto aname = node.op_type() + "." + name; + auto attr_output = builder_.getNamedAttr(aname, attr_v); + return attr_output; + } + + auto get_attr_floats(onnx::NodeProto &node, std::string name) { + std::function(onnx::AttributeProto &)> attr_getter = + [](onnx::AttributeProto &attr) { + std::vector floats(attr.floats_size()); + std::copy(attr.floats().begin(), attr.floats().end(), floats.begin()); + return floats; + }; + auto r = get_attr_generic(node, name, attr_getter); + auto dataType = + mlir::RankedTensorType::get(r.size(), builder_.getF32Type()); + auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r)); + auto aname = node.op_type() + "." + name; + auto attr_output = builder_.getNamedAttr(aname, attr_v); + return attr_output; + } + + auto get_attr_floats(onnx::NodeProto &node, std::string name, + std::vector default_val) { + std::function(onnx::AttributeProto &)> attr_getter = + [](onnx::AttributeProto &attr) { + std::vector floats(attr.floats_size()); + std::copy(attr.floats().begin(), attr.floats().end(), floats.begin()); + return floats; + }; + auto r = get_attr_generic(node, name, attr_getter, default_val); + auto dataType = + mlir::RankedTensorType::get(r.size(), builder_.getF32Type()); + auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r)); + auto aname = node.op_type() + "." + name; + auto attr_output = builder_.getNamedAttr(aname, attr_v); + return attr_output; + } + + auto get_attr_int(onnx::NodeProto &node, std::string name) { + std::function attr_getter = + [](onnx::AttributeProto &attr) { return attr.i(); }; + int r = get_attr_generic(node, name, attr_getter); + auto attr_v = builder_.getI32IntegerAttr(r); + auto aname = node.op_type() + "." + name; + auto attr_output = builder_.getNamedAttr(aname, attr_v); + return attr_output; + } + + auto get_attr_int(onnx::NodeProto &node, std::string name, int default_val) { + std::function attr_getter = + [](onnx::AttributeProto &attr) { return attr.i(); }; + int r = get_attr_generic(node, name, attr_getter, default_val); + auto attr_v = builder_.getI32IntegerAttr(r); + auto aname = node.op_type() + "." + name; + auto attr_output = builder_.getNamedAttr(aname, attr_v); + return attr_output; + } + + auto get_attr_float(onnx::NodeProto &node, std::string name) { + std::function attr_getter = + [](onnx::AttributeProto &attr) { return attr.f(); }; + auto r = get_attr_generic(node, name, attr_getter); + auto attr_v = builder_.getF32FloatAttr(r); + auto aname = node.op_type() + "." + name; + return builder_.getNamedAttr(aname, attr_v); + } + + auto get_attr_float(onnx::NodeProto &node, std::string name, + float default_val) { + std::function attr_getter = + [](onnx::AttributeProto &attr) { return attr.f(); }; + auto r = get_attr_generic(node, name, attr_getter, default_val); + auto attr_v = builder_.getF32FloatAttr(r); + auto aname = node.op_type() + "." + name; + return builder_.getNamedAttr(aname, attr_v); + } + + auto get_attr_string(onnx::NodeProto &node, std::string name) { + std::function attr_getter = + [](onnx::AttributeProto &attr) { return attr.s(); }; + auto r = get_attr_generic(node, name, attr_getter); + auto attr_v = builder_.getStringAttr(r); + auto aname = node.op_type() + "." + name; + return builder_.getNamedAttr(aname, attr_v); + } + + auto get_attr_string(onnx::NodeProto &node, std::string name, + std::string default_val) { + std::function attr_getter = + [](onnx::AttributeProto &attr) { return attr.s(); }; + auto r = get_attr_generic(node, name, attr_getter, default_val); + auto attr_v = builder_.getStringAttr(r); + auto aname = node.op_type() + "." + name; + return builder_.getNamedAttr(aname, attr_v); + } + + /* + auto get_attr_strings(onnx::NodeProto &node, std::string name) { + std::function(onnx::AttributeProto &)> + attr_getter = + [](onnx::AttributeProto &attr) { + std::vector strings(attr.strings_size()); + std::copy(attr.strings().begin(), attr.strings().end(), + strings.begin()); return strings; + }; + auto r = get_attr_generic(node, name, attr_getter); + return r; + return builder_.getNamedAttr(aname, attr_v); + auto dataType = + mlir::RankedTensorType::get(r.size(), builder_.get???Type()); + auto attr_v = mlir::DenseElementsAttr::get(dataType, + llvm::makeArrayRef(r)); auto aname = node.op_type() + "." + name; auto + attr_output = builder_.getNamedAttr(aname, attr_v); return attr_output; + } + */ + + auto get_default_ints(std::string default_str) { + std::vector r; + auto start = default_str.find("{"); + while (true) { + auto end = default_str.find(",", start + 1); + if (end == std::string::npos) { + end = default_str.find("}", start + 1); + if (end != std::string::npos && end > start+1) { + r.push_back(std::stoi(default_str.substr(start + 1, end))); + } + break; + } else { + r.push_back(std::stoi(default_str.substr(start + 1, end))); + } + start = end + 1; + } + return r; + } + + auto get_default_floats(std::string default_str) { + std::vector r; + auto start = default_str.find("{"); + while (true) { + auto end = default_str.find(",", start + 1); + if (end == std::string::npos) { + end = default_str.find("}", start + 1); + if (end != std::string::npos && end > start+1) { + r.push_back(std::stof(default_str.substr(start + 1, end))); + } + break; + } else { + r.push_back(std::stof(default_str.substr(start + 1, end))); + } + start = end + 1; + } + return r; + } + + auto get_default_strings(std::string default_str) { + std::vector r; + auto start = default_str.find("{"); + while (true) { + auto end = default_str.find(",", start + 1); + if (end == std::string::npos) { + end = default_str.find("}", start + 1); + if (end != std::string::npos && end > start+1) { + r.push_back(default_str.substr(start + 1, end)); + } + break; + } else { + r.push_back(default_str.substr(start + 1, end)); + } + start = end + 1; + } + return r; + } + + onnx::TensorProto get_attr_tensor(onnx::NodeProto &node, std::string name) { + std::function attr_getter = + [](onnx::AttributeProto &attr) { return attr.t(); }; + return get_attr_generic(node, name, attr_getter); + } + + auto ImportNodeAttr(onnx::NodeProto node, std::string attr_name, + std::string type_name, std::string default_str) { + if (default_str == "") { + if (type_name == "int") { + return get_attr_int(node, attr_name); + } else if (type_name == "float") { + return get_attr_float(node, attr_name); + } else if (type_name == "str") { + return get_attr_string(node, attr_name); + } else if (type_name == "ints") { + return get_attr_ints(node, attr_name); + } else if (type_name == "floats") { + return get_attr_floats(node, attr_name); + } else { + assert( + false && + "Got an empty initializer or initializer for this " + "datatype is not implemented. Something is wrong."); + } + } else { + // with default value + if (type_name == "int") { + return get_attr_int(node, attr_name, std::stoi(default_str)); + } else if (type_name == "float") { + return get_attr_float(node, attr_name, std::stof(default_str)); + } else if (type_name == "str") { + return get_attr_string(node, attr_name, default_str); + } else if (type_name == "ints") { + return get_attr_ints(node, attr_name, get_default_ints(default_str)); + } else if (type_name == "floats") { + return get_attr_floats(node, attr_name, + get_default_floats(default_str)); + } else { + assert( + false && + "Got an empty initializer or initializer for this " + "datatype is not implemented. Something is wrong."); + } + } + } + + void ImportNodeGeneric(onnx::NodeProto node) { + std::vector inputs; for (auto item : node.input()) { if (frontend_symbols_.ContainKey(legalize_name(item))) { inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); } } - - std::vector outputTypes; - for (auto item : node.output()) { - outputTypes.push_back(mlir::UnrankedTensorType::get(builder_.getF32Type())); - } - - std::vector attributes; - llvm::StringRef OpName = node.op_type(); - - //the following code is generated by gen_doc.py - //refer to dialect/onnx/onnx.td for details - //when the input or output of then op does not match the specification, - //the generic operator is used - //one known reeason is the optional input - -#include "src/builder/op_build_table.inc" - - - // Old way of doing things. mlir::OperationState result(UnknownLoc(), "frontend." + node.op_type()); for (auto item : node.output()) { result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type())); @@ -261,8 +496,165 @@ class FrontendGenImpl { auto r = op->getResult(i); frontend_symbols_.AddMapping(legalize_name(node.output()[i]), r); } + } - // TODO more info from node: attributes + // if c++17 is used, ImportNodeOneOut and ImportNodeMultipleOuts can be + // combined with 'if constexpr' the issue is the type of the output is + // different. alternative way to use variadic output for all the op + + /*! + * Important onnx node which generates only one output + * @param node onnx node + * @param nIn number of expected inputs + * @param nOut number of expected outputs + * @param attrs list of desription for attributes with format {name, type, + * default} + */ + template + void ImportNodeOneOut( + onnx::NodeProto node, int nIn, int nOut, + std::initializer_list> + attrs) { + std::vector inputs; + for (auto item : node.input()) { + if (frontend_symbols_.ContainKey(legalize_name(item))) { + inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); + } + } + + std::vector outputTypes; + for (auto item : node.output()) { + outputTypes.push_back( + mlir::UnrankedTensorType::get(builder_.getF32Type())); + } + + std::vector attributes; + //for (auto [attr_name, attr_type, attr_default] : attrs) { + for (auto oneAttr: attrs) { + std::string attr_name; + std::string attr_type; + std::string attr_default; + std::tie (attr_name, attr_type, attr_default) = oneAttr; + if (attr_type != "") { + auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default); + attributes.push_back(attr); + } else { + // TODO: the attributes need special handling + //std::cout << "missing " << node.op_type() << " " << attr_name << std::endl; + } + } + + llvm::StringRef OpName = node.op_type(); + + if (nIn == inputs.size() && nOut == outputTypes.size()) { + auto op = + builder_.create(UnknownLoc(), outputTypes, inputs, attributes); + frontend_symbols_.AddMapping(legalize_name(node.output()[0]), + op.getResult()); + } else { + ImportNodeGeneric(node); + } + } + + template + void ImportNodeMultipleOuts( + onnx::NodeProto node, int nIn, int nOut, + std::initializer_list> + attrs) { + std::vector inputs; + for (auto item : node.input()) { + if (frontend_symbols_.ContainKey(legalize_name(item))) { + inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); + } + } + + std::vector outputTypes; + for (auto item : node.output()) { + outputTypes.push_back( + mlir::UnrankedTensorType::get(builder_.getF32Type())); + } + + std::vector attributes; + for (auto oneAttr: attrs) { + std::string attr_name; + std::string attr_type; + std::string attr_default; + std::tie (attr_name, attr_type, attr_default) = oneAttr; + if (attr_type != "") { + auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default); + attributes.push_back(attr); + } else { + // TODO: the attributes need special handling + //std::cout << "missing " << node.op_type() << " " << attr_name << std::endl; + } + } + + llvm::StringRef OpName = node.op_type(); + + if (nIn == inputs.size() && nOut == outputTypes.size()) { + auto op = + builder_.create(UnknownLoc(), outputTypes, inputs, attributes); + for (int i = 0; i < node.output().size(); i++) { + frontend_symbols_.AddMapping(legalize_name(node.output()[i]), + op.getResult(i)); + } + } else { + ImportNodeGeneric(node); + } + } + + /*! + * Special handle for Conv operations. + * c++ does not allow template specialization inside a class scope + * a specialized function is used + */ + void ImportNodeConv( + onnx::NodeProto node, int nIn, int nOut, + std::initializer_list> + attrs) { + + // Conv has attribute dilations, kernel_shape, pads, the default value of + // which is determined by the shape of first argument. However, since the + // shape is unknown now, these attributes can be not generated auto + // dilations_attr = get_attr_ints(node, "dilations", + // std::vector(inputs[0]->getType().cast.getDims()-2, + // 1)); + // attributes.push_back(dilations_attr) + // similar situation for pads, strides in AveragePool + // axes of ReduceSum, pads, strides, dilations and kernel_shape of MaxPool + // TODO: fix this after type inference + + if (node.input().size() == 1) { + ImportNodeOneOut(node, nIn, nOut, attrs); + } else { + ImportNodeOneOut(node, nIn, nOut, attrs); + } + } + + void ImportNode(onnx::NodeProto node) { + std::vector inputs; + for (auto item : node.input()) { + if (frontend_symbols_.ContainKey(legalize_name(item))) { + inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); + } + } + + std::vector outputTypes; + for (auto item : node.output()) { + outputTypes.push_back( + mlir::UnrankedTensorType::get(builder_.getF32Type())); + } + + std::vector attributes; + llvm::StringRef OpName = node.op_type(); + + // the following code is generated by gen_doc.py + // refer to dialect/onnx/onnx.td for details + // when the input or output of then op does not match the specification, + // the generic operator is used + // one known reeason is the optional input + +#include "src/builder/op_build_table.inc" } /*! @@ -277,9 +669,9 @@ class FrontendGenImpl { * @param ret_vals a vector of mlir Value* representing graph's * output tensor. */ - void ImportOutputTensor(const onnx::ValueInfoProto& output, - llvm::SmallVectorImpl& ret_types, - llvm::SmallVectorImpl& ret_vals) { + void ImportOutputTensor(const onnx::ValueInfoProto &output, + llvm::SmallVectorImpl &ret_types, + llvm::SmallVectorImpl &ret_vals) { auto output_tensor_legalized_name = legalize_name(output.name()); assert( frontend_symbols_.ContainKey(output_tensor_legalized_name) && @@ -291,8 +683,8 @@ class FrontendGenImpl { ret_vals.push_back(tensor_val); } - void ImportGraph( - const onnx::GraphProto& graph, const std::string& name = "main") { + void ImportGraph(const onnx::GraphProto &graph, + const std::string &name = "main") { // create a function for the graph // TODO: // * get name and type for the function. @@ -300,7 +692,7 @@ class FrontendGenImpl { llvm::SmallVector arg_types; // Import the input tensor types. - for (const auto& input : graph.input()) { + for (const auto &input : graph.input()) { ImportInputTensorType(input, arg_types); } @@ -308,7 +700,7 @@ class FrontendGenImpl { auto func_type = builder_.getFunctionType(arg_types, {}); auto main_func = mlir::FuncOp::create(UnknownLoc(), name, func_type, /* attrs = */ {}); - auto& entryBlock = *main_func.addEntryBlock(); + auto &entryBlock = *main_func.addEntryBlock(); builder_.setInsertionPointToStart(&entryBlock); module_.push_back(main_func); @@ -319,14 +711,14 @@ class FrontendGenImpl { // import nodes in the graph auto node = graph.node(); - for (const auto& item : node) { + for (const auto &item : node) { ImportNode(item); } llvm::SmallVector ret_types; - llvm::SmallVector ret_vals; + llvm::SmallVector ret_vals; // Import the output tensors - for (const auto& output : graph.output()) { + for (const auto &output : graph.output()) { ImportOutputTensor(output, ret_types, ret_vals); } @@ -337,7 +729,6 @@ class FrontendGenImpl { func_type = builder_.getFunctionType(arg_types, ret_types); main_func.setType(func_type); } - }; // FrontendGenImpl class } // namespace } // namespace onnf @@ -354,11 +745,13 @@ mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model) { } void ImportFrontendModelFile(std::string model_fname, - mlir::MLIRContext& context, mlir::OwningModuleRef& module) { + mlir::MLIRContext &context, + mlir::OwningModuleRef &module) { onnx::ModelProto model; std::fstream input(model_fname, std::ios::in | std::ios::binary); auto parse_success = model.ParseFromIstream(&input); + assert(parse_success && "Onnx Model Parsing Failed."); FrontendGenImpl myONNXGen(context); module = myONNXGen.ImportONNXModel(model); diff --git a/src/builder/op_build_table.inc b/src/builder/op_build_table.inc index c8e75e0..b9e5720 100644 --- a/src/builder/op_build_table.inc +++ b/src/builder/op_build_table.inc @@ -1,313 +1,688 @@ - if (OpName == "Abs") { - OneOut(Abs, 1, 1); + if (OpName == "DUMMY") { + }else if (OpName == "Abs") { + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Acos") { - OneOut(Acos, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Acosh") { - OneOut(Acosh, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Add") { - OneOut(Add, 2, 1); + ImportNodeOneOut(node, 2, 1, { + }); }else if (OpName == "And") { - OneOut(And, 2, 1); + ImportNodeOneOut(node, 2, 1, { + }); }else if (OpName == "ArgMax") { - OneOut(ArgMax, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"axis","int","0"} + ,{"keepdims","int","1"} + }); }else if (OpName == "ArgMin") { - OneOut(ArgMin, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"axis","int","0"} + ,{"keepdims","int","1"} + }); }else if (OpName == "Asin") { - OneOut(Asin, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Asinh") { - OneOut(Asinh, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Atan") { - OneOut(Atan, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Atanh") { - OneOut(Atanh, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "AveragePool") { - OneOut(AveragePool, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"auto_pad","str","NOTSET"} + ,{"ceil_mode","int","0"} + ,{"count_include_pad","int","0"} + ,{"kernel_shape","ints", ""} + ,{"pads","", ""} + ,{"strides","", ""} + }); }else if (OpName == "BatchNormalization") { - MultipleOuts(BatchNormalization, 5, 5); + ImportNodeMultipleOuts(node, 5, 5, { + {"epsilon","float","1e-05"} + ,{"momentum","float","0.9"} + }); }else if (OpName == "BitShift") { - OneOut(BitShift, 2, 1); + ImportNodeOneOut(node, 2, 1, { + {"direction","", ""} + }); }else if (OpName == "Cast") { - OneOut(Cast, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"to","int", "0"} + }); }else if (OpName == "Ceil") { - OneOut(Ceil, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Clip") { - OneOut(Clip, 3, 1); + ImportNodeOneOut(node, 3, 1, { + }); }else if (OpName == "Compress") { - OneOut(Compress, 2, 1); + ImportNodeOneOut(node, 2, 1, { + {"axis","", ""} + }); }else if (OpName == "Concat") { - OneOut(Concat, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"axis","int", "0"} + }); }else if (OpName == "ConcatFromSequence") { - OneOut(ConcatFromSequence, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"axis","", ""} + ,{"new_axis","int","0"} + }); }else if (OpName == "Constant") { - OneOut(Constant, 0, 1); + ImportNodeOneOut(node, 0, 1, { + {"sparse_value","", ""} + ,{"value","", ""} + }); }else if (OpName == "ConstantOfShape") { - OneOut(ConstantOfShape, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"value","", ""} + }); }else if (OpName == "Conv") { - OneOut(Conv, 3, 1); + ImportNodeConv(node, 3, 1, { + {"auto_pad","str","NOTSET"} + ,{"dilations","", ""} + ,{"group","int", "1"} + ,{"kernel_shape","", ""} + ,{"pads","", ""} + ,{"strides","", ""} + }); }else if (OpName == "ConvInteger") { - OneOut(ConvInteger, 4, 1); + ImportNodeOneOut(node, 4, 1, { + {"auto_pad","str","NOTSET"} + ,{"dilations","", ""} + ,{"group","int","1"} + ,{"kernel_shape","", ""} + ,{"pads","", ""} + ,{"strides","", ""} + }); }else if (OpName == "ConvTranspose") { - OneOut(ConvTranspose, 3, 1); + ImportNodeOneOut(node, 3, 1, { + {"auto_pad","str","NOTSET"} + ,{"dilations","", ""} + ,{"group","int","1"} + ,{"kernel_shape","", ""} + ,{"output_padding","", ""} + ,{"output_shape","", ""} + ,{"pads","", ""} + ,{"strides","", ""} + }); }else if (OpName == "Cos") { - OneOut(Cos, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Cosh") { - OneOut(Cosh, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "CumSum") { - OneOut(CumSum, 2, 1); + ImportNodeOneOut(node, 2, 1, { + {"exclusive","int","0"} + ,{"reverse","int","0"} + }); }else if (OpName == "DepthToSpace") { - OneOut(DepthToSpace, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"blocksize","", ""} + ,{"mode","str","DCR"} + }); }else if (OpName == "DequantizeLinear") { - OneOut(DequantizeLinear, 3, 1); + ImportNodeOneOut(node, 3, 1, { + }); }else if (OpName == "Det") { - OneOut(Det, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Div") { - OneOut(Div, 2, 1); + ImportNodeOneOut(node, 2, 1, { + }); }else if (OpName == "Dropout") { - MultipleOuts(Dropout, 1, 2); + ImportNodeMultipleOuts(node, 1, 2, { + {"ratio","float","0.5"} + }); }else if (OpName == "DynamicQuantizeLinear") { - MultipleOuts(DynamicQuantizeLinear, 1, 3); + ImportNodeMultipleOuts(node, 1, 3, { + }); }else if (OpName == "Elu") { - OneOut(Elu, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"alpha","float","1.0"} + }); }else if (OpName == "Equal") { - OneOut(Equal, 2, 1); + ImportNodeOneOut(node, 2, 1, { + }); }else if (OpName == "Erf") { - OneOut(Erf, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Exp") { - OneOut(Exp, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Expand") { - OneOut(Expand, 2, 1); + ImportNodeOneOut(node, 2, 1, { + }); }else if (OpName == "EyeLike") { - OneOut(EyeLike, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"dtype","", ""} + ,{"k","int","0"} + }); }else if (OpName == "Flatten") { - OneOut(Flatten, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"axis","int","1"} + }); }else if (OpName == "Floor") { - OneOut(Floor, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "GRU") { - MultipleOuts(GRU, 6, 2); + ImportNodeMultipleOuts(node, 6, 2, { + {"activation_alpha","", ""} + ,{"activation_beta","", ""} + ,{"activations","", ""} + ,{"clip","", ""} + ,{"direction","str","forward"} + ,{"hidden_size","", ""} + ,{"linear_before_reset","int","0"} + }); }else if (OpName == "Gather") { - OneOut(Gather, 2, 1); + ImportNodeOneOut(node, 2, 1, { + {"axis","int","0"} + }); }else if (OpName == "GatherElements") { - OneOut(GatherElements, 2, 1); + ImportNodeOneOut(node, 2, 1, { + {"axis","int","0"} + }); }else if (OpName == "GatherND") { - OneOut(GatherND, 2, 1); + ImportNodeOneOut(node, 2, 1, { + }); }else if (OpName == "Gemm") { - OneOut(Gemm, 3, 1); + ImportNodeOneOut(node, 3, 1, { + {"alpha","float","1.0"} + ,{"beta","float","1.0"} + ,{"transA","int","0"} + ,{"transB","int","0"} + }); }else if (OpName == "GlobalAveragePool") { - OneOut(GlobalAveragePool, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "GlobalLpPool") { - OneOut(GlobalLpPool, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"p","int","2"} + }); }else if (OpName == "GlobalMaxPool") { - OneOut(GlobalMaxPool, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Greater") { - OneOut(Greater, 2, 1); + ImportNodeOneOut(node, 2, 1, { + }); }else if (OpName == "HardSigmoid") { - OneOut(HardSigmoid, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"alpha","float","0.2"} + ,{"beta","float","0.5"} + }); }else if (OpName == "Hardmax") { - OneOut(Hardmax, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"axis","int","1"} + }); }else if (OpName == "Identity") { - OneOut(Identity, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "If") { - OneOut(If, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"else_branch","", ""} + ,{"then_branch","", ""} + }); }else if (OpName == "InstanceNormalization") { - OneOut(InstanceNormalization, 3, 1); + ImportNodeOneOut(node, 3, 1, { + {"epsilon","float","1e-05"} + }); }else if (OpName == "IsInf") { - OneOut(IsInf, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"detect_negative","int","1"} + ,{"detect_positive","int","1"} + }); }else if (OpName == "IsNaN") { - OneOut(IsNaN, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "LRN") { - OneOut(LRN, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"alpha","float","0.0001"} + ,{"beta","float","0.75"} + ,{"bias","float","1.0"} + ,{"size","int", ""} + }); }else if (OpName == "LSTM") { - MultipleOuts(LSTM, 8, 3); + ImportNodeMultipleOuts(node, 8, 3, { + {"activation_alpha","", ""} + ,{"activation_beta","", ""} + ,{"activations","", ""} + ,{"clip","", ""} + ,{"direction","str","forward"} + ,{"hidden_size","", ""} + ,{"input_forget","int","0"} + }); }else if (OpName == "LeakyRelu") { - OneOut(LeakyRelu, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"alpha","float","0.01"} + }); }else if (OpName == "Less") { - OneOut(Less, 2, 1); + ImportNodeOneOut(node, 2, 1, { + }); }else if (OpName == "Log") { - OneOut(Log, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "LogSoftmax") { - OneOut(LogSoftmax, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"axis","int","1"} + }); }else if (OpName == "Loop") { - OneOut(Loop, 3, 1); + ImportNodeOneOut(node, 3, 1, { + {"body","", ""} + }); }else if (OpName == "LpNormalization") { - OneOut(LpNormalization, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"axis","int","-1"} + ,{"p","int","2"} + }); }else if (OpName == "LpPool") { - OneOut(LpPool, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"auto_pad","str","NOTSET"} + ,{"kernel_shape","", ""} + ,{"p","int","2"} + ,{"pads","", ""} + ,{"strides","", ""} + }); }else if (OpName == "MatMul") { - OneOut(MatMul, 2, 1); + ImportNodeOneOut(node, 2, 1, { + }); }else if (OpName == "MatMulInteger") { - OneOut(MatMulInteger, 4, 1); + ImportNodeOneOut(node, 4, 1, { + }); }else if (OpName == "Max") { - OneOut(Max, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "MaxPool") { - MultipleOuts(MaxPool, 1, 2); + ImportNodeMultipleOuts(node, 1, 2, { + {"auto_pad","str","NOTSET"} + ,{"ceil_mode","int","0"} + ,{"dilations","", ""} + ,{"kernel_shape","ints", ""} + ,{"pads","", ""} + ,{"storage_order","int","0"} + ,{"strides","", ""} + }); }else if (OpName == "MaxRoiPool") { - OneOut(MaxRoiPool, 2, 1); + ImportNodeOneOut(node, 2, 1, { + {"pooled_shape","", ""} + ,{"spatial_scale","float","1.0"} + }); }else if (OpName == "MaxUnpool") { - OneOut(MaxUnpool, 3, 1); + ImportNodeOneOut(node, 3, 1, { + {"kernel_shape","", ""} + ,{"pads","", ""} + ,{"strides","", ""} + }); }else if (OpName == "Mean") { - OneOut(Mean, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "MeanVarianceNormalization") { - OneOut(MeanVarianceNormalization, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"axes","ints","{'0', '2', '3'}"} + }); }else if (OpName == "Min") { - OneOut(Min, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Mod") { - OneOut(Mod, 2, 1); + ImportNodeOneOut(node, 2, 1, { + {"fmod","int","0"} + }); }else if (OpName == "Mul") { - OneOut(Mul, 2, 1); + ImportNodeOneOut(node, 2, 1, { + }); }else if (OpName == "Multinomial") { - OneOut(Multinomial, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"dtype","int","6"} + ,{"sample_size","int","1"} + ,{"seed","", ""} + }); }else if (OpName == "Neg") { - OneOut(Neg, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "NonMaxSuppression") { - OneOut(NonMaxSuppression, 5, 1); + ImportNodeOneOut(node, 5, 1, { + {"center_point_box","int","0"} + }); }else if (OpName == "NonZero") { - OneOut(NonZero, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Not") { - OneOut(Not, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "OneHot") { - OneOut(OneHot, 3, 1); + ImportNodeOneOut(node, 3, 1, { + {"axis","int","-1"} + }); }else if (OpName == "Or") { - OneOut(Or, 2, 1); + ImportNodeOneOut(node, 2, 1, { + }); }else if (OpName == "PRelu") { - OneOut(PRelu, 2, 1); + ImportNodeOneOut(node, 2, 1, { + }); }else if (OpName == "Pad") { - OneOut(Pad, 3, 1); + ImportNodeOneOut(node, 3, 1, { + {"mode","str","constant"} + }); }else if (OpName == "Pow") { - OneOut(Pow, 2, 1); + ImportNodeOneOut(node, 2, 1, { + }); }else if (OpName == "QLinearConv") { - OneOut(QLinearConv, 9, 1); + ImportNodeOneOut(node, 9, 1, { + {"auto_pad","str","NOTSET"} + ,{"dilations","", ""} + ,{"group","int","1"} + ,{"kernel_shape","", ""} + ,{"pads","", ""} + ,{"strides","", ""} + }); }else if (OpName == "QLinearMatMul") { - OneOut(QLinearMatMul, 8, 1); + ImportNodeOneOut(node, 8, 1, { + }); }else if (OpName == "QuantizeLinear") { - OneOut(QuantizeLinear, 3, 1); + ImportNodeOneOut(node, 3, 1, { + }); }else if (OpName == "RNN") { - MultipleOuts(RNN, 6, 2); + ImportNodeMultipleOuts(node, 6, 2, { + {"activation_alpha","floats", "{}"} + ,{"activation_beta","floats", "{}"} + ,{"activations","", "{Tannh, Tanh}"} + ,{"clip","", ""} + ,{"direction","str","forward"} + ,{"hidden_size","", ""} + }); }else if (OpName == "RandomNormal") { - OneOut(RandomNormal, 0, 1); + ImportNodeOneOut(node, 0, 1, { + {"dtype","int","1"} + ,{"mean","float","0.0"} + ,{"scale","float","1.0"} + ,{"seed","", ""} + ,{"shape","", ""} + }); }else if (OpName == "RandomNormalLike") { - OneOut(RandomNormalLike, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"dtype","", ""} + ,{"mean","float","0.0"} + ,{"scale","float","1.0"} + ,{"seed","", ""} + }); }else if (OpName == "RandomUniform") { - OneOut(RandomUniform, 0, 1); + ImportNodeOneOut(node, 0, 1, { + {"dtype","int","1"} + ,{"high","float","1.0"} + ,{"low","float","0.0"} + ,{"seed","", ""} + ,{"shape","", ""} + }); }else if (OpName == "RandomUniformLike") { - OneOut(RandomUniformLike, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"dtype","", ""} + ,{"high","float","1.0"} + ,{"low","float","0.0"} + ,{"seed","", ""} + }); }else if (OpName == "Range") { - OneOut(Range, 3, 1); + ImportNodeOneOut(node, 3, 1, { + }); }else if (OpName == "Reciprocal") { - OneOut(Reciprocal, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "ReduceL1") { - OneOut(ReduceL1, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"axes","", ""} + ,{"keepdims","int","1"} + }); }else if (OpName == "ReduceL2") { - OneOut(ReduceL2, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"axes","", ""} + ,{"keepdims","int","1"} + }); }else if (OpName == "ReduceLogSum") { - OneOut(ReduceLogSum, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"axes","", ""} + ,{"keepdims","int","1"} + }); }else if (OpName == "ReduceLogSumExp") { - OneOut(ReduceLogSumExp, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"axes","", ""} + ,{"keepdims","int","1"} + }); }else if (OpName == "ReduceMax") { - OneOut(ReduceMax, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"axes","", ""} + ,{"keepdims","int","1"} + }); }else if (OpName == "ReduceMean") { - OneOut(ReduceMean, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"axes","", ""} + ,{"keepdims","int","1"} + }); }else if (OpName == "ReduceMin") { - OneOut(ReduceMin, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"axes","", ""} + ,{"keepdims","int","1"} + }); }else if (OpName == "ReduceProd") { - OneOut(ReduceProd, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"axes","", ""} + ,{"keepdims","int","1"} + }); }else if (OpName == "ReduceSum") { - OneOut(ReduceSum, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"axes","", ""} + ,{"keepdims","int","1"} + }); }else if (OpName == "ReduceSumSquare") { - OneOut(ReduceSumSquare, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"axes","", ""} + ,{"keepdims","int","1"} + }); }else if (OpName == "Relu") { - OneOut(Relu, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Reshape") { - OneOut(Reshape, 2, 1); + ImportNodeOneOut(node, 2, 1, { + }); }else if (OpName == "Resize") { - OneOut(Resize, 4, 1); + ImportNodeOneOut(node, 4, 1, { + {"coordinate_transformation_mode","str","half_pixel"} + ,{"cubic_coeff_a","float","-0.75"} + ,{"exclude_outside","int","0"} + ,{"extrapolation_value","float","0.0"} + ,{"mode","str","nearest"} + ,{"nearest_mode","str","round_prefer_floor"} + }); }else if (OpName == "ReverseSequence") { - OneOut(ReverseSequence, 2, 1); + ImportNodeOneOut(node, 2, 1, { + {"batch_axis","int","1"} + ,{"time_axis","int","0"} + }); }else if (OpName == "RoiAlign") { - OneOut(RoiAlign, 3, 1); + ImportNodeOneOut(node, 3, 1, { + {"mode","str","avg"} + ,{"output_height","int","1"} + ,{"output_width","int","1"} + ,{"sampling_ratio","int","0"} + ,{"spatial_scale","float","1.0"} + }); }else if (OpName == "Round") { - OneOut(Round, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Scan") { - OneOut(Scan, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"body","", ""} + ,{"num_scan_inputs","", ""} + ,{"scan_input_axes","", ""} + ,{"scan_input_directions","", ""} + ,{"scan_output_axes","", ""} + ,{"scan_output_directions","", ""} + }); }else if (OpName == "Scatter") { - OneOut(Scatter, 3, 1); + ImportNodeOneOut(node, 3, 1, { + {"axis","int","0"} + }); }else if (OpName == "ScatterElements") { - OneOut(ScatterElements, 3, 1); + ImportNodeOneOut(node, 3, 1, { + {"axis","int","0"} + }); }else if (OpName == "ScatterND") { - OneOut(ScatterND, 3, 1); + ImportNodeOneOut(node, 3, 1, { + }); }else if (OpName == "Selu") { - OneOut(Selu, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"alpha","float","1.67326"} + ,{"gamma","float","1.0507"} + }); }else if (OpName == "SequenceAt") { - OneOut(SequenceAt, 2, 1); + ImportNodeOneOut(node, 2, 1, { + }); }else if (OpName == "SequenceConstruct") { - OneOut(SequenceConstruct, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "SequenceEmpty") { - OneOut(SequenceEmpty, 0, 1); + ImportNodeOneOut(node, 0, 1, { + {"dtype","", ""} + }); }else if (OpName == "SequenceErase") { - OneOut(SequenceErase, 2, 1); + ImportNodeOneOut(node, 2, 1, { + }); }else if (OpName == "SequenceInsert") { - OneOut(SequenceInsert, 3, 1); + ImportNodeOneOut(node, 3, 1, { + }); }else if (OpName == "SequenceLength") { - OneOut(SequenceLength, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Shape") { - OneOut(Shape, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Shrink") { - OneOut(Shrink, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"bias","float","0.0"} + ,{"lambd","float","0.5"} + }); }else if (OpName == "Sigmoid") { - OneOut(Sigmoid, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Sign") { - OneOut(Sign, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Sin") { - OneOut(Sin, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Sinh") { - OneOut(Sinh, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Size") { - OneOut(Size, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Slice") { - OneOut(Slice, 5, 1); + ImportNodeOneOut(node, 5, 1, { + }); }else if (OpName == "Softmax") { - OneOut(Softmax, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"axis","int","1"} + }); }else if (OpName == "Softplus") { - OneOut(Softplus, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Softsign") { - OneOut(Softsign, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "SpaceToDepth") { - OneOut(SpaceToDepth, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"blocksize","", ""} + }); }else if (OpName == "Split") { - OneOut(Split, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"axis","int","0"} + ,{"split","", ""} + }); }else if (OpName == "SplitToSequence") { - OneOut(SplitToSequence, 2, 1); + ImportNodeOneOut(node, 2, 1, { + {"axis","int","0"} + ,{"keepdims","int","1"} + }); }else if (OpName == "Sqrt") { - OneOut(Sqrt, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Squeeze") { - OneOut(Squeeze, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"axes","", ""} + }); }else if (OpName == "StringNormalizer") { - OneOut(StringNormalizer, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"case_change_action","str","NONE"} + ,{"is_case_sensitive","int","0"} + ,{"locale","", ""} + ,{"stopwords","", ""} + }); }else if (OpName == "Sub") { - OneOut(Sub, 2, 1); + ImportNodeOneOut(node, 2, 1, { + }); }else if (OpName == "Sum") { - OneOut(Sum, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Tan") { - OneOut(Tan, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "Tanh") { - OneOut(Tanh, 1, 1); + ImportNodeOneOut(node, 1, 1, { + }); }else if (OpName == "TfIdfVectorizer") { - OneOut(TfIdfVectorizer, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"max_gram_length","", ""} + ,{"max_skip_count","", ""} + ,{"min_gram_length","", ""} + ,{"mode","", ""} + ,{"ngram_counts","", ""} + ,{"ngram_indexes","", ""} + ,{"pool_int64s","", ""} + ,{"pool_strings","", ""} + ,{"weights","", ""} + }); }else if (OpName == "ThresholdedRelu") { - OneOut(ThresholdedRelu, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"alpha","float","1.0"} + }); }else if (OpName == "Tile") { - OneOut(Tile, 2, 1); + ImportNodeOneOut(node, 2, 1, { + }); }else if (OpName == "TopK") { - MultipleOuts(TopK, 2, 2); + ImportNodeMultipleOuts(node, 2, 2, { + {"axis","int","-1"} + ,{"largest","int","1"} + ,{"sorted","int","1"} + }); }else if (OpName == "Transpose") { - OneOut(Transpose, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"perm","", ""} + }); }else if (OpName == "Unique") { - MultipleOuts(Unique, 1, 4); + ImportNodeMultipleOuts(node, 1, 4, { + {"axis","", ""} + ,{"sorted","int","1"} + }); }else if (OpName == "Unsqueeze") { - OneOut(Unsqueeze, 1, 1); + ImportNodeOneOut(node, 1, 1, { + {"axes","ints", ""} + }); }else if (OpName == "Upsample") { - OneOut(Upsample, 2, 1); + ImportNodeOneOut(node, 2, 1, { + {"mode","str","nearest"} + }); }else if (OpName == "Where") { - OneOut(Where, 3, 1); + ImportNodeOneOut(node, 3, 1, { + }); }else if (OpName == "Xor") { - OneOut(Xor, 2, 1); + ImportNodeOneOut(node, 2, 1, { + }); } \ No newline at end of file diff --git a/src/compiler/dialect/onnx/gen_doc.py b/src/compiler/dialect/onnx/gen_doc.py index 3e22199..530ac9a 100644 --- a/src/compiler/dialect/onnx/gen_doc.py +++ b/src/compiler/dialect/onnx/gen_doc.py @@ -352,6 +352,121 @@ def gen_schema(schema) : return s +""" +special cases: +* Split: attr split default value: sizeof(output1) namely 1 +* Conv: attr dilations default value is {num_dim of first input - 2, 1} +* Conv: attr kernel_shape type is ints +* Transpose: attr perm default value is {} empty int list +""" + +def gen_code(schema,fefile) : + special_handler = dict([ + ("Conv", "ImportNodeConv"), + #("Transpose", "ImportNodeTranspose") + ]) + special_type = dict([ + ("AveragePool "+"kernel_shape", '"ints", ""'), + ("MaxPool "+"kernel_shape", '"ints", ""'), + ("Cast "+"to", '"int", "0"'), + ("Concat "+"axis", '"int", "0"'), + ("Conv "+"group", '"int", "1"'), + ("Unsqueeze "+"axes", '"ints", ""'), + ("RNN "+"activation_alpha", '"floats", "{}"'), + ("RNN "+"activation_beta", '"floats", "{}"'), + ("RNN "+"activations", '"", "{Tannh, Tanh}"'), + ("LRN "+"size", '"int", ""') + ]) + line_indent = ' ' + fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n') + op_type_str='mlir::ONNX'+schema.name+'Op' + if schema.name in special_handler : + fefile.write(' '+special_handler[schema.name]+'(node, ' + +str(len(schema.inputs)) + +', ' +str(len(schema.outputs))+', {\n') + elif len(schema.outputs) > 1 : + fefile.write(' '+'ImportNodeMultipleOuts<'+op_type_str+'>(node, ' + +str(len(schema.inputs)) + +', ' +str(len(schema.outputs))+', {\n') + else : + fefile.write(' '+'ImportNodeOneOut<'+op_type_str+'>(node, ' + +str(len(schema.inputs)) + +', ' +str(len(schema.outputs))+', {\n') + + + if schema.attributes: + first_attr = True + for _, attr in sorted(schema.attributes.items()): + attr_line = line_indent+line_indent+line_indent+line_indent + if not first_attr: + attr_line += ',{' + else : + attr_line += ' {' + first_attr = False + + attr_line += '"'+attr.name+'",' + + if schema.name+' '+attr.name in special_type: + attr_line += special_type[schema.name+' '+attr.name] + # option holds either required or default value + elif attr.required: + attr_line += '"", ""' + + elif attr.default_value.name: + default_value = helper.get_attribute_value(attr.default_value) + + def format_value(value): # type: (Any) -> Text + if isinstance(value, float): + formatted = str(np.round(value, 5)) + # use default formatting, unless too long. + if (len(formatted) > 10): + formatted = str("({:e})".format(value)) + return formatted + elif isinstance(value, (bytes, bytearray)) and sys.version_info[0] == 3: + return str(value.decode('utf-8')) + return str(value) + + if isinstance(default_value, list): + value = default_value[0] + default_value = [format_value(val) for val in default_value] + # TODO the list type is homogenous or htergeneous? + + if isinstance(value, float) : + attr_type_str = '"floats"' + elif isinstance(value, int) : + attr_type_str = '"ints"' + elif isinstance(value, str) : + attr_type_str = '"strs"' + elif isinstance(value, (bytes, bytearray)) : + attr_type_str = '"strs"' + else : + attr_type_str = '"unknowns"' + attr_option_str = '"{}"'.format(default_value) + attr_option_str = attr_option_str.replace('[', '{', 1) + attr_option_str = attr_option_str.replace(']', '}', 1) + else: + if isinstance(default_value, float) : + attr_type_str = '"float"' + elif isinstance(default_value, int) : + attr_type_str = '"int"' + elif isinstance(default_value, str) : + attr_type_str = '"str"' + elif isinstance(default_value, (bytes, bytearray)) : + attr_type_str = '"str"' + else : + attr_type_str = '"unknown"' + default_value = format_value(default_value) + attr_option_str = '"{}"'.format(default_value) + attr_line += attr_type_str+','+attr_option_str + else: + #TODO why? + attr_line += '"", ""' + attr_line += '}\n' + fefile.write(attr_line) + fefile.write(line_indent+line_indent+line_indent+'});\n') + + + def main(args): # type: (Type[Args]) -> None with io.open(args.changelog, 'w', newline='') as fout: fout.write('## Operator Changelog\n') @@ -453,6 +568,7 @@ def main(args): # type: (Type[Args]) -> None fefile=io.open('op_build_table.inc', 'w', newline='') firstfunc = True + fefile.write(' '+'if (OpName == "DUMMY") {\n') for domain, supportmap in operator_schemas: s = '## {}\n'.format(display_domain_short(domain)) fout.write(s) @@ -461,21 +577,8 @@ def main(args): # type: (Type[Args]) -> None for op_type, schema, versions in namemap: # op_type #print("check point 1", schema.name, len(schema.inputs), len(schema.outputs)) - if firstfunc : - fefile.write(' '+'if (OpName == "'+schema.name+'") {\n') - firstfunc = False - else : - fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n') - if len(schema.outputs) > 1 : - fefile.write(' '+'MultipleOuts('+schema.name+', ' - +str(schema.since_version)+', ' - +str(len(schema.inputs))+', ' - +str(len(schema.outputs))+');\n') - else : - fefile.write(' '+'OneOut('+schema.name+', ' - +str(schema.since_version)+', ' - +str(len(schema.inputs))+', ' - +str(len(schema.outputs))+');\n') + gen_code(schema, fefile) + r = gen_schema(schema) tdfile.write(r) s = ('### {}**{}**' + (' (deprecated)' if schema.deprecated else '') + '\n').format( diff --git a/src/compiler/dialect/onnx/onnx.td b/src/compiler/dialect/onnx/onnx.td index f50cdbd..4244d58 100644 --- a/src/compiler/dialect/onnx/onnx.td +++ b/src/compiler/dialect/onnx/onnx.td @@ -71,4 +71,26 @@ def ONNXFullGemmOp: ONNX_Op<"FullGemm", let results = (outs AnyTensor); } +def ONNXConv1Op:ONNX_Op<"Conv1", + [NoSideEffect]> { + let summary = "ONNX Conv operation"; + let description = [{ + "The convolution operator consumes an input tensor and a filter, and" + "computes the output." + }]; + let arguments = (ins AnyTensor:$X); + let results = (outs AnyTensor); +} + +def ONNXConv3Op:ONNX_Op<"Conv3", + [NoSideEffect]> { + let summary = "ONNX Conv operation"; + let description = [{ + "The convolution operator consumes an input tensor and a filter, and" + "computes the output." + }]; + let arguments = (ins AnyTensor:$X, AnyTensor:$W, AnyTensor:$B); + let results = (outs AnyTensor); +} + #endif // ONNX_OPS