diff --git a/.gitmodules b/.gitmodules index 285a7ac..2293919 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "third_party/pybind11"] path = third_party/pybind11 url = https://github.com/pybind/pybind11.git +[submodule "third_party/variant"] + path = third_party/variant + url = git@github.com:mpark/variant.git diff --git a/CMakeLists.txt b/CMakeLists.txt index ab9e1d7..7ec7054 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,6 +22,7 @@ include(MLIR.cmake) add_subdirectory(third_party/onnx) add_subdirectory(third_party/benchmark) add_subdirectory(third_party/pybind11) +add_subdirectory(third_party/variant) set(CMAKE_CXX_STANDARD 14) add_subdirectory(src) diff --git a/src/builder/CMakeLists.txt b/src/builder/CMakeLists.txt index 7d96296..6033e52 100644 --- a/src/builder/CMakeLists.txt +++ b/src/builder/CMakeLists.txt @@ -7,8 +7,9 @@ add_library(builder target_include_directories(builder PRIVATE ${CMAKE_SOURCE_DIR}) target_include_directories(builder PRIVATE ${CMAKE_BINARY_DIR}) -target_link_libraries(builder compiler onnx ${MLIRLibs} curses) +target_link_libraries(builder compiler onnx ${MLIRLibs} curses mpark_variant) target_include_directories(builder PRIVATE ${CMAKE_SOURCE_DIR}/third_party/onnx + ${CMAKE_SOURCE_DIR}/third_party/variant ${CMAKE_SOURCE_DIR}) diff --git a/src/builder/frontend_dialect_transformer.cpp b/src/builder/frontend_dialect_transformer.cpp index 7dd52c0..6065cd3 100644 --- a/src/builder/frontend_dialect_transformer.cpp +++ b/src/builder/frontend_dialect_transformer.cpp @@ -1,6 +1,6 @@ //===- frontend_dialect_transformer.cpp - MLIR Operations -----------------===// // -// Copyright 2019 The IBM Research Authors. +// Copyright 2019 The IBM Research Authors. // // ============================================================================= // @@ -14,11 +14,16 @@ // //===----------------------------------------------------------------------===// +#include #include #include #include #include -#include + +// Using backported variant. +// bstd = backported standard library. +#include +namespace bstd = mpark; #include "mlir/Analysis/Verifier.h" #include "mlir/Dialect/StandardOps/Ops.h" @@ -42,15 +47,15 @@ namespace onnf { namespace { -void replaceAll( - std::string& str, const std::string& from, const std::string& to) { +void replaceAll(std::string &str, const std::string &from, + const std::string &to) { if (from.empty()) return; size_t start_pos = 0; while ((start_pos = str.find(from, start_pos)) != std::string::npos) { str.replace(start_pos, from.length(), to); - start_pos += to.length(); // In case 'to' contains 'from', like replacing - // 'x' with 'yx' + start_pos += to.length(); // In case 'to' contains 'from', like replacing + // 'x' with 'yx' } } @@ -71,10 +76,10 @@ struct OnnxOnnfSymbolMapping { * @param name onnx tensor name. * @return onnf tensor corresponding to `name`. */ - mlir::Value GetTensorByOnnxName(std::string name) { + mlir::Value GetTensorByOnnxName(const std::string &name) { assert(onnx_name2onnf_tensor.find(legalize_name(name)) != - onnx_name2onnf_tensor.end() && - "Tensor not found"); + onnx_name2onnf_tensor.end() && + "Tensor not found"); return onnx_name2onnf_tensor.at(legalize_name(name)); } @@ -83,9 +88,9 @@ struct OnnxOnnfSymbolMapping { * @param name onnx tensor name. * @param tensor MLIR Value pointer. */ - void AddMapping(std::string name, mlir::Value tensor) { + void AddMapping(const std::string &name, mlir::Value tensor) { assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 && - "Tensor already exists."); + "Tensor already exists."); onnx_name2onnf_tensor.emplace(legalize_name(name), tensor); } @@ -124,34 +129,34 @@ private: // Convert type to MLIR type. // A complete list of types can be found in: // /third_party/onnx/onnx/onnx.pb.h - mlir::Type TypeConvert(onnx::TensorProto_DataType intype) { - switch (intype) { - case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16: - return builder_.getF16Type(); - case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT: - return builder_.getF32Type(); - case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE: - return builder_.getF64Type(); - case onnx::TensorProto_DataType::TensorProto_DataType_INT8: - case onnx::TensorProto_DataType::TensorProto_DataType_UINT8: - return builder_.getIntegerType(8); - case onnx::TensorProto_DataType::TensorProto_DataType_INT16: - case onnx::TensorProto_DataType::TensorProto_DataType_UINT16: - return builder_.getIntegerType(16); - case onnx::TensorProto_DataType::TensorProto_DataType_INT32: - case onnx::TensorProto_DataType::TensorProto_DataType_UINT32: - return builder_.getIntegerType(32); - case onnx::TensorProto_DataType::TensorProto_DataType_INT64: - case onnx::TensorProto_DataType::TensorProto_DataType_UINT64: - return builder_.getIntegerType(64); - case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: - return builder_.getI1Type(); - case onnx::TensorProto_DataType::TensorProto_DataType_STRING: - case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64: - case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128: - case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED: - assert(false && "Unsupported data type encountered."); - return nullptr; + mlir::Type convertONNXTypeToMLIRType(onnx::TensorProto_DataType onnxType) { + switch (onnxType) { + case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16: + return builder_.getF16Type(); + case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT: + return builder_.getF32Type(); + case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE: + return builder_.getF64Type(); + case onnx::TensorProto_DataType::TensorProto_DataType_INT8: + case onnx::TensorProto_DataType::TensorProto_DataType_UINT8: + return builder_.getIntegerType(8); + case onnx::TensorProto_DataType::TensorProto_DataType_INT16: + case onnx::TensorProto_DataType::TensorProto_DataType_UINT16: + return builder_.getIntegerType(16); + case onnx::TensorProto_DataType::TensorProto_DataType_INT32: + case onnx::TensorProto_DataType::TensorProto_DataType_UINT32: + return builder_.getIntegerType(32); + case onnx::TensorProto_DataType::TensorProto_DataType_INT64: + case onnx::TensorProto_DataType::TensorProto_DataType_UINT64: + return builder_.getIntegerType(64); + case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: + return builder_.getI1Type(); + case onnx::TensorProto_DataType::TensorProto_DataType_STRING: + case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64: + case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128: + case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED: + assert(false && "Unsupported data type encountered."); + return nullptr; } } @@ -169,8 +174,8 @@ private: for (int i = 0; i < shape_proto.dim_size(); i++) { if (shape_proto.dim()[i].dim_value()) { int dim_numeric_size = shape_proto.dim()[i].dim_value(); - assert( - dim_numeric_size != 0 && "Parsed an input tensor with a dimension size of zero"); + assert(dim_numeric_size != 0 && + "Parsed an input tensor with a dimension size of zero"); if (dim_numeric_size > 0) { dims.push_back(dim_numeric_size); } else { // If dim_value < 0, then dim is parametric. @@ -184,7 +189,7 @@ private: } mlir::Type elementType = - TypeConvert(input.type().tensor_type().elem_type()); + convertONNXTypeToMLIRType(input.type().tensor_type().elem_type()); llvm::ArrayRef tensor_dims(dims.data(), dims.size()); arg_types.emplace_back( mlir::RankedTensorType::get(tensor_dims, elementType)); @@ -200,288 +205,111 @@ private: void ImportInputTensorSymbol(const onnx::ValueInfoProto &input, mlir::Value symbol) { auto input_tensor_legalized_name = legalize_name(input.name()); - assert( - !frontend_symbols_.ContainKey(input_tensor_legalized_name) && - "Found duplicate legalized input tensor names."); + assert(!frontend_symbols_.ContainKey(input_tensor_legalized_name) && + "Found duplicate legalized input tensor names."); frontend_symbols_.AddMapping(input_tensor_legalized_name, symbol); } - template - T get_attr_generic(onnx::NodeProto &node, std::string name, - std::function attr_getter, - T default_val) { + typedef bstd::variant, float, + std::vector, std::string, + std::vector> + AttrValueType; + + struct ONNXAttrVisitor { + ONNXAttrVisitor(std::string name, mlir::OpBuilder &builder) + : _builder(builder), _name(std::move(name)) {} + + // Op builder. + mlir::OpBuilder &_builder; + + // Name of the attribute being inspected. + std::string _name; + + mlir::NamedAttribute operator()(int64_t const &r) { + auto val = _builder.getI32IntegerAttr(r); + return _builder.getNamedAttr(_name, val); + } + + mlir::NamedAttribute operator()(std::vector const &ints) { + auto val = _builder.getI64ArrayAttr(ints); + return _builder.getNamedAttr(_name, val); + } + + mlir::NamedAttribute operator()(float const &r) { + auto val = _builder.getF32FloatAttr(r); + return _builder.getNamedAttr(_name, val); + } + + mlir::NamedAttribute operator()(std::vector const &floats) { + auto val = _builder.getF32ArrayAttr(floats); + return _builder.getNamedAttr(_name, val); + } + + mlir::NamedAttribute operator()(std::string const &s) { + auto val = _builder.getStringAttr(s); + return _builder.getNamedAttr(_name, val); + } + + mlir::NamedAttribute operator()(std::vector const &r) { + assert(false && "type of attribute value is not implemented"); + auto val = _builder.getI32IntegerAttr(1); + return _builder.getNamedAttr(_name, val); + }; + }; + + mlir::NamedAttribute convertNameValuePairToNamedAttribute( + std::pair nameAndVal) { + auto visitor = ONNXAttrVisitor(nameAndVal.first, builder_); + return mpark::visit(visitor, nameAndVal.second); + } + + static std::pair + convertAttributeProtoToNameValuePair(onnx::AttributeProto &attr) { + AttrValueType val; + switch (attr.type()) { + case onnx::AttributeProto::FLOAT: + return std::make_pair(attr.name(), AttrValueType(attr.f())); + case onnx::AttributeProto::INT: + return std::make_pair(attr.name(), AttrValueType(attr.i())); + case onnx::AttributeProto::STRING: + return std::make_pair(attr.name(), AttrValueType(attr.s())); + case onnx::AttributeProto::FLOATS: + val = AttrValueType( + std::vector(attr.floats().begin(), attr.floats().end())); + return std::make_pair(attr.name(), val); + case onnx::AttributeProto::INTS: + val = AttrValueType( + std::vector(attr.ints().begin(), attr.ints().end())); + return std::make_pair(attr.name(), val); + default: + assert(false && "datatype for attribute is not implemented"); + break; + } + } + + std::vector ImportNodeAttributes( + const onnx::NodeProto &node, + std::initializer_list> + defaultAttrList) { + std::vector attributes; + std::set definedAttributeSet; for (int i = 0; i < node.attribute_size(); ++i) { auto attr = node.attribute(i); - if (attr.name() == name) { - return attr_getter(attr); - } + auto nameValPair = convertAttributeProtoToNameValuePair(attr); + attributes.push_back(convertNameValuePairToNamedAttribute(nameValPair)); + definedAttributeSet.insert(attr.name()); } - return default_val; - } - - template - T get_attr_generic(onnx::NodeProto &node, std::string name, - std::function attr_getter) { - for (int i = 0; i < node.attribute_size(); ++i) { - auto attr = node.attribute(i); - if (attr.name() == name) { - return attr_getter(attr); - } + for (const auto &defaultAttr : defaultAttrList) { + if (definedAttributeSet.find(defaultAttr.first) == + definedAttributeSet.end()) + attributes.push_back(convertNameValuePairToNamedAttribute(defaultAttr)); } - assert(false && "ONNX Node Attribute Not Found!"); + return attributes; } - auto get_attr_ints(onnx::NodeProto &node, std::string name, - std::vector default_val) { - std::function(onnx::AttributeProto &)> attr_getter = - [](onnx::AttributeProto &attr) { - std::vector ints(attr.ints_size()); - std::copy(attr.ints().begin(), attr.ints().end(), ints.begin()); - return ints; - }; - auto r = get_attr_generic(node, name, attr_getter, default_val); - auto dataType = - mlir::RankedTensorType::get(r.size(), builder_.getIntegerType(32)); - auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r)); - auto aname = node.op_type() + "." + name; - auto attr_output = builder_.getNamedAttr(aname, attr_v); - return attr_output; - } - - auto get_attr_ints(onnx::NodeProto &node, std::string name) { - std::function(onnx::AttributeProto &)> attr_getter = - [](onnx::AttributeProto &attr) { - std::vector ints(attr.ints_size()); - std::copy(attr.ints().begin(), attr.ints().end(), ints.begin()); - return ints; - }; - auto r = get_attr_generic(node, name, attr_getter); - auto dataType = - mlir::RankedTensorType::get(r.size(), builder_.getIntegerType(32)); - auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r)); - auto aname = node.op_type() + "." + name; - auto attr_output = builder_.getNamedAttr(aname, attr_v); - return attr_output; - } - - auto get_attr_floats(onnx::NodeProto &node, std::string name) { - std::function(onnx::AttributeProto &)> attr_getter = - [](onnx::AttributeProto &attr) { - std::vector floats(attr.floats_size()); - std::copy(attr.floats().begin(), attr.floats().end(), floats.begin()); - return floats; - }; - auto r = get_attr_generic(node, name, attr_getter); - auto dataType = - mlir::RankedTensorType::get(r.size(), builder_.getF32Type()); - auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r)); - auto aname = node.op_type() + "." + name; - auto attr_output = builder_.getNamedAttr(aname, attr_v); - return attr_output; - } - - auto get_attr_floats(onnx::NodeProto &node, std::string name, - std::vector default_val) { - std::function(onnx::AttributeProto &)> attr_getter = - [](onnx::AttributeProto &attr) { - std::vector floats(attr.floats_size()); - std::copy(attr.floats().begin(), attr.floats().end(), floats.begin()); - return floats; - }; - auto r = get_attr_generic(node, name, attr_getter, default_val); - auto dataType = - mlir::RankedTensorType::get(r.size(), builder_.getF32Type()); - auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r)); - auto aname = node.op_type() + "." + name; - auto attr_output = builder_.getNamedAttr(aname, attr_v); - return attr_output; - } - - auto get_attr_int(onnx::NodeProto &node, std::string name) { - std::function attr_getter = - [](onnx::AttributeProto &attr) { return attr.i(); }; - int r = get_attr_generic(node, name, attr_getter); - auto attr_v = builder_.getI32IntegerAttr(r); - auto aname = node.op_type() + "." + name; - auto attr_output = builder_.getNamedAttr(aname, attr_v); - return attr_output; - } - - auto get_attr_int(onnx::NodeProto &node, std::string name, int default_val) { - std::function attr_getter = - [](onnx::AttributeProto &attr) { return attr.i(); }; - int r = get_attr_generic(node, name, attr_getter, default_val); - auto attr_v = builder_.getI32IntegerAttr(r); - auto aname = node.op_type() + "." + name; - auto attr_output = builder_.getNamedAttr(aname, attr_v); - return attr_output; - } - - auto get_attr_float(onnx::NodeProto &node, std::string name) { - std::function attr_getter = - [](onnx::AttributeProto &attr) { return attr.f(); }; - auto r = get_attr_generic(node, name, attr_getter); - auto attr_v = builder_.getF32FloatAttr(r); - auto aname = node.op_type() + "." + name; - return builder_.getNamedAttr(aname, attr_v); - } - - auto get_attr_float(onnx::NodeProto &node, std::string name, - float default_val) { - std::function attr_getter = - [](onnx::AttributeProto &attr) { return attr.f(); }; - auto r = get_attr_generic(node, name, attr_getter, default_val); - auto attr_v = builder_.getF32FloatAttr(r); - auto aname = node.op_type() + "." + name; - return builder_.getNamedAttr(aname, attr_v); - } - - auto get_attr_string(onnx::NodeProto &node, std::string name) { - std::function attr_getter = - [](onnx::AttributeProto &attr) { return attr.s(); }; - auto r = get_attr_generic(node, name, attr_getter); - auto attr_v = builder_.getStringAttr(r); - auto aname = node.op_type() + "." + name; - return builder_.getNamedAttr(aname, attr_v); - } - - auto get_attr_string(onnx::NodeProto &node, std::string name, - std::string default_val) { - std::function attr_getter = - [](onnx::AttributeProto &attr) { return attr.s(); }; - auto r = get_attr_generic(node, name, attr_getter, default_val); - auto attr_v = builder_.getStringAttr(r); - auto aname = node.op_type() + "." + name; - return builder_.getNamedAttr(aname, attr_v); - } - - /* - auto get_attr_strings(onnx::NodeProto &node, std::string name) { - std::function(onnx::AttributeProto &)> - attr_getter = - [](onnx::AttributeProto &attr) { - std::vector strings(attr.strings_size()); - std::copy(attr.strings().begin(), attr.strings().end(), - strings.begin()); return strings; - }; - auto r = get_attr_generic(node, name, attr_getter); - return r; - return builder_.getNamedAttr(aname, attr_v); - auto dataType = - mlir::RankedTensorType::get(r.size(), builder_.get???Type()); - auto attr_v = mlir::DenseElementsAttr::get(dataType, - llvm::makeArrayRef(r)); auto aname = node.op_type() + "." + name; auto - attr_output = builder_.getNamedAttr(aname, attr_v); return attr_output; - } - */ - - auto get_default_ints(std::string default_str) { - std::vector r; - auto start = default_str.find("{"); - while (true) { - auto end = default_str.find(",", start + 1); - if (end == std::string::npos) { - end = default_str.find("}", start + 1); - if (end != std::string::npos && end > start + 1) { - r.push_back(std::stoi(default_str.substr(start + 1, end))); - } - break; - } else { - r.push_back(std::stoi(default_str.substr(start + 1, end))); - } - start = end + 1; - } - return r; - } - - auto get_default_floats(std::string default_str) { - std::vector r; - auto start = default_str.find("{"); - while (true) { - auto end = default_str.find(",", start + 1); - if (end == std::string::npos) { - end = default_str.find("}", start + 1); - if (end != std::string::npos && end > start + 1) { - r.push_back(std::stof(default_str.substr(start + 1, end))); - } - break; - } else { - r.push_back(std::stof(default_str.substr(start + 1, end))); - } - start = end + 1; - } - return r; - } - - auto get_default_strings(std::string default_str) { - std::vector r; - auto start = default_str.find("{"); - while (true) { - auto end = default_str.find(",", start + 1); - if (end == std::string::npos) { - end = default_str.find("}", start + 1); - if (end != std::string::npos && end > start + 1) { - r.push_back(default_str.substr(start + 1, end)); - } - break; - } else { - r.push_back(default_str.substr(start + 1, end)); - } - start = end + 1; - } - return r; - } - - onnx::TensorProto get_attr_tensor(onnx::NodeProto &node, std::string name) { - std::function attr_getter = - [](onnx::AttributeProto &attr) { return attr.t(); }; - return get_attr_generic(node, name, attr_getter); - } - - auto ImportNodeAttr(onnx::NodeProto node, std::string attr_name, - std::string type_name, std::string default_str) { - if (default_str == "") { - if (type_name == "int") { - return get_attr_int(node, attr_name); - } else if (type_name == "float") { - return get_attr_float(node, attr_name); - } else if (type_name == "str") { - return get_attr_string(node, attr_name); - } else if (type_name == "ints") { - return get_attr_ints(node, attr_name); - } else if (type_name == "floats") { - return get_attr_floats(node, attr_name); - } else { - assert( - false && - "Got an empty initializer or initializer for this " - "datatype is not implemented. Something is wrong."); - } - } else { - // with default value - if (type_name == "int") { - return get_attr_int(node, attr_name, std::stoi(default_str)); - } else if (type_name == "float") { - return get_attr_float(node, attr_name, std::stof(default_str)); - } else if (type_name == "str") { - return get_attr_string(node, attr_name, default_str); - } else if (type_name == "ints") { - return get_attr_ints(node, attr_name, get_default_ints(default_str)); - } else if (type_name == "floats") { - return get_attr_floats(node, attr_name, - get_default_floats(default_str)); - } else { - assert( - false && - "Got an empty initializer or initializer for this " - "datatype is not implemented. Something is wrong."); - } - } - } - - void ImportNodeGeneric(onnx::NodeProto node) { + void ImportNodeGeneric(const onnx::NodeProto &node) { std::vector inputs; - for (auto item : node.input()) { + for (const auto &item : node.input()) { if (frontend_symbols_.ContainKey(legalize_name(item))) { inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); } @@ -511,12 +339,12 @@ private: * default} */ template - void ImportNodeOneOut( - onnx::NodeProto node, int nIn, int nOut, - std::initializer_list> - attrs) { + void + ImportNodeOneOut(const onnx::NodeProto &node, int nIn, int nOut, + std::initializer_list> + defaultAttrList) { std::vector inputs; - for (auto item : node.input()) { + for (const auto &item : node.input()) { if (frontend_symbols_.ContainKey(legalize_name(item))) { inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); } @@ -528,22 +356,7 @@ private: mlir::UnrankedTensorType::get(builder_.getF32Type())); } - std::vector attributes; - // for (auto [attr_name, attr_type, attr_default] : attrs) { - for (auto oneAttr : attrs) { - std::string attr_name; - std::string attr_type; - std::string attr_default; - std::tie(attr_name, attr_type, attr_default) = oneAttr; - if (attr_type != "") { - auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default); - attributes.push_back(attr); - } else { - // TODO: the attributes need special handling - // std::cout << "missing " << node.op_type() << " " << attr_name << - // std::endl; - } - } + auto attributes = ImportNodeAttributes(node, defaultAttrList); llvm::StringRef OpName = node.op_type(); @@ -559,11 +372,11 @@ private: template void ImportNodeMultipleOuts( - onnx::NodeProto node, int nIn, int nOut, - std::initializer_list> - attrs) { + const onnx::NodeProto &node, int nIn, int nOut, + std::initializer_list> + defaultAttrList) { std::vector inputs; - for (auto item : node.input()) { + for (const auto &item : node.input()) { if (frontend_symbols_.ContainKey(legalize_name(item))) { inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); } @@ -575,21 +388,7 @@ private: mlir::UnrankedTensorType::get(builder_.getF32Type())); } - std::vector attributes; - for (auto oneAttr : attrs) { - std::string attr_name; - std::string attr_type; - std::string attr_default; - std::tie(attr_name, attr_type, attr_default) = oneAttr; - if (attr_type != "") { - auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default); - attributes.push_back(attr); - } else { - // TODO: the attributes need special handling - // std::cout << "missing " << node.op_type() << " " << attr_name << - // std::endl; - } - } + auto attributes = ImportNodeAttributes(node, defaultAttrList); llvm::StringRef OpName = node.op_type(); @@ -610,10 +409,10 @@ private: * c++ does not allow template specialization inside a class scope * a specialized function is used */ - void ImportNodeConv( - onnx::NodeProto node, int nOut, - std::initializer_list> - attrs) { + void + ImportNodeConv(onnx::NodeProto node, int nIn, int nOut, + std::initializer_list> + defaultAttrList) { // Conv has attribute dilations, kernel_shape, pads, the default value of // which is determined by the shape of first argument. However, since the // shape is unknown now, these attributes can be not generated auto @@ -627,29 +426,32 @@ private: int nOps = node.input().size(); if (nOps == 2) - ImportNodeOneOut(node, nOps, nOut, attrs); + ImportNodeOneOut( + node, nOps, nOut, defaultAttrList); else - ImportNodeOneOut(node, nOps, nOut, attrs); + ImportNodeOneOut(node, nOps, nOut, defaultAttrList); } /*! * Special handle for MaxPool operations. */ void ImportNodeMaxPool( - onnx::NodeProto node, int nIn, - std::initializer_list> - attrs) { + onnx::NodeProto node, int nIn, int nOut, + std::initializer_list> + defaultAttrList) { int nOuts = node.output().size(); if (nOuts == 1) { - ImportNodeOneOut(node, nIn, nOuts, attrs); + ImportNodeOneOut( + node, nIn, nOuts, defaultAttrList); } else { - ImportNodeMultipleOuts(node, nIn, nOuts, attrs); + ImportNodeMultipleOuts( + node, nIn, nOuts, defaultAttrList); } } - void ImportNode(onnx::NodeProto node) { + void ImportNode(const onnx::NodeProto &node) { std::vector inputs; - for (auto item : node.input()) { + for (const auto &item : node.input()) { if (frontend_symbols_.ContainKey(legalize_name(item))) { inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); } @@ -689,9 +491,8 @@ private: llvm::SmallVectorImpl &ret_types, llvm::SmallVectorImpl &ret_vals) { auto output_tensor_legalized_name = legalize_name(output.name()); - assert( - frontend_symbols_.ContainKey(output_tensor_legalized_name) && - "Output tensor not found"); + assert(frontend_symbols_.ContainKey(output_tensor_legalized_name) && + "Output tensor not found"); auto tensor_val = frontend_symbols_.GetTensorByOnnxName(output_tensor_legalized_name); @@ -750,9 +551,9 @@ private: funcType = builder_.getFunctionType(arg_types, ret_types); mainFunc.setType(funcType); } -}; // FrontendGenImpl class -} // namespace -} // namespace onnf +}; // FrontendGenImpl class +} // namespace +} // namespace onnf namespace onnf { @@ -775,4 +576,4 @@ void ImportFrontendModelFile(std::string model_fname, FrontendGenImpl myONNXGen(context); module = myONNXGen.ImportONNXModel(model); } -} // namespace onnf +} // namespace onnf diff --git a/src/builder/op_build_table.inc b/src/builder/op_build_table.inc index 29a48cc..0e7f20e 100644 --- a/src/builder/op_build_table.inc +++ b/src/builder/op_build_table.inc @@ -16,13 +16,13 @@ }); }else if (OpName == "ArgMax") { ImportNodeOneOut(node, 1, 1, { - {"axis","int","0"} - ,{"keepdims","int","1"} + {"axis", 0} + ,{"keepdims", 1} }); }else if (OpName == "ArgMin") { ImportNodeOneOut(node, 1, 1, { - {"axis","int","0"} - ,{"keepdims","int","1"} + {"axis", 0} + ,{"keepdims", 1} }); }else if (OpName == "Asin") { ImportNodeOneOut(node, 1, 1, { @@ -38,25 +38,22 @@ }); }else if (OpName == "AveragePool") { ImportNodeOneOut(node, 1, 1, { - {"auto_pad","str","NOTSET"} - ,{"ceil_mode","int","0"} - ,{"count_include_pad","int","0"} - ,{"kernel_shape","ints", ""} - ,{"pads","", ""} - ,{"strides","", ""} + {"auto_pad", "NOTSET"} + ,{"ceil_mode", 0} + ,{"count_include_pad", 0} + ,{"kernel_shape", std::vector {}} }); }else if (OpName == "BatchNormalization") { ImportNodeMultipleOuts(node, 5, 5, { - {"epsilon","float","1e-05"} - ,{"momentum","float","0.9"} + {"epsilon", (float)1e-05} + ,{"momentum", (float)0.9} }); }else if (OpName == "BitShift") { ImportNodeOneOut(node, 2, 1, { - {"direction","", ""} }); }else if (OpName == "Cast") { ImportNodeOneOut(node, 1, 1, { - {"to","int", "0"} + {"to", 0} }); }else if (OpName == "Ceil") { ImportNodeOneOut(node, 1, 1, { @@ -66,54 +63,35 @@ }); }else if (OpName == "Compress") { ImportNodeOneOut(node, 2, 1, { - {"axis","", ""} }); }else if (OpName == "Concat") { ImportNodeOneOut(node, 1, 1, { - {"axis","int", "0"} + {"axis", 0} }); }else if (OpName == "ConcatFromSequence") { ImportNodeOneOut(node, 1, 1, { - {"axis","", ""} - ,{"new_axis","int","0"} + {"new_axis", 0} }); }else if (OpName == "Constant") { ImportNodeOneOut(node, 0, 1, { - {"sparse_value","", ""} - ,{"value","", ""} }); }else if (OpName == "ConstantOfShape") { ImportNodeOneOut(node, 1, 1, { - {"value","", ""} }); }else if (OpName == "Conv") { - ImportNodeConv(node, 1, { - {"auto_pad","str","NOTSET"} - ,{"dilations","", ""} - ,{"group","int", "1"} - ,{"kernel_shape","", ""} - ,{"pads","", ""} - ,{"strides","", ""} + ImportNodeConv(node, 3, 1, { + {"auto_pad", "NOTSET"} + ,{"group", 1} }); }else if (OpName == "ConvInteger") { ImportNodeOneOut(node, 4, 1, { - {"auto_pad","str","NOTSET"} - ,{"dilations","", ""} - ,{"group","int","1"} - ,{"kernel_shape","", ""} - ,{"pads","", ""} - ,{"strides","", ""} + {"auto_pad", "NOTSET"} + ,{"group", 1} }); }else if (OpName == "ConvTranspose") { ImportNodeOneOut(node, 3, 1, { - {"auto_pad","str","NOTSET"} - ,{"dilations","", ""} - ,{"group","int","1"} - ,{"kernel_shape","", ""} - ,{"output_padding","", ""} - ,{"output_shape","", ""} - ,{"pads","", ""} - ,{"strides","", ""} + {"auto_pad", "NOTSET"} + ,{"group", 1} }); }else if (OpName == "Cos") { ImportNodeOneOut(node, 1, 1, { @@ -123,13 +101,12 @@ }); }else if (OpName == "CumSum") { ImportNodeOneOut(node, 2, 1, { - {"exclusive","int","0"} - ,{"reverse","int","0"} + {"exclusive", 0} + ,{"reverse", 0} }); }else if (OpName == "DepthToSpace") { ImportNodeOneOut(node, 1, 1, { - {"blocksize","", ""} - ,{"mode","str","DCR"} + {"mode", "DCR"} }); }else if (OpName == "DequantizeLinear") { ImportNodeOneOut(node, 3, 1, { @@ -142,14 +119,14 @@ }); }else if (OpName == "Dropout") { ImportNodeMultipleOuts(node, 1, 2, { - {"ratio","float","0.5"} + {"ratio", (float)0.5} }); }else if (OpName == "DynamicQuantizeLinear") { ImportNodeMultipleOuts(node, 1, 3, { }); }else if (OpName == "Elu") { ImportNodeOneOut(node, 1, 1, { - {"alpha","float","1.0"} + {"alpha", (float)1.0} }); }else if (OpName == "Equal") { ImportNodeOneOut(node, 2, 1, { @@ -165,50 +142,44 @@ }); }else if (OpName == "EyeLike") { ImportNodeOneOut(node, 1, 1, { - {"dtype","", ""} - ,{"k","int","0"} + {"k", 0} }); }else if (OpName == "Flatten") { ImportNodeOneOut(node, 1, 1, { - {"axis","int","1"} + {"axis", 1} }); }else if (OpName == "Floor") { ImportNodeOneOut(node, 1, 1, { }); }else if (OpName == "GRU") { ImportNodeMultipleOuts(node, 6, 2, { - {"activation_alpha","", ""} - ,{"activation_beta","", ""} - ,{"activations","", ""} - ,{"clip","", ""} - ,{"direction","str","forward"} - ,{"hidden_size","", ""} - ,{"linear_before_reset","int","0"} + {"direction", "forward"} + ,{"linear_before_reset", 0} }); }else if (OpName == "Gather") { ImportNodeOneOut(node, 2, 1, { - {"axis","int","0"} + {"axis", 0} }); }else if (OpName == "GatherElements") { ImportNodeOneOut(node, 2, 1, { - {"axis","int","0"} + {"axis", 0} }); }else if (OpName == "GatherND") { ImportNodeOneOut(node, 2, 1, { }); }else if (OpName == "Gemm") { ImportNodeOneOut(node, 3, 1, { - {"alpha","float","1.0"} - ,{"beta","float","1.0"} - ,{"transA","int","0"} - ,{"transB","int","0"} + {"alpha", (float)1.0} + ,{"beta", (float)1.0} + ,{"transA", 0} + ,{"transB", 0} }); }else if (OpName == "GlobalAveragePool") { ImportNodeOneOut(node, 1, 1, { }); }else if (OpName == "GlobalLpPool") { ImportNodeOneOut(node, 1, 1, { - {"p","int","2"} + {"p", 2} }); }else if (OpName == "GlobalMaxPool") { ImportNodeOneOut(node, 1, 1, { @@ -218,53 +189,45 @@ }); }else if (OpName == "HardSigmoid") { ImportNodeOneOut(node, 1, 1, { - {"alpha","float","0.2"} - ,{"beta","float","0.5"} + {"alpha", (float)0.2} + ,{"beta", (float)0.5} }); }else if (OpName == "Hardmax") { ImportNodeOneOut(node, 1, 1, { - {"axis","int","1"} + {"axis", 1} }); }else if (OpName == "Identity") { ImportNodeOneOut(node, 1, 1, { }); }else if (OpName == "If") { ImportNodeOneOut(node, 1, 1, { - {"else_branch","", ""} - ,{"then_branch","", ""} }); }else if (OpName == "InstanceNormalization") { ImportNodeOneOut(node, 3, 1, { - {"epsilon","float","1e-05"} + {"epsilon", (float)1e-05} }); }else if (OpName == "IsInf") { ImportNodeOneOut(node, 1, 1, { - {"detect_negative","int","1"} - ,{"detect_positive","int","1"} + {"detect_negative", 1} + ,{"detect_positive", 1} }); }else if (OpName == "IsNaN") { ImportNodeOneOut(node, 1, 1, { }); }else if (OpName == "LRN") { ImportNodeOneOut(node, 1, 1, { - {"alpha","float","0.0001"} - ,{"beta","float","0.75"} - ,{"bias","float","1.0"} - ,{"size","int", ""} + {"alpha", (float)0.0001} + ,{"beta", (float)0.75} + ,{"bias", (float)1.0} }); }else if (OpName == "LSTM") { ImportNodeMultipleOuts(node, 8, 3, { - {"activation_alpha","", ""} - ,{"activation_beta","", ""} - ,{"activations","", ""} - ,{"clip","", ""} - ,{"direction","str","forward"} - ,{"hidden_size","", ""} - ,{"input_forget","int","0"} + {"direction", "forward"} + ,{"input_forget", 0} }); }else if (OpName == "LeakyRelu") { ImportNodeOneOut(node, 1, 1, { - {"alpha","float","0.01"} + {"alpha", (float)0.01} }); }else if (OpName == "Less") { ImportNodeOneOut(node, 2, 1, { @@ -274,24 +237,20 @@ }); }else if (OpName == "LogSoftmax") { ImportNodeOneOut(node, 1, 1, { - {"axis","int","1"} + {"axis", 1} }); }else if (OpName == "Loop") { ImportNodeOneOut(node, 3, 1, { - {"body","", ""} }); }else if (OpName == "LpNormalization") { ImportNodeOneOut(node, 1, 1, { - {"axis","int","-1"} - ,{"p","int","2"} + {"axis", -1} + ,{"p", 2} }); }else if (OpName == "LpPool") { ImportNodeOneOut(node, 1, 1, { - {"auto_pad","str","NOTSET"} - ,{"kernel_shape","", ""} - ,{"p","int","2"} - ,{"pads","", ""} - ,{"strides","", ""} + {"auto_pad", "NOTSET"} + ,{"p", 2} }); }else if (OpName == "MatMul") { ImportNodeOneOut(node, 2, 1, { @@ -303,55 +262,47 @@ ImportNodeOneOut(node, 1, 1, { }); }else if (OpName == "MaxPool") { - ImportNodeMaxPool(node, 1, { - {"auto_pad","str","NOTSET"} - ,{"ceil_mode","int","0"} - ,{"dilations","", ""} - ,{"kernel_shape","ints", ""} - ,{"pads","", ""} - ,{"storage_order","int","0"} - ,{"strides","", ""} + ImportNodeMaxPool(node, 1, 2, { + {"auto_pad", "NOTSET"} + ,{"ceil_mode", 0} + ,{"kernel_shape", std::vector {}} + ,{"storage_order", 0} }); }else if (OpName == "MaxRoiPool") { ImportNodeOneOut(node, 2, 1, { - {"pooled_shape","", ""} - ,{"spatial_scale","float","1.0"} + {"spatial_scale", (float)1.0} }); }else if (OpName == "MaxUnpool") { ImportNodeOneOut(node, 3, 1, { - {"kernel_shape","", ""} - ,{"pads","", ""} - ,{"strides","", ""} }); }else if (OpName == "Mean") { ImportNodeOneOut(node, 1, 1, { }); }else if (OpName == "MeanVarianceNormalization") { ImportNodeOneOut(node, 1, 1, { - {"axes","ints","{'0', '2', '3'}"} + {"axes", std::vector{0, 2, 3}} }); }else if (OpName == "Min") { ImportNodeOneOut(node, 1, 1, { }); }else if (OpName == "Mod") { ImportNodeOneOut(node, 2, 1, { - {"fmod","int","0"} + {"fmod", 0} }); }else if (OpName == "Mul") { ImportNodeOneOut(node, 2, 1, { }); }else if (OpName == "Multinomial") { ImportNodeOneOut(node, 1, 1, { - {"dtype","int","6"} - ,{"sample_size","int","1"} - ,{"seed","", ""} + {"dtype", 6} + ,{"sample_size", 1} }); }else if (OpName == "Neg") { ImportNodeOneOut(node, 1, 1, { }); }else if (OpName == "NonMaxSuppression") { ImportNodeOneOut(node, 5, 1, { - {"center_point_box","int","0"} + {"center_point_box", 0} }); }else if (OpName == "NonZero") { ImportNodeOneOut(node, 1, 1, { @@ -361,7 +312,7 @@ }); }else if (OpName == "OneHot") { ImportNodeOneOut(node, 3, 1, { - {"axis","int","-1"} + {"axis", -1} }); }else if (OpName == "Or") { ImportNodeOneOut(node, 2, 1, { @@ -371,19 +322,15 @@ }); }else if (OpName == "Pad") { ImportNodeOneOut(node, 3, 1, { - {"mode","str","constant"} + {"mode", "constant"} }); }else if (OpName == "Pow") { ImportNodeOneOut(node, 2, 1, { }); }else if (OpName == "QLinearConv") { ImportNodeOneOut(node, 9, 1, { - {"auto_pad","str","NOTSET"} - ,{"dilations","", ""} - ,{"group","int","1"} - ,{"kernel_shape","", ""} - ,{"pads","", ""} - ,{"strides","", ""} + {"auto_pad", "NOTSET"} + ,{"group", 1} }); }else if (OpName == "QLinearMatMul") { ImportNodeOneOut(node, 8, 1, { @@ -393,42 +340,32 @@ }); }else if (OpName == "RNN") { ImportNodeMultipleOuts(node, 6, 2, { - {"activation_alpha","floats", "{}"} - ,{"activation_beta","floats", "{}"} - ,{"activations","", "{Tannh, Tanh}"} - ,{"clip","", ""} - ,{"direction","str","forward"} - ,{"hidden_size","", ""} + {"activation_alpha", std::vector {}} + ,{"activation_beta", std::vector {}} + ,{"activations", std::vector{"Tanh", "Tanh"}} + ,{"direction", "forward"} }); }else if (OpName == "RandomNormal") { ImportNodeOneOut(node, 0, 1, { - {"dtype","int","1"} - ,{"mean","float","0.0"} - ,{"scale","float","1.0"} - ,{"seed","", ""} - ,{"shape","", ""} + {"dtype", 1} + ,{"mean", (float)0.0} + ,{"scale", (float)1.0} }); }else if (OpName == "RandomNormalLike") { ImportNodeOneOut(node, 1, 1, { - {"dtype","", ""} - ,{"mean","float","0.0"} - ,{"scale","float","1.0"} - ,{"seed","", ""} + {"mean", (float)0.0} + ,{"scale", (float)1.0} }); }else if (OpName == "RandomUniform") { ImportNodeOneOut(node, 0, 1, { - {"dtype","int","1"} - ,{"high","float","1.0"} - ,{"low","float","0.0"} - ,{"seed","", ""} - ,{"shape","", ""} + {"dtype", 1} + ,{"high", (float)1.0} + ,{"low", (float)0.0} }); }else if (OpName == "RandomUniformLike") { ImportNodeOneOut(node, 1, 1, { - {"dtype","", ""} - ,{"high","float","1.0"} - ,{"low","float","0.0"} - ,{"seed","", ""} + {"high", (float)1.0} + ,{"low", (float)0.0} }); }else if (OpName == "Range") { ImportNodeOneOut(node, 3, 1, { @@ -438,53 +375,43 @@ }); }else if (OpName == "ReduceL1") { ImportNodeOneOut(node, 1, 1, { - {"axes","", ""} - ,{"keepdims","int","1"} + {"keepdims", 1} }); }else if (OpName == "ReduceL2") { ImportNodeOneOut(node, 1, 1, { - {"axes","", ""} - ,{"keepdims","int","1"} + {"keepdims", 1} }); }else if (OpName == "ReduceLogSum") { ImportNodeOneOut(node, 1, 1, { - {"axes","", ""} - ,{"keepdims","int","1"} + {"keepdims", 1} }); }else if (OpName == "ReduceLogSumExp") { ImportNodeOneOut(node, 1, 1, { - {"axes","", ""} - ,{"keepdims","int","1"} + {"keepdims", 1} }); }else if (OpName == "ReduceMax") { ImportNodeOneOut(node, 1, 1, { - {"axes","", ""} - ,{"keepdims","int","1"} + {"keepdims", 1} }); }else if (OpName == "ReduceMean") { ImportNodeOneOut(node, 1, 1, { - {"axes","", ""} - ,{"keepdims","int","1"} + {"keepdims", 1} }); }else if (OpName == "ReduceMin") { ImportNodeOneOut(node, 1, 1, { - {"axes","", ""} - ,{"keepdims","int","1"} + {"keepdims", 1} }); }else if (OpName == "ReduceProd") { ImportNodeOneOut(node, 1, 1, { - {"axes","", ""} - ,{"keepdims","int","1"} + {"keepdims", 1} }); }else if (OpName == "ReduceSum") { ImportNodeOneOut(node, 1, 1, { - {"axes","", ""} - ,{"keepdims","int","1"} + {"keepdims", 1} }); }else if (OpName == "ReduceSumSquare") { ImportNodeOneOut(node, 1, 1, { - {"axes","", ""} - ,{"keepdims","int","1"} + {"keepdims", 1} }); }else if (OpName == "Relu") { ImportNodeOneOut(node, 1, 1, { @@ -494,53 +421,47 @@ }); }else if (OpName == "Resize") { ImportNodeOneOut(node, 4, 1, { - {"coordinate_transformation_mode","str","half_pixel"} - ,{"cubic_coeff_a","float","-0.75"} - ,{"exclude_outside","int","0"} - ,{"extrapolation_value","float","0.0"} - ,{"mode","str","nearest"} - ,{"nearest_mode","str","round_prefer_floor"} + {"coordinate_transformation_mode", "half_pixel"} + ,{"cubic_coeff_a", (float)-0.75} + ,{"exclude_outside", 0} + ,{"extrapolation_value", (float)0.0} + ,{"mode", "nearest"} + ,{"nearest_mode", "round_prefer_floor"} }); }else if (OpName == "ReverseSequence") { ImportNodeOneOut(node, 2, 1, { - {"batch_axis","int","1"} - ,{"time_axis","int","0"} + {"batch_axis", 1} + ,{"time_axis", 0} }); }else if (OpName == "RoiAlign") { ImportNodeOneOut(node, 3, 1, { - {"mode","str","avg"} - ,{"output_height","int","1"} - ,{"output_width","int","1"} - ,{"sampling_ratio","int","0"} - ,{"spatial_scale","float","1.0"} + {"mode", "avg"} + ,{"output_height", 1} + ,{"output_width", 1} + ,{"sampling_ratio", 0} + ,{"spatial_scale", (float)1.0} }); }else if (OpName == "Round") { ImportNodeOneOut(node, 1, 1, { }); }else if (OpName == "Scan") { ImportNodeOneOut(node, 1, 1, { - {"body","", ""} - ,{"num_scan_inputs","", ""} - ,{"scan_input_axes","", ""} - ,{"scan_input_directions","", ""} - ,{"scan_output_axes","", ""} - ,{"scan_output_directions","", ""} }); }else if (OpName == "Scatter") { ImportNodeOneOut(node, 3, 1, { - {"axis","int","0"} + {"axis", 0} }); }else if (OpName == "ScatterElements") { ImportNodeOneOut(node, 3, 1, { - {"axis","int","0"} + {"axis", 0} }); }else if (OpName == "ScatterND") { ImportNodeOneOut(node, 3, 1, { }); }else if (OpName == "Selu") { ImportNodeOneOut(node, 1, 1, { - {"alpha","float","1.67326"} - ,{"gamma","float","1.0507"} + {"alpha", (float)1.67326} + ,{"gamma", (float)1.0507} }); }else if (OpName == "SequenceAt") { ImportNodeOneOut(node, 2, 1, { @@ -550,7 +471,6 @@ }); }else if (OpName == "SequenceEmpty") { ImportNodeOneOut(node, 0, 1, { - {"dtype","", ""} }); }else if (OpName == "SequenceErase") { ImportNodeOneOut(node, 2, 1, { @@ -566,8 +486,8 @@ }); }else if (OpName == "Shrink") { ImportNodeOneOut(node, 1, 1, { - {"bias","float","0.0"} - ,{"lambd","float","0.5"} + {"bias", (float)0.0} + ,{"lambd", (float)0.5} }); }else if (OpName == "Sigmoid") { ImportNodeOneOut(node, 1, 1, { @@ -589,7 +509,7 @@ }); }else if (OpName == "Softmax") { ImportNodeOneOut(node, 1, 1, { - {"axis","int","1"} + {"axis", 1} }); }else if (OpName == "Softplus") { ImportNodeOneOut(node, 1, 1, { @@ -599,31 +519,26 @@ }); }else if (OpName == "SpaceToDepth") { ImportNodeOneOut(node, 1, 1, { - {"blocksize","", ""} }); }else if (OpName == "Split") { ImportNodeOneOut(node, 1, 1, { - {"axis","int","0"} - ,{"split","", ""} + {"axis", 0} }); }else if (OpName == "SplitToSequence") { ImportNodeOneOut(node, 2, 1, { - {"axis","int","0"} - ,{"keepdims","int","1"} + {"axis", 0} + ,{"keepdims", 1} }); }else if (OpName == "Sqrt") { ImportNodeOneOut(node, 1, 1, { }); }else if (OpName == "Squeeze") { ImportNodeOneOut(node, 1, 1, { - {"axes","", ""} }); }else if (OpName == "StringNormalizer") { ImportNodeOneOut(node, 1, 1, { - {"case_change_action","str","NONE"} - ,{"is_case_sensitive","int","0"} - ,{"locale","", ""} - ,{"stopwords","", ""} + {"case_change_action", "NONE"} + ,{"is_case_sensitive", 0} }); }else if (OpName == "Sub") { ImportNodeOneOut(node, 2, 1, { @@ -639,45 +554,34 @@ }); }else if (OpName == "TfIdfVectorizer") { ImportNodeOneOut(node, 1, 1, { - {"max_gram_length","", ""} - ,{"max_skip_count","", ""} - ,{"min_gram_length","", ""} - ,{"mode","", ""} - ,{"ngram_counts","", ""} - ,{"ngram_indexes","", ""} - ,{"pool_int64s","", ""} - ,{"pool_strings","", ""} - ,{"weights","", ""} }); }else if (OpName == "ThresholdedRelu") { ImportNodeOneOut(node, 1, 1, { - {"alpha","float","1.0"} + {"alpha", (float)1.0} }); }else if (OpName == "Tile") { ImportNodeOneOut(node, 2, 1, { }); }else if (OpName == "TopK") { ImportNodeMultipleOuts(node, 2, 2, { - {"axis","int","-1"} - ,{"largest","int","1"} - ,{"sorted","int","1"} + {"axis", -1} + ,{"largest", 1} + ,{"sorted", 1} }); }else if (OpName == "Transpose") { ImportNodeOneOut(node, 1, 1, { - {"perm","", ""} }); }else if (OpName == "Unique") { ImportNodeMultipleOuts(node, 1, 4, { - {"axis","", ""} - ,{"sorted","int","1"} + {"sorted", 1} }); }else if (OpName == "Unsqueeze") { ImportNodeOneOut(node, 1, 1, { - {"axes","ints", ""} + {"axes", std::vector {}} }); }else if (OpName == "Upsample") { ImportNodeOneOut(node, 2, 1, { - {"mode","str","nearest"} + {"mode", "nearest"} }); }else if (OpName == "Where") { ImportNodeOneOut(node, 3, 1, { @@ -685,4 +589,4 @@ }else if (OpName == "Xor") { ImportNodeOneOut(node, 2, 1, { }); - } \ No newline at end of file + } diff --git a/src/dialect/onnx/gen_doc.py b/src/dialect/onnx/gen_doc.py index 4141556..ed99e57 100644 --- a/src/dialect/onnx/gen_doc.py +++ b/src/dialect/onnx/gen_doc.py @@ -368,17 +368,17 @@ def gen_code(schema,fefile) : ("MaxPool", "ImportNodeMaxPool"), #("Transpose", "ImportNodeTranspose") ]) - special_type = dict([ - ("AveragePool "+"kernel_shape", '"ints", ""'), - ("MaxPool "+"kernel_shape", '"ints", ""'), - ("Cast "+"to", '"int", "0"'), - ("Concat "+"axis", '"int", "0"'), - ("Conv "+"group", '"int", "1"'), - ("Unsqueeze "+"axes", '"ints", ""'), - ("RNN "+"activation_alpha", '"floats", "{}"'), - ("RNN "+"activation_beta", '"floats", "{}"'), - ("RNN "+"activations", '"", "{Tannh, Tanh}"'), - ("LRN "+"size", '"int", ""') + list_str = 'std::vector' + empty_ints = list_str+' {}' + empty_floats = list_str+' {}' + special_default = dict([ + ("AveragePool "+"kernel_shape", empty_ints), + ("MaxPool "+"kernel_shape", empty_ints), + ("Cast "+"to", '0'), + ("Concat "+"axis", '0'), + ("Unsqueeze "+"axes", empty_ints), + ("RNN "+"activation_alpha", empty_floats), + ("RNN "+"activation_beta", empty_floats) ]) line_indent = ' ' fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n') @@ -400,21 +400,9 @@ def gen_code(schema,fefile) : if schema.attributes: first_attr = True for _, attr in sorted(schema.attributes.items()): - attr_line = line_indent+line_indent+line_indent+line_indent - if not first_attr: - attr_line += ',{' - else : - attr_line += ' {' - first_attr = False - - attr_line += '"'+attr.name+'",' - - if schema.name+' '+attr.name in special_type: - attr_line += special_type[schema.name+' '+attr.name] - # option holds either required or default value - elif attr.required: - attr_line += '"", ""' - + #only generate default attr list + if schema.name+' '+attr.name in special_default: + attr_value = special_default[schema.name+' '+attr.name] elif attr.default_value.name: default_value = helper.get_attribute_value(attr.default_value) @@ -430,28 +418,35 @@ def gen_code(schema,fefile) : return str(value) if isinstance(default_value, list): + value = default_value[0] default_value = [format_value(val) for val in default_value] + attr_option_str = '{}'.format(default_value) + attr_option_str = attr_option_str.replace('[', '{', 1) + attr_option_str = attr_option_str.replace(']', '}', 1) # TODO the list type is homogenous or htergeneous? if isinstance(value, float) : - attr_type_str = '"floats"' + attr_type_str = list_str+'' + attr_option_str = attr_option_str.replace("'", '') elif isinstance(value, int) : - attr_type_str = '"ints"' + attr_type_str = list_str+'' + attr_option_str = attr_option_str.replace("'", '') elif isinstance(value, str) : - attr_type_str = '"strs"' + attr_type_str = list_str+'' + attr_option_str = attr_option_str.replace("'", '"') elif isinstance(value, (bytes, bytearray)) : - attr_type_str = '"strs"' + attr_type_str = list_str+'' + attr_option_str = attr_option_str.replace("'", '"') else : attr_type_str = '"unknowns"' - attr_option_str = '"{}"'.format(default_value) - attr_option_str = attr_option_str.replace('[', '{', 1) - attr_option_str = attr_option_str.replace(']', '}', 1) else: if isinstance(default_value, float) : - attr_type_str = '"float"' + attr_type_str = '(float)' + attr_option_str = default_value elif isinstance(default_value, int) : - attr_type_str = '"int"' + attr_option_str = default_value + attr_type_str='' elif isinstance(default_value, str) : attr_type_str = '"str"' elif isinstance(default_value, (bytes, bytearray)) : @@ -459,11 +454,25 @@ def gen_code(schema,fefile) : else : attr_type_str = '"unknown"' default_value = format_value(default_value) - attr_option_str = '"{}"'.format(default_value) - attr_line += attr_type_str+','+attr_option_str + if attr_type_str == '"str"' : + attr_option_str = '"'+default_value+'"' + attr_type_str='' + else : + attr_option_str = default_value + attr_value = attr_type_str+attr_option_str else: - #TODO why? - attr_line += '"", ""' + #no default value + continue + + attr_line = line_indent+line_indent+line_indent+line_indent + if not first_attr: + attr_line += ',{' + else : + attr_line += ' {' + first_attr = False + + attr_line += '"'+attr.name+'", ' + attr_line += attr_value attr_line += '}\n' fefile.write(attr_line) fefile.write(line_indent+line_indent+line_indent+'});\n') diff --git a/src/pass/lower_frontend_to_krnl.cpp b/src/pass/lower_frontend_to_krnl.cpp index 3d899ee..58d603a 100644 --- a/src/pass/lower_frontend_to_krnl.cpp +++ b/src/pass/lower_frontend_to_krnl.cpp @@ -420,8 +420,8 @@ Value mapToLowerScalarOp( // Constant 1) auto loc = op->getLoc(); Value operand = operands[0]; - auto alphaAttr = op->getAttrOfType("HardSigmoid.alpha"); - auto betaAttr = op->getAttrOfType("HardSigmoid.beta"); + auto alphaAttr = op->getAttrOfType("alpha"); + auto betaAttr = op->getAttrOfType("beta"); auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); @@ -455,7 +455,7 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, Value operand = operands[0]; auto elementType = result_types[0]; - auto alphaAttr = op->getAttrOfType("Elu.alpha"); + auto alphaAttr = op->getAttrOfType("alpha"); auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto alpha = rewriter.create(loc, alphaAttr); @@ -508,7 +508,7 @@ Value mapToLowerScalarOp(Operation *op, Value operand = operands[0]; auto elementType = result_types[0]; - auto alphaAttr = op->getAttrOfType("LeakyRelu.alpha"); + auto alphaAttr = op->getAttrOfType("alpha"); auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto alpha = rewriter.create(loc, alphaAttr); auto lessThanZero = @@ -533,8 +533,8 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, // alpha))) auto loc = op->getLoc(); Value operand = operands[0]; - auto alphaAttr = op->getAttrOfType("Selu.alpha"); - auto gammaAttr = op->getAttrOfType("Selu.gamma"); + auto alphaAttr = op->getAttrOfType("alpha"); + auto gammaAttr = op->getAttrOfType("gamma"); auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); @@ -836,7 +836,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern { // exp_x / sum auto tensorType = (*op->result_type_begin()).cast(); int64_t rank = tensorType.getRank(); - int64_t axis = op->getAttrOfType("Softmax.axis").getInt(); + int64_t axis = op->getAttrOfType("axis").getInt(); axis = axis >= 0 ? axis : rank + axis; assert(axis >= -rank && axis <= rank - 1); diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 3ffce9a..c6c5927 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -385,7 +385,7 @@ func @test_min(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<* } func @test_elu(%arg0 : tensor) -> tensor<*xf32> { - %0 = "onnx.Elu"(%arg0) {Elu.alpha=2.0:f32} : (tensor) -> tensor<*xf32> + %0 = "onnx.Elu"(%arg0) {alpha=2.0:f32} : (tensor) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_elu @@ -411,7 +411,7 @@ func @test_elu(%arg0 : tensor) -> tensor<*xf32> { } func @test_leakyrelu(%arg0 : tensor) -> tensor<*xf32> { - %0 = "onnx.LeakyRelu"(%arg0) {LeakyRelu.alpha=1.0:f32} : (tensor) -> tensor<*xf32> + %0 = "onnx.LeakyRelu"(%arg0) {alpha=1.0:f32} : (tensor) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_leakyrelu @@ -434,7 +434,7 @@ func @test_leakyrelu(%arg0 : tensor) -> tensor<*xf32> { } func @test_selu(%arg0 : tensor) -> tensor<*xf32> { - %0 = "onnx.Selu"(%arg0) {Selu.alpha=1.0:f32, Selu.gamma=2.0:f32} : (tensor) -> tensor<*xf32> + %0 = "onnx.Selu"(%arg0) {alpha=1.0:f32, gamma=2.0:f32} : (tensor) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_selu @@ -461,7 +461,7 @@ func @test_selu(%arg0 : tensor) -> tensor<*xf32> { } func @test_hardsigmoid(%arg0 : tensor) -> tensor<*xf32> { - %0 = "onnx.HardSigmoid"(%arg0) {HardSigmoid.alpha=1.0:f32, HardSigmoid.beta=2.0:f32} : (tensor) -> tensor<*xf32> + %0 = "onnx.HardSigmoid"(%arg0) {alpha=1.0:f32, beta=2.0:f32} : (tensor) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_hardsigmoid @@ -535,7 +535,7 @@ func @test_add_with_broadcasting(%arg0 : tensor, %arg1 : tensor } func @test_softmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> { - %0 = "onnx.Softmax"(%arg0) {Softmax.axis=1:i32} : (tensor<10x10xf32>) -> tensor<*xf32> + %0 = "onnx.Softmax"(%arg0) {axis=1:i32} : (tensor<10x10xf32>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_softmax diff --git a/test/mlir/onnx/onnx_lowering_with_dealloc.mlir b/test/mlir/onnx/onnx_lowering_with_dealloc.mlir index 385dc3c..1286041 100644 --- a/test/mlir/onnx/onnx_lowering_with_dealloc.mlir +++ b/test/mlir/onnx/onnx_lowering_with_dealloc.mlir @@ -648,8 +648,8 @@ func @test_min_min(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tens } func @test_elu_elu(%arg0 : tensor) -> tensor<*xf32> { - %0 = "onnx.Elu"(%arg0) {Elu.alpha=2.0:f32} : (tensor) -> tensor<*xf32> - %1 = "onnx.Elu"(%0) {Elu.alpha=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32> + %0 = "onnx.Elu"(%arg0) {alpha=2.0:f32} : (tensor) -> tensor<*xf32> + %1 = "onnx.Elu"(%0) {alpha=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32> "std.return"(%1) : (tensor<*xf32>) -> () // CHECK-LABEL: test_elu_elu @@ -701,8 +701,8 @@ func @test_elu_elu(%arg0 : tensor) -> tensor<*xf32> { } func @test_leakyrelu_leakyrelu(%arg0 : tensor) -> tensor<*xf32> { - %0 = "onnx.LeakyRelu"(%arg0) {LeakyRelu.alpha=1.0:f32} : (tensor) -> tensor<*xf32> - %1 = "onnx.LeakyRelu"(%0) {LeakyRelu.alpha=1.0:f32} : (tensor<*xf32>) -> tensor<*xf32> + %0 = "onnx.LeakyRelu"(%arg0) {alpha=1.0:f32} : (tensor) -> tensor<*xf32> + %1 = "onnx.LeakyRelu"(%0) {alpha=1.0:f32} : (tensor<*xf32>) -> tensor<*xf32> "std.return"(%1) : (tensor<*xf32>) -> () // CHECK-LABEL: test_leakyrelu_leakyrelu @@ -748,8 +748,8 @@ func @test_leakyrelu_leakyrelu(%arg0 : tensor) -> tensor<*xf32> { } func @test_selu_selu(%arg0 : tensor) -> tensor<*xf32> { - %0 = "onnx.Selu"(%arg0) {Selu.alpha=1.0:f32, Selu.gamma=2.0:f32} : (tensor) -> tensor<*xf32> - %1 = "onnx.Selu"(%0) {Selu.alpha=1.0:f32, Selu.gamma=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32> + %0 = "onnx.Selu"(%arg0) {alpha=1.0:f32, gamma=2.0:f32} : (tensor) -> tensor<*xf32> + %1 = "onnx.Selu"(%0) {alpha=1.0:f32, gamma=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32> "std.return"(%1) : (tensor<*xf32>) -> () // CHECK-LABEL: test_selu_selu @@ -803,8 +803,8 @@ func @test_selu_selu(%arg0 : tensor) -> tensor<*xf32> { } func @test_hardsigmoid_hardsigmoid(%arg0 : tensor) -> tensor<*xf32> { - %0 = "onnx.HardSigmoid"(%arg0) {HardSigmoid.alpha=1.0:f32, HardSigmoid.beta=2.0:f32} : (tensor) -> tensor<*xf32> - %1 = "onnx.HardSigmoid"(%0) {HardSigmoid.alpha=1.0:f32, HardSigmoid.beta=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32> + %0 = "onnx.HardSigmoid"(%arg0) {alpha=1.0:f32, beta=2.0:f32} : (tensor) -> tensor<*xf32> + %1 = "onnx.HardSigmoid"(%0) {alpha=1.0:f32, beta=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32> "std.return"(%1) : (tensor<*xf32>) -> () // CHECK-LABEL: test_hardsigmoid_hardsigmoid diff --git a/third_party/variant b/third_party/variant new file mode 160000 index 0000000..3c7fc82 --- /dev/null +++ b/third_party/variant @@ -0,0 +1 @@ +Subproject commit 3c7fc8266bb46046b42c2dc2663f9f505f0cec28