diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index c806eef..21c66a0 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -128,6 +128,13 @@ private: case onnx::AttributeProto::TENSOR: mlirAttr = onnxTensorProtoToDenseElmAttr(builder_, attr.t()); break; + case onnx::AttributeProto::STRINGS: { + llvm::SmallVector vectorStringRef; + for (const auto &item : attr.strings()) { + vectorStringRef.push_back(llvm::StringRef(item)); + } + mlirAttr = builder_.getStrArrayAttr(llvm::makeArrayRef(vectorStringRef)); + } break; default: llvm_unreachable("datatype for attribute is not implemented"); break; diff --git a/src/Builder/MLOpBuildTable.inc b/src/Builder/MLOpBuildTable.inc deleted file mode 100644 index 58d3d3b..0000000 --- a/src/Builder/MLOpBuildTable.inc +++ /dev/null @@ -1,42 +0,0 @@ -//******************************************************** -// Do not modify this file directly. -// This file is automatically generated via script. -// Details can be found in docs/readonnxdefs.md . -//******************************************************** - -if (opName == "ArrayFeatureExtractor") - buildOperation(node); -if (opName == "Binarizer") - buildOperation(node); -if (opName == "CastMap") - buildOperation(node); -if (opName == "CategoryMapper") - buildOperation(node); -if (opName == "DictVectorizer") - buildOperation(node); -if (opName == "FeatureVectorizer") - buildOperation(node); -if (opName == "Imputer") - buildOperation(node); -if (opName == "LabelEncoder") - buildOperation(node); -if (opName == "LinearClassifier") - buildOperation(node); -if (opName == "LinearRegressor") - buildOperation(node); -if (opName == "Normalizer") - buildOperation(node); -if (opName == "OneHotEncoder") - buildOperation(node); -if (opName == "SVMClassifier") - buildOperation(node); -if (opName == "SVMRegressor") - buildOperation(node); -if (opName == "Scaler") - buildOperation(node); -if (opName == "TreeEnsembleClassifier") - buildOperation(node); -if (opName == "TreeEnsembleRegressor") - buildOperation(node); -if (opName == "ZipMap") - buildOperation(node); diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index 89d969a..7089c14 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -204,7 +204,7 @@ def ONNXArgMaxOp:ONNX_Op<"ArgMax", return 1; } static std::vector getTypeMap() { - return {-1}; + return {4}; } }]; } @@ -230,7 +230,7 @@ def ONNXArgMinOp:ONNX_Op<"ArgMin", return 1; } static std::vector getTypeMap() { - return {-1}; + return {4}; } }]; } @@ -944,7 +944,7 @@ def ONNXDequantizeLinearOp:ONNX_Op<"DequantizeLinear", return 1; } static std::vector getTypeMap() { - return {21}; + return {7}; } }]; } @@ -1091,7 +1091,7 @@ def ONNXDynamicQuantizeLinearOp:ONNX_Op<"DynamicQuantizeLinear", return 3; } static std::vector getTypeMap() { - return {1,-1,1}; + return {1,7,1}; } }]; } @@ -2914,7 +2914,7 @@ def ONNXNonMaxSuppressionOp:ONNX_Op<"NonMaxSuppression", return 1; } static std::vector getTypeMap() { - return {22}; + return {4}; } }]; } @@ -2938,7 +2938,7 @@ def ONNXNonZeroOp:ONNX_Op<"NonZero", return 1; } static std::vector getTypeMap() { - return {-1}; + return {4}; } }]; } @@ -5144,7 +5144,7 @@ def ONNXStringNormalizerOp:ONNX_Op<"StringNormalizer", return 1; } static std::vector getTypeMap() { - return {20}; + return {11}; } }]; } @@ -5526,7 +5526,7 @@ def ONNXUniqueOp:ONNX_Op<"Unique", return 4; } static std::vector getTypeMap() { - return {20,-1,-1,-1}; + return {20,4,4,4}; } }]; } @@ -5823,7 +5823,7 @@ def ONNXFeatureVectorizerOp:ONNX_Op<"FeatureVectorizer", return 1; } static std::vector getTypeMap() { - return {-1}; + return {7}; } }]; } @@ -5929,7 +5929,7 @@ def ONNXLinearClassifierOp:ONNX_Op<"LinearClassifier", return 2; } static std::vector getTypeMap() { - return {-1,-1}; + return {-1,7}; } }]; } @@ -5959,7 +5959,7 @@ def ONNXLinearRegressorOp:ONNX_Op<"LinearRegressor", return 1; } static std::vector getTypeMap() { - return {-1}; + return {7}; } }]; } @@ -5990,7 +5990,7 @@ def ONNXNormalizerOp:ONNX_Op<"Normalizer", return 1; } static std::vector getTypeMap() { - return {-1}; + return {7}; } }]; } @@ -6021,7 +6021,7 @@ def ONNXOneHotEncoderOp:ONNX_Op<"OneHotEncoder", return 1; } static std::vector getTypeMap() { - return {-1}; + return {7}; } }]; } @@ -6054,7 +6054,7 @@ def ONNXSVMClassifierOp:ONNX_Op<"SVMClassifier", return 2; } static std::vector getTypeMap() { - return {-1,-1}; + return {-1,7}; } }]; } @@ -6083,7 +6083,7 @@ def ONNXSVMRegressorOp:ONNX_Op<"SVMRegressor", return 1; } static std::vector getTypeMap() { - return {-1}; + return {7}; } }]; } @@ -6106,7 +6106,7 @@ def ONNXScalerOp:ONNX_Op<"Scaler", return 1; } static std::vector getTypeMap() { - return {-1}; + return {7}; } }]; } @@ -6153,7 +6153,7 @@ def ONNXTreeEnsembleClassifierOp:ONNX_Op<"TreeEnsembleClassifier", return 2; } static std::vector getTypeMap() { - return {-1,-1}; + return {-1,7}; } }]; } @@ -6200,7 +6200,7 @@ def ONNXTreeEnsembleRegressorOp:ONNX_Op<"TreeEnsembleRegressor", return 1; } static std::vector getTypeMap() { - return {-1}; + return {7}; } }]; } diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index 393c266..06c1e47 100644 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -427,7 +427,15 @@ def get_allowed_elem_types(schema, input): # return allowed_types_str # TODO: enable type constraints. if input.typeStr : - tstr = input.typeStr + tstr = input.typeStr + structure, element = get_data_structure_element(tstr); + # In case the type is directly specified + if structure and element : + t = np_type_to_tblgen_attr_type(element) + if t == None : + return allowed_structure, None + else : + return structure, [t] else : return None if schema.type_constraints: