diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index 6396f9c..5f382fe 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -25,7 +25,7 @@ namespace bstd = mpark; #include "FrontendDialectTransformer.hpp" namespace onnx_mlir { -namespace { +namespace detail { /*! * The list of tensors initialized by the ONNX model. @@ -37,6 +37,7 @@ public: FrontendGenImpl(mlir::MLIRContext &context) : context_(context), builder_(&context) { module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); + InitHandlerMap(); } mlir::ModuleOp ImportONNXModel(onnx::ModelProto model) { @@ -52,6 +53,11 @@ private: // mapping between string name and symbol OnnxMlirSymbolMapping frontend_symbols_; + typedef void (onnx_mlir::detail::FrontendGenImpl::*ImportHandlerType)( + const onnx::NodeProto &); + + std::map import_handler_map_; + mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); } /*! @@ -329,7 +335,7 @@ private: /*! * Special handle for MaxPool operations. */ - void ImportNodeMaxPool(onnx::NodeProto node) { + void ImportNodeMaxPool(const onnx::NodeProto &node) { int nOuts = node.output().size(); if (nOuts == 1) { buildOperation(node); @@ -341,7 +347,7 @@ private: /*! * Special handle for BatchNormalization operations. */ - void ImportNodeBatchNormalization(onnx::NodeProto node) { + void ImportNodeBatchNormalization(const onnx::NodeProto &node) { int nOuts = node.output().size(); if (nOuts == 1) { // Test mode with one output. @@ -355,7 +361,7 @@ private: /*! * Special handle for Pad operations. */ - void ImportNodePad(onnx::NodeProto node) { + void ImportNodePad(const onnx::NodeProto &node) { int nOps = node.input().size(); if (nOps == 2) { @@ -400,12 +406,16 @@ private: // the generic operator is used // one known reeason is the optional input + (this->*(import_handler_map_[opName.str()]))(node); + } + + void InitHandlerMap() { #include "src/Builder/OpBuildTable.inc" } /*! * Import output tensor, by doing the following: - * - Add the type of this output tensor to a list of tensor + * - Add the t/yp this output tensor to a list of tensor * types representing return types of this graph function. * - Add this output tensor to the list of mlir::Value * to be returned by the function representing computation graph. @@ -499,7 +509,7 @@ private: mainFunc.setType(funcType); } }; // FrontendGenImpl class -} // namespace +} // namespace detail } // namespace onnx_mlir namespace onnx_mlir { @@ -512,7 +522,7 @@ void ImportFrontendModelFile(std::string model_fname, auto parse_success = model.ParseFromIstream(&input); assert(parse_success && "Onnx Model Parsing Failed."); - FrontendGenImpl myONNXGen(context); + detail::FrontendGenImpl myONNXGen(context); module = myONNXGen.ImportONNXModel(model); } } // namespace onnx_mlir diff --git a/src/Builder/OpBuildTable.inc b/src/Builder/OpBuildTable.inc index b13c619..dcd78af 100644 --- a/src/Builder/OpBuildTable.inc +++ b/src/Builder/OpBuildTable.inc @@ -4,351 +4,351 @@ // Details can be found in docs/ImportONNXDefs.md . //******************************************************** -if (opName == "Abs") - buildOperation(node); -if (opName == "Acos") - buildOperation(node); -if (opName == "Acosh") - buildOperation(node); -if (opName == "Add") - buildOperation(node); -if (opName == "And") - buildOperation(node); -if (opName == "ArgMax") - buildOperation(node); -if (opName == "ArgMin") - buildOperation(node); -if (opName == "Asin") - buildOperation(node); -if (opName == "Asinh") - buildOperation(node); -if (opName == "Atan") - buildOperation(node); -if (opName == "Atanh") - buildOperation(node); -if (opName == "AveragePool") - buildOperation(node); -if (opName == "BatchNormalization") - ImportNodeBatchNormalization(node); -if (opName == "BitShift") - buildOperation(node); -if (opName == "Cast") - buildOperation(node); -if (opName == "Ceil") - buildOperation(node); -if (opName == "Clip") - buildOperation(node); -if (opName == "Compress") - buildOperation(node); -if (opName == "Concat") - buildOperation(node); -if (opName == "ConcatFromSequence") - buildOperation(node); -if (opName == "Constant") - buildOperation(node); -if (opName == "ConstantOfShape") - buildOperation(node); -if (opName == "Conv") - buildOperation(node); -if (opName == "ConvInteger") - buildOperation(node); -if (opName == "ConvTranspose") - buildOperation(node); -if (opName == "Cos") - buildOperation(node); -if (opName == "Cosh") - buildOperation(node); -if (opName == "CumSum") - buildOperation(node); -if (opName == "DepthToSpace") - buildOperation(node); -if (opName == "DequantizeLinear") - buildOperation(node); -if (opName == "Det") - buildOperation(node); -if (opName == "Div") - buildOperation(node); -if (opName == "Dropout") - buildOperation(node); -if (opName == "DynamicQuantizeLinear") - buildOperation(node); -if (opName == "Elu") - buildOperation(node); -if (opName == "Equal") - buildOperation(node); -if (opName == "Erf") - buildOperation(node); -if (opName == "Exp") - buildOperation(node); -if (opName == "Expand") - buildOperation(node); -if (opName == "EyeLike") - buildOperation(node); -if (opName == "Flatten") - buildOperation(node); -if (opName == "Floor") - buildOperation(node); -if (opName == "GRU") - buildOperation(node); -if (opName == "Gather") - buildOperation(node); -if (opName == "GatherElements") - buildOperation(node); -if (opName == "GatherND") - buildOperation(node); -if (opName == "Gemm") - buildOperation(node); -if (opName == "GlobalAveragePool") - buildOperation(node); -if (opName == "GlobalLpPool") - buildOperation(node); -if (opName == "GlobalMaxPool") - buildOperation(node); -if (opName == "Greater") - buildOperation(node); -if (opName == "HardSigmoid") - buildOperation(node); -if (opName == "Hardmax") - buildOperation(node); -if (opName == "Identity") - buildOperation(node); -if (opName == "If") - buildOperation(node); -if (opName == "InstanceNormalization") - buildOperation(node); -if (opName == "IsInf") - buildOperation(node); -if (opName == "IsNaN") - buildOperation(node); -if (opName == "LRN") - buildOperation(node); -if (opName == "LSTM") - buildOperation(node); -if (opName == "LeakyRelu") - buildOperation(node); -if (opName == "Less") - buildOperation(node); -if (opName == "Log") - buildOperation(node); -if (opName == "LogSoftmax") - buildOperation(node); -if (opName == "Loop") - buildOperation(node); -if (opName == "LpNormalization") - buildOperation(node); -if (opName == "LpPool") - buildOperation(node); -if (opName == "MatMul") - buildOperation(node); -if (opName == "MatMulInteger") - buildOperation(node); -if (opName == "Max") - buildOperation(node); -if (opName == "MaxPool") - ImportNodeMaxPool(node); -if (opName == "MaxRoiPool") - buildOperation(node); -if (opName == "MaxUnpool") - buildOperation(node); -if (opName == "Mean") - buildOperation(node); -if (opName == "MeanVarianceNormalization") - buildOperation(node); -if (opName == "Min") - buildOperation(node); -if (opName == "Mod") - buildOperation(node); -if (opName == "Mul") - buildOperation(node); -if (opName == "Multinomial") - buildOperation(node); -if (opName == "Neg") - buildOperation(node); -if (opName == "NonMaxSuppression") - buildOperation(node); -if (opName == "NonZero") - buildOperation(node); -if (opName == "Not") - buildOperation(node); -if (opName == "OneHot") - buildOperation(node); -if (opName == "Or") - buildOperation(node); -if (opName == "PRelu") - buildOperation(node); -if (opName == "Pad") - ImportNodePad(node); -if (opName == "Pow") - buildOperation(node); -if (opName == "QLinearConv") - buildOperation(node); -if (opName == "QLinearMatMul") - buildOperation(node); -if (opName == "QuantizeLinear") - buildOperation(node); -if (opName == "RNN") - buildOperation(node); -if (opName == "RandomNormal") - buildOperation(node); -if (opName == "RandomNormalLike") - buildOperation(node); -if (opName == "RandomUniform") - buildOperation(node); -if (opName == "RandomUniformLike") - buildOperation(node); -if (opName == "Range") - buildOperation(node); -if (opName == "Reciprocal") - buildOperation(node); -if (opName == "ReduceL1") - buildOperation(node); -if (opName == "ReduceL2") - buildOperation(node); -if (opName == "ReduceLogSum") - buildOperation(node); -if (opName == "ReduceLogSumExp") - buildOperation(node); -if (opName == "ReduceMax") - buildOperation(node); -if (opName == "ReduceMean") - buildOperation(node); -if (opName == "ReduceMin") - buildOperation(node); -if (opName == "ReduceProd") - buildOperation(node); -if (opName == "ReduceSum") - buildOperation(node); -if (opName == "ReduceSumSquare") - buildOperation(node); -if (opName == "Relu") - buildOperation(node); -if (opName == "Reshape") - buildOperation(node); -if (opName == "Resize") - buildOperation(node); -if (opName == "ReverseSequence") - buildOperation(node); -if (opName == "RoiAlign") - buildOperation(node); -if (opName == "Round") - buildOperation(node); -if (opName == "Scan") - buildOperation(node); -if (opName == "Scatter") - buildOperation(node); -if (opName == "ScatterElements") - buildOperation(node); -if (opName == "ScatterND") - buildOperation(node); -if (opName == "Selu") - buildOperation(node); -if (opName == "SequenceAt") - buildOperation(node); -if (opName == "SequenceConstruct") - buildOperation(node); -if (opName == "SequenceEmpty") - buildOperation(node); -if (opName == "SequenceErase") - buildOperation(node); -if (opName == "SequenceInsert") - buildOperation(node); -if (opName == "SequenceLength") - buildOperation(node); -if (opName == "Shape") - buildOperation(node); -if (opName == "Shrink") - buildOperation(node); -if (opName == "Sigmoid") - buildOperation(node); -if (opName == "Sign") - buildOperation(node); -if (opName == "Sin") - buildOperation(node); -if (opName == "Sinh") - buildOperation(node); -if (opName == "Size") - buildOperation(node); -if (opName == "Slice") - buildOperation(node); -if (opName == "Softmax") - buildOperation(node); -if (opName == "Softplus") - buildOperation(node); -if (opName == "Softsign") - buildOperation(node); -if (opName == "SpaceToDepth") - buildOperation(node); -if (opName == "Split") - buildOperation(node); -if (opName == "SplitToSequence") - buildOperation(node); -if (opName == "Sqrt") - buildOperation(node); -if (opName == "Squeeze") - buildOperation(node); -if (opName == "StringNormalizer") - buildOperation(node); -if (opName == "Sub") - buildOperation(node); -if (opName == "Sum") - buildOperation(node); -if (opName == "Tan") - buildOperation(node); -if (opName == "Tanh") - buildOperation(node); -if (opName == "TfIdfVectorizer") - buildOperation(node); -if (opName == "ThresholdedRelu") - buildOperation(node); -if (opName == "Tile") - buildOperation(node); -if (opName == "TopK") - buildOperation(node); -if (opName == "Transpose") - buildOperation(node); -if (opName == "Unique") - buildOperation(node); -if (opName == "Unsqueeze") - buildOperation(node); -if (opName == "Upsample") - buildOperation(node); -if (opName == "Where") - buildOperation(node); -if (opName == "Xor") - buildOperation(node); -if (opName == "ArrayFeatureExtractor") - buildOperation(node); -if (opName == "Binarizer") - buildOperation(node); -if (opName == "CastMap") - buildOperation(node); -if (opName == "CategoryMapper") - buildOperation(node); -if (opName == "DictVectorizer") - buildOperation(node); -if (opName == "FeatureVectorizer") - buildOperation(node); -if (opName == "Imputer") - buildOperation(node); -if (opName == "LabelEncoder") - buildOperation(node); -if (opName == "LinearClassifier") - buildOperation(node); -if (opName == "LinearRegressor") - buildOperation(node); -if (opName == "Normalizer") - buildOperation(node); -if (opName == "OneHotEncoder") - buildOperation(node); -if (opName == "SVMClassifier") - buildOperation(node); -if (opName == "SVMRegressor") - buildOperation(node); -if (opName == "Scaler") - buildOperation(node); -if (opName == "TreeEnsembleClassifier") - buildOperation(node); -if (opName == "TreeEnsembleRegressor") - buildOperation(node); -if (opName == "ZipMap") - buildOperation(node); +import_handler_map_["Abs"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Acos"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Acosh"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Add"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["And"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["ArgMax"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["ArgMin"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Asin"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Asinh"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Atan"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Atanh"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["AveragePool"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["BatchNormalization"] = + &onnx_mlir::detail::FrontendGenImpl::ImportNodeBatchNormalization; +import_handler_map_["BitShift"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Cast"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Ceil"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Clip"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Compress"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Concat"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["ConcatFromSequence"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Constant"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["ConstantOfShape"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Conv"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["ConvInteger"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["ConvTranspose"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Cos"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Cosh"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["CumSum"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["DepthToSpace"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["DequantizeLinear"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Det"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Div"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Dropout"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["DynamicQuantizeLinear"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Elu"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Equal"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Erf"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Exp"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Expand"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["EyeLike"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Flatten"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Floor"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["GRU"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Gather"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["GatherElements"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["GatherND"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Gemm"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["GlobalAveragePool"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["GlobalLpPool"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["GlobalMaxPool"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Greater"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["HardSigmoid"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Hardmax"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Identity"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["If"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["InstanceNormalization"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["IsInf"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["IsNaN"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["LRN"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["LSTM"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["LeakyRelu"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Less"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Log"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["LogSoftmax"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Loop"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["LpNormalization"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["LpPool"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["MatMul"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["MatMulInteger"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Max"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["MaxPool"] = + &onnx_mlir::detail::FrontendGenImpl::ImportNodeMaxPool; +import_handler_map_["MaxRoiPool"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["MaxUnpool"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Mean"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["MeanVarianceNormalization"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Min"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Mod"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Mul"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Multinomial"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Neg"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["NonMaxSuppression"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["NonZero"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Not"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["OneHot"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Or"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["PRelu"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Pad"] = + &onnx_mlir::detail::FrontendGenImpl::ImportNodePad; +import_handler_map_["Pow"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["QLinearConv"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["QLinearMatMul"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["QuantizeLinear"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["RNN"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["RandomNormal"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["RandomNormalLike"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["RandomUniform"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["RandomUniformLike"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Range"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Reciprocal"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["ReduceL1"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["ReduceL2"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["ReduceLogSum"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["ReduceLogSumExp"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["ReduceMax"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["ReduceMean"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["ReduceMin"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["ReduceProd"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["ReduceSum"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["ReduceSumSquare"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Relu"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Reshape"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Resize"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["ReverseSequence"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["RoiAlign"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Round"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Scan"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Scatter"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["ScatterElements"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["ScatterND"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Selu"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["SequenceAt"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["SequenceConstruct"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["SequenceEmpty"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["SequenceErase"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["SequenceInsert"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["SequenceLength"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Shape"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Shrink"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Sigmoid"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Sign"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Sin"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Sinh"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Size"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Slice"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Softmax"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Softplus"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Softsign"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["SpaceToDepth"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Split"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["SplitToSequence"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Sqrt"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Squeeze"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["StringNormalizer"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Sub"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Sum"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Tan"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Tanh"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["TfIdfVectorizer"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["ThresholdedRelu"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Tile"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["TopK"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Transpose"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Unique"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Unsqueeze"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Upsample"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Where"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Xor"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["ArrayFeatureExtractor"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Binarizer"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["CastMap"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["CategoryMapper"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["DictVectorizer"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["FeatureVectorizer"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Imputer"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["LabelEncoder"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["LinearClassifier"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["LinearRegressor"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Normalizer"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["OneHotEncoder"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["SVMClassifier"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["SVMRegressor"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["Scaler"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["TreeEnsembleClassifier"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["TreeEnsembleRegressor"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; +import_handler_map_["ZipMap"] = + &onnx_mlir::detail::FrontendGenImpl::buildOperation; diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index 80369e9..d2d4360 100644 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -954,7 +954,7 @@ special cases: def gen_op_importer(schema, file): indent = inc_indent() - s = indent + 'if (opName == "' + schema.name + '")\n' + s = indent + 'import_handler_map_["' + schema.name +'"] = \n ' expected_num_operands = len(schema.inputs) expected_num_results = len(schema.outputs) @@ -978,8 +978,8 @@ def gen_op_importer(schema, file): args.append( '/* expected_num_results = */ {}'.format(expected_num_results)) """ - s += inc_indent(indent) + " {}({});\n".format( - handler_func, ", ".join(args)) + s += inc_indent(indent) + '&onnx_mlir::detail::FrontendGenImpl::' + s += handler_func+';\n' file.write(s)