Use map when Import onnx node (#230)

* add map

* generate map
This commit is contained in:
chentong319 2020-07-27 12:25:21 -04:00 committed by GitHub
parent 8af5fdeb62
commit 32ceb6968a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 368 additions and 358 deletions

View File

@ -25,7 +25,7 @@ namespace bstd = mpark;
#include "FrontendDialectTransformer.hpp"
namespace onnx_mlir {
namespace {
namespace detail {
/*!
* The list of tensors initialized by the ONNX model.
@ -37,6 +37,7 @@ public:
FrontendGenImpl(mlir::MLIRContext &context)
: context_(context), builder_(&context) {
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
InitHandlerMap();
}
mlir::ModuleOp ImportONNXModel(onnx::ModelProto model) {
@ -52,6 +53,11 @@ private:
// mapping between string name and symbol
OnnxMlirSymbolMapping frontend_symbols_;
typedef void (onnx_mlir::detail::FrontendGenImpl::*ImportHandlerType)(
const onnx::NodeProto &);
std::map<std::string, ImportHandlerType> import_handler_map_;
mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); }
/*!
@ -329,7 +335,7 @@ private:
/*!
* Special handle for MaxPool operations.
*/
void ImportNodeMaxPool(onnx::NodeProto node) {
void ImportNodeMaxPool(const onnx::NodeProto &node) {
int nOuts = node.output().size();
if (nOuts == 1) {
buildOperation<mlir::ONNXMaxPoolSingleOutOp>(node);
@ -341,7 +347,7 @@ private:
/*!
* Special handle for BatchNormalization operations.
*/
void ImportNodeBatchNormalization(onnx::NodeProto node) {
void ImportNodeBatchNormalization(const onnx::NodeProto &node) {
int nOuts = node.output().size();
if (nOuts == 1) {
// Test mode with one output.
@ -355,7 +361,7 @@ private:
/*!
* Special handle for Pad operations.
*/
void ImportNodePad(onnx::NodeProto node) {
void ImportNodePad(const onnx::NodeProto &node) {
int nOps = node.input().size();
if (nOps == 2) {
@ -400,12 +406,16 @@ private:
// the generic operator is used
// one known reeason is the optional input
(this->*(import_handler_map_[opName.str()]))(node);
}
void InitHandlerMap() {
#include "src/Builder/OpBuildTable.inc"
}
/*!
* Import output tensor, by doing the following:
* - Add the type of this output tensor to a list of tensor
* - Add the t/yp this output tensor to a list of tensor
* types representing return types of this graph function.
* - Add this output tensor to the list of mlir::Value
* to be returned by the function representing computation graph.
@ -499,7 +509,7 @@ private:
mainFunc.setType(funcType);
}
}; // FrontendGenImpl class
} // namespace
} // namespace detail
} // namespace onnx_mlir
namespace onnx_mlir {
@ -512,7 +522,7 @@ void ImportFrontendModelFile(std::string model_fname,
auto parse_success = model.ParseFromIstream(&input);
assert(parse_success && "Onnx Model Parsing Failed.");
FrontendGenImpl myONNXGen(context);
detail::FrontendGenImpl myONNXGen(context);
module = myONNXGen.ImportONNXModel(model);
}
} // namespace onnx_mlir

View File

@ -4,351 +4,351 @@
// Details can be found in docs/ImportONNXDefs.md .
//********************************************************
if (opName == "Abs")
buildOperation<mlir::ONNXAbsOp>(node);
if (opName == "Acos")
buildOperation<mlir::ONNXAcosOp>(node);
if (opName == "Acosh")
buildOperation<mlir::ONNXAcoshOp>(node);
if (opName == "Add")
buildOperation<mlir::ONNXAddOp>(node);
if (opName == "And")
buildOperation<mlir::ONNXAndOp>(node);
if (opName == "ArgMax")
buildOperation<mlir::ONNXArgMaxOp>(node);
if (opName == "ArgMin")
buildOperation<mlir::ONNXArgMinOp>(node);
if (opName == "Asin")
buildOperation<mlir::ONNXAsinOp>(node);
if (opName == "Asinh")
buildOperation<mlir::ONNXAsinhOp>(node);
if (opName == "Atan")
buildOperation<mlir::ONNXAtanOp>(node);
if (opName == "Atanh")
buildOperation<mlir::ONNXAtanhOp>(node);
if (opName == "AveragePool")
buildOperation<mlir::ONNXAveragePoolOp>(node);
if (opName == "BatchNormalization")
ImportNodeBatchNormalization(node);
if (opName == "BitShift")
buildOperation<mlir::ONNXBitShiftOp>(node);
if (opName == "Cast")
buildOperation<mlir::ONNXCastOp>(node);
if (opName == "Ceil")
buildOperation<mlir::ONNXCeilOp>(node);
if (opName == "Clip")
buildOperation<mlir::ONNXClipOp>(node);
if (opName == "Compress")
buildOperation<mlir::ONNXCompressOp>(node);
if (opName == "Concat")
buildOperation<mlir::ONNXConcatOp>(node);
if (opName == "ConcatFromSequence")
buildOperation<mlir::ONNXConcatFromSequenceOp>(node);
if (opName == "Constant")
buildOperation<mlir::ONNXConstantOp>(node);
if (opName == "ConstantOfShape")
buildOperation<mlir::ONNXConstantOfShapeOp>(node);
if (opName == "Conv")
buildOperation<mlir::ONNXConvOp>(node);
if (opName == "ConvInteger")
buildOperation<mlir::ONNXConvIntegerOp>(node);
if (opName == "ConvTranspose")
buildOperation<mlir::ONNXConvTransposeOp>(node);
if (opName == "Cos")
buildOperation<mlir::ONNXCosOp>(node);
if (opName == "Cosh")
buildOperation<mlir::ONNXCoshOp>(node);
if (opName == "CumSum")
buildOperation<mlir::ONNXCumSumOp>(node);
if (opName == "DepthToSpace")
buildOperation<mlir::ONNXDepthToSpaceOp>(node);
if (opName == "DequantizeLinear")
buildOperation<mlir::ONNXDequantizeLinearOp>(node);
if (opName == "Det")
buildOperation<mlir::ONNXDetOp>(node);
if (opName == "Div")
buildOperation<mlir::ONNXDivOp>(node);
if (opName == "Dropout")
buildOperation<mlir::ONNXDropoutOp>(node);
if (opName == "DynamicQuantizeLinear")
buildOperation<mlir::ONNXDynamicQuantizeLinearOp>(node);
if (opName == "Elu")
buildOperation<mlir::ONNXEluOp>(node);
if (opName == "Equal")
buildOperation<mlir::ONNXEqualOp>(node);
if (opName == "Erf")
buildOperation<mlir::ONNXErfOp>(node);
if (opName == "Exp")
buildOperation<mlir::ONNXExpOp>(node);
if (opName == "Expand")
buildOperation<mlir::ONNXExpandOp>(node);
if (opName == "EyeLike")
buildOperation<mlir::ONNXEyeLikeOp>(node);
if (opName == "Flatten")
buildOperation<mlir::ONNXFlattenOp>(node);
if (opName == "Floor")
buildOperation<mlir::ONNXFloorOp>(node);
if (opName == "GRU")
buildOperation<mlir::ONNXGRUOp>(node);
if (opName == "Gather")
buildOperation<mlir::ONNXGatherOp>(node);
if (opName == "GatherElements")
buildOperation<mlir::ONNXGatherElementsOp>(node);
if (opName == "GatherND")
buildOperation<mlir::ONNXGatherNDOp>(node);
if (opName == "Gemm")
buildOperation<mlir::ONNXGemmOp>(node);
if (opName == "GlobalAveragePool")
buildOperation<mlir::ONNXGlobalAveragePoolOp>(node);
if (opName == "GlobalLpPool")
buildOperation<mlir::ONNXGlobalLpPoolOp>(node);
if (opName == "GlobalMaxPool")
buildOperation<mlir::ONNXGlobalMaxPoolOp>(node);
if (opName == "Greater")
buildOperation<mlir::ONNXGreaterOp>(node);
if (opName == "HardSigmoid")
buildOperation<mlir::ONNXHardSigmoidOp>(node);
if (opName == "Hardmax")
buildOperation<mlir::ONNXHardmaxOp>(node);
if (opName == "Identity")
buildOperation<mlir::ONNXIdentityOp>(node);
if (opName == "If")
buildOperation<mlir::ONNXIfOp>(node);
if (opName == "InstanceNormalization")
buildOperation<mlir::ONNXInstanceNormalizationOp>(node);
if (opName == "IsInf")
buildOperation<mlir::ONNXIsInfOp>(node);
if (opName == "IsNaN")
buildOperation<mlir::ONNXIsNaNOp>(node);
if (opName == "LRN")
buildOperation<mlir::ONNXLRNOp>(node);
if (opName == "LSTM")
buildOperation<mlir::ONNXLSTMOp>(node);
if (opName == "LeakyRelu")
buildOperation<mlir::ONNXLeakyReluOp>(node);
if (opName == "Less")
buildOperation<mlir::ONNXLessOp>(node);
if (opName == "Log")
buildOperation<mlir::ONNXLogOp>(node);
if (opName == "LogSoftmax")
buildOperation<mlir::ONNXLogSoftmaxOp>(node);
if (opName == "Loop")
buildOperation<mlir::ONNXLoopOp>(node);
if (opName == "LpNormalization")
buildOperation<mlir::ONNXLpNormalizationOp>(node);
if (opName == "LpPool")
buildOperation<mlir::ONNXLpPoolOp>(node);
if (opName == "MatMul")
buildOperation<mlir::ONNXMatMulOp>(node);
if (opName == "MatMulInteger")
buildOperation<mlir::ONNXMatMulIntegerOp>(node);
if (opName == "Max")
buildOperation<mlir::ONNXMaxOp>(node);
if (opName == "MaxPool")
ImportNodeMaxPool(node);
if (opName == "MaxRoiPool")
buildOperation<mlir::ONNXMaxRoiPoolOp>(node);
if (opName == "MaxUnpool")
buildOperation<mlir::ONNXMaxUnpoolOp>(node);
if (opName == "Mean")
buildOperation<mlir::ONNXMeanOp>(node);
if (opName == "MeanVarianceNormalization")
buildOperation<mlir::ONNXMeanVarianceNormalizationOp>(node);
if (opName == "Min")
buildOperation<mlir::ONNXMinOp>(node);
if (opName == "Mod")
buildOperation<mlir::ONNXModOp>(node);
if (opName == "Mul")
buildOperation<mlir::ONNXMulOp>(node);
if (opName == "Multinomial")
buildOperation<mlir::ONNXMultinomialOp>(node);
if (opName == "Neg")
buildOperation<mlir::ONNXNegOp>(node);
if (opName == "NonMaxSuppression")
buildOperation<mlir::ONNXNonMaxSuppressionOp>(node);
if (opName == "NonZero")
buildOperation<mlir::ONNXNonZeroOp>(node);
if (opName == "Not")
buildOperation<mlir::ONNXNotOp>(node);
if (opName == "OneHot")
buildOperation<mlir::ONNXOneHotOp>(node);
if (opName == "Or")
buildOperation<mlir::ONNXOrOp>(node);
if (opName == "PRelu")
buildOperation<mlir::ONNXPReluOp>(node);
if (opName == "Pad")
ImportNodePad(node);
if (opName == "Pow")
buildOperation<mlir::ONNXPowOp>(node);
if (opName == "QLinearConv")
buildOperation<mlir::ONNXQLinearConvOp>(node);
if (opName == "QLinearMatMul")
buildOperation<mlir::ONNXQLinearMatMulOp>(node);
if (opName == "QuantizeLinear")
buildOperation<mlir::ONNXQuantizeLinearOp>(node);
if (opName == "RNN")
buildOperation<mlir::ONNXRNNOp>(node);
if (opName == "RandomNormal")
buildOperation<mlir::ONNXRandomNormalOp>(node);
if (opName == "RandomNormalLike")
buildOperation<mlir::ONNXRandomNormalLikeOp>(node);
if (opName == "RandomUniform")
buildOperation<mlir::ONNXRandomUniformOp>(node);
if (opName == "RandomUniformLike")
buildOperation<mlir::ONNXRandomUniformLikeOp>(node);
if (opName == "Range")
buildOperation<mlir::ONNXRangeOp>(node);
if (opName == "Reciprocal")
buildOperation<mlir::ONNXReciprocalOp>(node);
if (opName == "ReduceL1")
buildOperation<mlir::ONNXReduceL1Op>(node);
if (opName == "ReduceL2")
buildOperation<mlir::ONNXReduceL2Op>(node);
if (opName == "ReduceLogSum")
buildOperation<mlir::ONNXReduceLogSumOp>(node);
if (opName == "ReduceLogSumExp")
buildOperation<mlir::ONNXReduceLogSumExpOp>(node);
if (opName == "ReduceMax")
buildOperation<mlir::ONNXReduceMaxOp>(node);
if (opName == "ReduceMean")
buildOperation<mlir::ONNXReduceMeanOp>(node);
if (opName == "ReduceMin")
buildOperation<mlir::ONNXReduceMinOp>(node);
if (opName == "ReduceProd")
buildOperation<mlir::ONNXReduceProdOp>(node);
if (opName == "ReduceSum")
buildOperation<mlir::ONNXReduceSumOp>(node);
if (opName == "ReduceSumSquare")
buildOperation<mlir::ONNXReduceSumSquareOp>(node);
if (opName == "Relu")
buildOperation<mlir::ONNXReluOp>(node);
if (opName == "Reshape")
buildOperation<mlir::ONNXReshapeOp>(node);
if (opName == "Resize")
buildOperation<mlir::ONNXResizeOp>(node);
if (opName == "ReverseSequence")
buildOperation<mlir::ONNXReverseSequenceOp>(node);
if (opName == "RoiAlign")
buildOperation<mlir::ONNXRoiAlignOp>(node);
if (opName == "Round")
buildOperation<mlir::ONNXRoundOp>(node);
if (opName == "Scan")
buildOperation<mlir::ONNXScanOp>(node);
if (opName == "Scatter")
buildOperation<mlir::ONNXScatterOp>(node);
if (opName == "ScatterElements")
buildOperation<mlir::ONNXScatterElementsOp>(node);
if (opName == "ScatterND")
buildOperation<mlir::ONNXScatterNDOp>(node);
if (opName == "Selu")
buildOperation<mlir::ONNXSeluOp>(node);
if (opName == "SequenceAt")
buildOperation<mlir::ONNXSequenceAtOp>(node);
if (opName == "SequenceConstruct")
buildOperation<mlir::ONNXSequenceConstructOp>(node);
if (opName == "SequenceEmpty")
buildOperation<mlir::ONNXSequenceEmptyOp>(node);
if (opName == "SequenceErase")
buildOperation<mlir::ONNXSequenceEraseOp>(node);
if (opName == "SequenceInsert")
buildOperation<mlir::ONNXSequenceInsertOp>(node);
if (opName == "SequenceLength")
buildOperation<mlir::ONNXSequenceLengthOp>(node);
if (opName == "Shape")
buildOperation<mlir::ONNXShapeOp>(node);
if (opName == "Shrink")
buildOperation<mlir::ONNXShrinkOp>(node);
if (opName == "Sigmoid")
buildOperation<mlir::ONNXSigmoidOp>(node);
if (opName == "Sign")
buildOperation<mlir::ONNXSignOp>(node);
if (opName == "Sin")
buildOperation<mlir::ONNXSinOp>(node);
if (opName == "Sinh")
buildOperation<mlir::ONNXSinhOp>(node);
if (opName == "Size")
buildOperation<mlir::ONNXSizeOp>(node);
if (opName == "Slice")
buildOperation<mlir::ONNXSliceOp>(node);
if (opName == "Softmax")
buildOperation<mlir::ONNXSoftmaxOp>(node);
if (opName == "Softplus")
buildOperation<mlir::ONNXSoftplusOp>(node);
if (opName == "Softsign")
buildOperation<mlir::ONNXSoftsignOp>(node);
if (opName == "SpaceToDepth")
buildOperation<mlir::ONNXSpaceToDepthOp>(node);
if (opName == "Split")
buildOperation<mlir::ONNXSplitOp>(node);
if (opName == "SplitToSequence")
buildOperation<mlir::ONNXSplitToSequenceOp>(node);
if (opName == "Sqrt")
buildOperation<mlir::ONNXSqrtOp>(node);
if (opName == "Squeeze")
buildOperation<mlir::ONNXSqueezeOp>(node);
if (opName == "StringNormalizer")
buildOperation<mlir::ONNXStringNormalizerOp>(node);
if (opName == "Sub")
buildOperation<mlir::ONNXSubOp>(node);
if (opName == "Sum")
buildOperation<mlir::ONNXSumOp>(node);
if (opName == "Tan")
buildOperation<mlir::ONNXTanOp>(node);
if (opName == "Tanh")
buildOperation<mlir::ONNXTanhOp>(node);
if (opName == "TfIdfVectorizer")
buildOperation<mlir::ONNXTfIdfVectorizerOp>(node);
if (opName == "ThresholdedRelu")
buildOperation<mlir::ONNXThresholdedReluOp>(node);
if (opName == "Tile")
buildOperation<mlir::ONNXTileOp>(node);
if (opName == "TopK")
buildOperation<mlir::ONNXTopKOp>(node);
if (opName == "Transpose")
buildOperation<mlir::ONNXTransposeOp>(node);
if (opName == "Unique")
buildOperation<mlir::ONNXUniqueOp>(node);
if (opName == "Unsqueeze")
buildOperation<mlir::ONNXUnsqueezeOp>(node);
if (opName == "Upsample")
buildOperation<mlir::ONNXUpsampleOp>(node);
if (opName == "Where")
buildOperation<mlir::ONNXWhereOp>(node);
if (opName == "Xor")
buildOperation<mlir::ONNXXorOp>(node);
if (opName == "ArrayFeatureExtractor")
buildOperation<mlir::ONNXArrayFeatureExtractorOp>(node);
if (opName == "Binarizer")
buildOperation<mlir::ONNXBinarizerOp>(node);
if (opName == "CastMap")
buildOperation<mlir::ONNXCastMapOp>(node);
if (opName == "CategoryMapper")
buildOperation<mlir::ONNXCategoryMapperOp>(node);
if (opName == "DictVectorizer")
buildOperation<mlir::ONNXDictVectorizerOp>(node);
if (opName == "FeatureVectorizer")
buildOperation<mlir::ONNXFeatureVectorizerOp>(node);
if (opName == "Imputer")
buildOperation<mlir::ONNXImputerOp>(node);
if (opName == "LabelEncoder")
buildOperation<mlir::ONNXLabelEncoderOp>(node);
if (opName == "LinearClassifier")
buildOperation<mlir::ONNXLinearClassifierOp>(node);
if (opName == "LinearRegressor")
buildOperation<mlir::ONNXLinearRegressorOp>(node);
if (opName == "Normalizer")
buildOperation<mlir::ONNXNormalizerOp>(node);
if (opName == "OneHotEncoder")
buildOperation<mlir::ONNXOneHotEncoderOp>(node);
if (opName == "SVMClassifier")
buildOperation<mlir::ONNXSVMClassifierOp>(node);
if (opName == "SVMRegressor")
buildOperation<mlir::ONNXSVMRegressorOp>(node);
if (opName == "Scaler")
buildOperation<mlir::ONNXScalerOp>(node);
if (opName == "TreeEnsembleClassifier")
buildOperation<mlir::ONNXTreeEnsembleClassifierOp>(node);
if (opName == "TreeEnsembleRegressor")
buildOperation<mlir::ONNXTreeEnsembleRegressorOp>(node);
if (opName == "ZipMap")
buildOperation<mlir::ONNXZipMapOp>(node);
import_handler_map_["Abs"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXAbsOp>;
import_handler_map_["Acos"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXAcosOp>;
import_handler_map_["Acosh"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXAcoshOp>;
import_handler_map_["Add"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXAddOp>;
import_handler_map_["And"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXAndOp>;
import_handler_map_["ArgMax"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXArgMaxOp>;
import_handler_map_["ArgMin"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXArgMinOp>;
import_handler_map_["Asin"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXAsinOp>;
import_handler_map_["Asinh"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXAsinhOp>;
import_handler_map_["Atan"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXAtanOp>;
import_handler_map_["Atanh"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXAtanhOp>;
import_handler_map_["AveragePool"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXAveragePoolOp>;
import_handler_map_["BatchNormalization"] =
&onnx_mlir::detail::FrontendGenImpl::ImportNodeBatchNormalization;
import_handler_map_["BitShift"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXBitShiftOp>;
import_handler_map_["Cast"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXCastOp>;
import_handler_map_["Ceil"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXCeilOp>;
import_handler_map_["Clip"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXClipOp>;
import_handler_map_["Compress"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXCompressOp>;
import_handler_map_["Concat"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXConcatOp>;
import_handler_map_["ConcatFromSequence"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXConcatFromSequenceOp>;
import_handler_map_["Constant"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXConstantOp>;
import_handler_map_["ConstantOfShape"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXConstantOfShapeOp>;
import_handler_map_["Conv"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXConvOp>;
import_handler_map_["ConvInteger"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXConvIntegerOp>;
import_handler_map_["ConvTranspose"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXConvTransposeOp>;
import_handler_map_["Cos"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXCosOp>;
import_handler_map_["Cosh"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXCoshOp>;
import_handler_map_["CumSum"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXCumSumOp>;
import_handler_map_["DepthToSpace"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXDepthToSpaceOp>;
import_handler_map_["DequantizeLinear"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXDequantizeLinearOp>;
import_handler_map_["Det"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXDetOp>;
import_handler_map_["Div"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXDivOp>;
import_handler_map_["Dropout"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXDropoutOp>;
import_handler_map_["DynamicQuantizeLinear"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXDynamicQuantizeLinearOp>;
import_handler_map_["Elu"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXEluOp>;
import_handler_map_["Equal"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXEqualOp>;
import_handler_map_["Erf"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXErfOp>;
import_handler_map_["Exp"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXExpOp>;
import_handler_map_["Expand"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXExpandOp>;
import_handler_map_["EyeLike"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXEyeLikeOp>;
import_handler_map_["Flatten"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXFlattenOp>;
import_handler_map_["Floor"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXFloorOp>;
import_handler_map_["GRU"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGRUOp>;
import_handler_map_["Gather"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGatherOp>;
import_handler_map_["GatherElements"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGatherElementsOp>;
import_handler_map_["GatherND"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGatherNDOp>;
import_handler_map_["Gemm"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGemmOp>;
import_handler_map_["GlobalAveragePool"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGlobalAveragePoolOp>;
import_handler_map_["GlobalLpPool"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGlobalLpPoolOp>;
import_handler_map_["GlobalMaxPool"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGlobalMaxPoolOp>;
import_handler_map_["Greater"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGreaterOp>;
import_handler_map_["HardSigmoid"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXHardSigmoidOp>;
import_handler_map_["Hardmax"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXHardmaxOp>;
import_handler_map_["Identity"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXIdentityOp>;
import_handler_map_["If"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXIfOp>;
import_handler_map_["InstanceNormalization"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXInstanceNormalizationOp>;
import_handler_map_["IsInf"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXIsInfOp>;
import_handler_map_["IsNaN"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXIsNaNOp>;
import_handler_map_["LRN"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXLRNOp>;
import_handler_map_["LSTM"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXLSTMOp>;
import_handler_map_["LeakyRelu"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXLeakyReluOp>;
import_handler_map_["Less"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXLessOp>;
import_handler_map_["Log"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXLogOp>;
import_handler_map_["LogSoftmax"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXLogSoftmaxOp>;
import_handler_map_["Loop"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXLoopOp>;
import_handler_map_["LpNormalization"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXLpNormalizationOp>;
import_handler_map_["LpPool"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXLpPoolOp>;
import_handler_map_["MatMul"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXMatMulOp>;
import_handler_map_["MatMulInteger"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXMatMulIntegerOp>;
import_handler_map_["Max"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXMaxOp>;
import_handler_map_["MaxPool"] =
&onnx_mlir::detail::FrontendGenImpl::ImportNodeMaxPool;
import_handler_map_["MaxRoiPool"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXMaxRoiPoolOp>;
import_handler_map_["MaxUnpool"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXMaxUnpoolOp>;
import_handler_map_["Mean"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXMeanOp>;
import_handler_map_["MeanVarianceNormalization"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXMeanVarianceNormalizationOp>;
import_handler_map_["Min"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXMinOp>;
import_handler_map_["Mod"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXModOp>;
import_handler_map_["Mul"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXMulOp>;
import_handler_map_["Multinomial"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXMultinomialOp>;
import_handler_map_["Neg"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXNegOp>;
import_handler_map_["NonMaxSuppression"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXNonMaxSuppressionOp>;
import_handler_map_["NonZero"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXNonZeroOp>;
import_handler_map_["Not"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXNotOp>;
import_handler_map_["OneHot"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXOneHotOp>;
import_handler_map_["Or"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXOrOp>;
import_handler_map_["PRelu"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXPReluOp>;
import_handler_map_["Pad"] =
&onnx_mlir::detail::FrontendGenImpl::ImportNodePad;
import_handler_map_["Pow"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXPowOp>;
import_handler_map_["QLinearConv"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXQLinearConvOp>;
import_handler_map_["QLinearMatMul"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXQLinearMatMulOp>;
import_handler_map_["QuantizeLinear"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXQuantizeLinearOp>;
import_handler_map_["RNN"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXRNNOp>;
import_handler_map_["RandomNormal"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXRandomNormalOp>;
import_handler_map_["RandomNormalLike"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXRandomNormalLikeOp>;
import_handler_map_["RandomUniform"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXRandomUniformOp>;
import_handler_map_["RandomUniformLike"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXRandomUniformLikeOp>;
import_handler_map_["Range"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXRangeOp>;
import_handler_map_["Reciprocal"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReciprocalOp>;
import_handler_map_["ReduceL1"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReduceL1Op>;
import_handler_map_["ReduceL2"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReduceL2Op>;
import_handler_map_["ReduceLogSum"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReduceLogSumOp>;
import_handler_map_["ReduceLogSumExp"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReduceLogSumExpOp>;
import_handler_map_["ReduceMax"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReduceMaxOp>;
import_handler_map_["ReduceMean"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReduceMeanOp>;
import_handler_map_["ReduceMin"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReduceMinOp>;
import_handler_map_["ReduceProd"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReduceProdOp>;
import_handler_map_["ReduceSum"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReduceSumOp>;
import_handler_map_["ReduceSumSquare"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReduceSumSquareOp>;
import_handler_map_["Relu"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReluOp>;
import_handler_map_["Reshape"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReshapeOp>;
import_handler_map_["Resize"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXResizeOp>;
import_handler_map_["ReverseSequence"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReverseSequenceOp>;
import_handler_map_["RoiAlign"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXRoiAlignOp>;
import_handler_map_["Round"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXRoundOp>;
import_handler_map_["Scan"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXScanOp>;
import_handler_map_["Scatter"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXScatterOp>;
import_handler_map_["ScatterElements"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXScatterElementsOp>;
import_handler_map_["ScatterND"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXScatterNDOp>;
import_handler_map_["Selu"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSeluOp>;
import_handler_map_["SequenceAt"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSequenceAtOp>;
import_handler_map_["SequenceConstruct"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSequenceConstructOp>;
import_handler_map_["SequenceEmpty"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSequenceEmptyOp>;
import_handler_map_["SequenceErase"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSequenceEraseOp>;
import_handler_map_["SequenceInsert"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSequenceInsertOp>;
import_handler_map_["SequenceLength"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSequenceLengthOp>;
import_handler_map_["Shape"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXShapeOp>;
import_handler_map_["Shrink"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXShrinkOp>;
import_handler_map_["Sigmoid"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSigmoidOp>;
import_handler_map_["Sign"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSignOp>;
import_handler_map_["Sin"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSinOp>;
import_handler_map_["Sinh"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSinhOp>;
import_handler_map_["Size"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSizeOp>;
import_handler_map_["Slice"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSliceOp>;
import_handler_map_["Softmax"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSoftmaxOp>;
import_handler_map_["Softplus"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSoftplusOp>;
import_handler_map_["Softsign"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSoftsignOp>;
import_handler_map_["SpaceToDepth"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSpaceToDepthOp>;
import_handler_map_["Split"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSplitOp>;
import_handler_map_["SplitToSequence"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSplitToSequenceOp>;
import_handler_map_["Sqrt"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSqrtOp>;
import_handler_map_["Squeeze"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSqueezeOp>;
import_handler_map_["StringNormalizer"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXStringNormalizerOp>;
import_handler_map_["Sub"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSubOp>;
import_handler_map_["Sum"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSumOp>;
import_handler_map_["Tan"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXTanOp>;
import_handler_map_["Tanh"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXTanhOp>;
import_handler_map_["TfIdfVectorizer"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXTfIdfVectorizerOp>;
import_handler_map_["ThresholdedRelu"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXThresholdedReluOp>;
import_handler_map_["Tile"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXTileOp>;
import_handler_map_["TopK"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXTopKOp>;
import_handler_map_["Transpose"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXTransposeOp>;
import_handler_map_["Unique"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXUniqueOp>;
import_handler_map_["Unsqueeze"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXUnsqueezeOp>;
import_handler_map_["Upsample"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXUpsampleOp>;
import_handler_map_["Where"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXWhereOp>;
import_handler_map_["Xor"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXXorOp>;
import_handler_map_["ArrayFeatureExtractor"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXArrayFeatureExtractorOp>;
import_handler_map_["Binarizer"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXBinarizerOp>;
import_handler_map_["CastMap"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXCastMapOp>;
import_handler_map_["CategoryMapper"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXCategoryMapperOp>;
import_handler_map_["DictVectorizer"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXDictVectorizerOp>;
import_handler_map_["FeatureVectorizer"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXFeatureVectorizerOp>;
import_handler_map_["Imputer"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXImputerOp>;
import_handler_map_["LabelEncoder"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXLabelEncoderOp>;
import_handler_map_["LinearClassifier"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXLinearClassifierOp>;
import_handler_map_["LinearRegressor"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXLinearRegressorOp>;
import_handler_map_["Normalizer"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXNormalizerOp>;
import_handler_map_["OneHotEncoder"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXOneHotEncoderOp>;
import_handler_map_["SVMClassifier"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSVMClassifierOp>;
import_handler_map_["SVMRegressor"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSVMRegressorOp>;
import_handler_map_["Scaler"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXScalerOp>;
import_handler_map_["TreeEnsembleClassifier"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXTreeEnsembleClassifierOp>;
import_handler_map_["TreeEnsembleRegressor"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXTreeEnsembleRegressorOp>;
import_handler_map_["ZipMap"] =
&onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXZipMapOp>;

View File

@ -954,7 +954,7 @@ special cases:
def gen_op_importer(schema, file):
indent = inc_indent()
s = indent + 'if (opName == "' + schema.name + '")\n'
s = indent + 'import_handler_map_["' + schema.name +'"] = \n '
expected_num_operands = len(schema.inputs)
expected_num_results = len(schema.outputs)
@ -978,8 +978,8 @@ def gen_op_importer(schema, file):
args.append(
'/* expected_num_results = */ {}'.format(expected_num_results))
"""
s += inc_indent(indent) + " {}({});\n".format(
handler_func, ", ".join(args))
s += inc_indent(indent) + '&onnx_mlir::detail::FrontendGenImpl::'
s += handler_func+';\n'
file.write(s)