Use map when Import onnx node (#230)

* add map

* generate map
This commit is contained in:
chentong319 2020-07-27 12:25:21 -04:00 committed by GitHub
parent 8af5fdeb62
commit 32ceb6968a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 368 additions and 358 deletions

View File

@ -25,7 +25,7 @@ namespace bstd = mpark;
#include "FrontendDialectTransformer.hpp" #include "FrontendDialectTransformer.hpp"
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace detail {
/*! /*!
* The list of tensors initialized by the ONNX model. * The list of tensors initialized by the ONNX model.
@ -37,6 +37,7 @@ public:
FrontendGenImpl(mlir::MLIRContext &context) FrontendGenImpl(mlir::MLIRContext &context)
: context_(context), builder_(&context) { : context_(context), builder_(&context) {
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
InitHandlerMap();
} }
mlir::ModuleOp ImportONNXModel(onnx::ModelProto model) { mlir::ModuleOp ImportONNXModel(onnx::ModelProto model) {
@ -52,6 +53,11 @@ private:
// mapping between string name and symbol // mapping between string name and symbol
OnnxMlirSymbolMapping frontend_symbols_; OnnxMlirSymbolMapping frontend_symbols_;
typedef void (onnx_mlir::detail::FrontendGenImpl::*ImportHandlerType)(
const onnx::NodeProto &);
std::map<std::string, ImportHandlerType> import_handler_map_;
mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); } mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); }
/*! /*!
@ -329,7 +335,7 @@ private:
/*! /*!
* Special handle for MaxPool operations. * Special handle for MaxPool operations.
*/ */
void ImportNodeMaxPool(onnx::NodeProto node) { void ImportNodeMaxPool(const onnx::NodeProto &node) {
int nOuts = node.output().size(); int nOuts = node.output().size();
if (nOuts == 1) { if (nOuts == 1) {
buildOperation<mlir::ONNXMaxPoolSingleOutOp>(node); buildOperation<mlir::ONNXMaxPoolSingleOutOp>(node);
@ -341,7 +347,7 @@ private:
/*! /*!
* Special handle for BatchNormalization operations. * Special handle for BatchNormalization operations.
*/ */
void ImportNodeBatchNormalization(onnx::NodeProto node) { void ImportNodeBatchNormalization(const onnx::NodeProto &node) {
int nOuts = node.output().size(); int nOuts = node.output().size();
if (nOuts == 1) { if (nOuts == 1) {
// Test mode with one output. // Test mode with one output.
@ -355,7 +361,7 @@ private:
/*! /*!
* Special handle for Pad operations. * Special handle for Pad operations.
*/ */
void ImportNodePad(onnx::NodeProto node) { void ImportNodePad(const onnx::NodeProto &node) {
int nOps = node.input().size(); int nOps = node.input().size();
if (nOps == 2) { if (nOps == 2) {
@ -400,12 +406,16 @@ private:
// the generic operator is used // the generic operator is used
// one known reeason is the optional input // one known reeason is the optional input
(this->*(import_handler_map_[opName.str()]))(node);
}
void InitHandlerMap() {
#include "src/Builder/OpBuildTable.inc" #include "src/Builder/OpBuildTable.inc"
} }
/*! /*!
* Import output tensor, by doing the following: * Import output tensor, by doing the following:
* - Add the type of this output tensor to a list of tensor * - Add the t/yp this output tensor to a list of tensor
* types representing return types of this graph function. * types representing return types of this graph function.
* - Add this output tensor to the list of mlir::Value * - Add this output tensor to the list of mlir::Value
* to be returned by the function representing computation graph. * to be returned by the function representing computation graph.
@ -499,7 +509,7 @@ private:
mainFunc.setType(funcType); mainFunc.setType(funcType);
} }
}; // FrontendGenImpl class }; // FrontendGenImpl class
} // namespace } // namespace detail
} // namespace onnx_mlir } // namespace onnx_mlir
namespace onnx_mlir { namespace onnx_mlir {
@ -512,7 +522,7 @@ void ImportFrontendModelFile(std::string model_fname,
auto parse_success = model.ParseFromIstream(&input); auto parse_success = model.ParseFromIstream(&input);
assert(parse_success && "Onnx Model Parsing Failed."); assert(parse_success && "Onnx Model Parsing Failed.");
FrontendGenImpl myONNXGen(context); detail::FrontendGenImpl myONNXGen(context);
module = myONNXGen.ImportONNXModel(model); module = myONNXGen.ImportONNXModel(model);
} }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@ -4,351 +4,351 @@
// Details can be found in docs/ImportONNXDefs.md . // Details can be found in docs/ImportONNXDefs.md .
//******************************************************** //********************************************************
if (opName == "Abs") import_handler_map_["Abs"] =
buildOperation<mlir::ONNXAbsOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXAbsOp>;
if (opName == "Acos") import_handler_map_["Acos"] =
buildOperation<mlir::ONNXAcosOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXAcosOp>;
if (opName == "Acosh") import_handler_map_["Acosh"] =
buildOperation<mlir::ONNXAcoshOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXAcoshOp>;
if (opName == "Add") import_handler_map_["Add"] =
buildOperation<mlir::ONNXAddOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXAddOp>;
if (opName == "And") import_handler_map_["And"] =
buildOperation<mlir::ONNXAndOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXAndOp>;
if (opName == "ArgMax") import_handler_map_["ArgMax"] =
buildOperation<mlir::ONNXArgMaxOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXArgMaxOp>;
if (opName == "ArgMin") import_handler_map_["ArgMin"] =
buildOperation<mlir::ONNXArgMinOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXArgMinOp>;
if (opName == "Asin") import_handler_map_["Asin"] =
buildOperation<mlir::ONNXAsinOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXAsinOp>;
if (opName == "Asinh") import_handler_map_["Asinh"] =
buildOperation<mlir::ONNXAsinhOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXAsinhOp>;
if (opName == "Atan") import_handler_map_["Atan"] =
buildOperation<mlir::ONNXAtanOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXAtanOp>;
if (opName == "Atanh") import_handler_map_["Atanh"] =
buildOperation<mlir::ONNXAtanhOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXAtanhOp>;
if (opName == "AveragePool") import_handler_map_["AveragePool"] =
buildOperation<mlir::ONNXAveragePoolOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXAveragePoolOp>;
if (opName == "BatchNormalization") import_handler_map_["BatchNormalization"] =
ImportNodeBatchNormalization(node); &onnx_mlir::detail::FrontendGenImpl::ImportNodeBatchNormalization;
if (opName == "BitShift") import_handler_map_["BitShift"] =
buildOperation<mlir::ONNXBitShiftOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXBitShiftOp>;
if (opName == "Cast") import_handler_map_["Cast"] =
buildOperation<mlir::ONNXCastOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXCastOp>;
if (opName == "Ceil") import_handler_map_["Ceil"] =
buildOperation<mlir::ONNXCeilOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXCeilOp>;
if (opName == "Clip") import_handler_map_["Clip"] =
buildOperation<mlir::ONNXClipOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXClipOp>;
if (opName == "Compress") import_handler_map_["Compress"] =
buildOperation<mlir::ONNXCompressOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXCompressOp>;
if (opName == "Concat") import_handler_map_["Concat"] =
buildOperation<mlir::ONNXConcatOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXConcatOp>;
if (opName == "ConcatFromSequence") import_handler_map_["ConcatFromSequence"] =
buildOperation<mlir::ONNXConcatFromSequenceOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXConcatFromSequenceOp>;
if (opName == "Constant") import_handler_map_["Constant"] =
buildOperation<mlir::ONNXConstantOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXConstantOp>;
if (opName == "ConstantOfShape") import_handler_map_["ConstantOfShape"] =
buildOperation<mlir::ONNXConstantOfShapeOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXConstantOfShapeOp>;
if (opName == "Conv") import_handler_map_["Conv"] =
buildOperation<mlir::ONNXConvOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXConvOp>;
if (opName == "ConvInteger") import_handler_map_["ConvInteger"] =
buildOperation<mlir::ONNXConvIntegerOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXConvIntegerOp>;
if (opName == "ConvTranspose") import_handler_map_["ConvTranspose"] =
buildOperation<mlir::ONNXConvTransposeOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXConvTransposeOp>;
if (opName == "Cos") import_handler_map_["Cos"] =
buildOperation<mlir::ONNXCosOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXCosOp>;
if (opName == "Cosh") import_handler_map_["Cosh"] =
buildOperation<mlir::ONNXCoshOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXCoshOp>;
if (opName == "CumSum") import_handler_map_["CumSum"] =
buildOperation<mlir::ONNXCumSumOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXCumSumOp>;
if (opName == "DepthToSpace") import_handler_map_["DepthToSpace"] =
buildOperation<mlir::ONNXDepthToSpaceOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXDepthToSpaceOp>;
if (opName == "DequantizeLinear") import_handler_map_["DequantizeLinear"] =
buildOperation<mlir::ONNXDequantizeLinearOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXDequantizeLinearOp>;
if (opName == "Det") import_handler_map_["Det"] =
buildOperation<mlir::ONNXDetOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXDetOp>;
if (opName == "Div") import_handler_map_["Div"] =
buildOperation<mlir::ONNXDivOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXDivOp>;
if (opName == "Dropout") import_handler_map_["Dropout"] =
buildOperation<mlir::ONNXDropoutOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXDropoutOp>;
if (opName == "DynamicQuantizeLinear") import_handler_map_["DynamicQuantizeLinear"] =
buildOperation<mlir::ONNXDynamicQuantizeLinearOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXDynamicQuantizeLinearOp>;
if (opName == "Elu") import_handler_map_["Elu"] =
buildOperation<mlir::ONNXEluOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXEluOp>;
if (opName == "Equal") import_handler_map_["Equal"] =
buildOperation<mlir::ONNXEqualOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXEqualOp>;
if (opName == "Erf") import_handler_map_["Erf"] =
buildOperation<mlir::ONNXErfOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXErfOp>;
if (opName == "Exp") import_handler_map_["Exp"] =
buildOperation<mlir::ONNXExpOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXExpOp>;
if (opName == "Expand") import_handler_map_["Expand"] =
buildOperation<mlir::ONNXExpandOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXExpandOp>;
if (opName == "EyeLike") import_handler_map_["EyeLike"] =
buildOperation<mlir::ONNXEyeLikeOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXEyeLikeOp>;
if (opName == "Flatten") import_handler_map_["Flatten"] =
buildOperation<mlir::ONNXFlattenOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXFlattenOp>;
if (opName == "Floor") import_handler_map_["Floor"] =
buildOperation<mlir::ONNXFloorOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXFloorOp>;
if (opName == "GRU") import_handler_map_["GRU"] =
buildOperation<mlir::ONNXGRUOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGRUOp>;
if (opName == "Gather") import_handler_map_["Gather"] =
buildOperation<mlir::ONNXGatherOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGatherOp>;
if (opName == "GatherElements") import_handler_map_["GatherElements"] =
buildOperation<mlir::ONNXGatherElementsOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGatherElementsOp>;
if (opName == "GatherND") import_handler_map_["GatherND"] =
buildOperation<mlir::ONNXGatherNDOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGatherNDOp>;
if (opName == "Gemm") import_handler_map_["Gemm"] =
buildOperation<mlir::ONNXGemmOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGemmOp>;
if (opName == "GlobalAveragePool") import_handler_map_["GlobalAveragePool"] =
buildOperation<mlir::ONNXGlobalAveragePoolOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGlobalAveragePoolOp>;
if (opName == "GlobalLpPool") import_handler_map_["GlobalLpPool"] =
buildOperation<mlir::ONNXGlobalLpPoolOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGlobalLpPoolOp>;
if (opName == "GlobalMaxPool") import_handler_map_["GlobalMaxPool"] =
buildOperation<mlir::ONNXGlobalMaxPoolOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGlobalMaxPoolOp>;
if (opName == "Greater") import_handler_map_["Greater"] =
buildOperation<mlir::ONNXGreaterOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXGreaterOp>;
if (opName == "HardSigmoid") import_handler_map_["HardSigmoid"] =
buildOperation<mlir::ONNXHardSigmoidOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXHardSigmoidOp>;
if (opName == "Hardmax") import_handler_map_["Hardmax"] =
buildOperation<mlir::ONNXHardmaxOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXHardmaxOp>;
if (opName == "Identity") import_handler_map_["Identity"] =
buildOperation<mlir::ONNXIdentityOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXIdentityOp>;
if (opName == "If") import_handler_map_["If"] =
buildOperation<mlir::ONNXIfOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXIfOp>;
if (opName == "InstanceNormalization") import_handler_map_["InstanceNormalization"] =
buildOperation<mlir::ONNXInstanceNormalizationOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXInstanceNormalizationOp>;
if (opName == "IsInf") import_handler_map_["IsInf"] =
buildOperation<mlir::ONNXIsInfOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXIsInfOp>;
if (opName == "IsNaN") import_handler_map_["IsNaN"] =
buildOperation<mlir::ONNXIsNaNOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXIsNaNOp>;
if (opName == "LRN") import_handler_map_["LRN"] =
buildOperation<mlir::ONNXLRNOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXLRNOp>;
if (opName == "LSTM") import_handler_map_["LSTM"] =
buildOperation<mlir::ONNXLSTMOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXLSTMOp>;
if (opName == "LeakyRelu") import_handler_map_["LeakyRelu"] =
buildOperation<mlir::ONNXLeakyReluOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXLeakyReluOp>;
if (opName == "Less") import_handler_map_["Less"] =
buildOperation<mlir::ONNXLessOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXLessOp>;
if (opName == "Log") import_handler_map_["Log"] =
buildOperation<mlir::ONNXLogOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXLogOp>;
if (opName == "LogSoftmax") import_handler_map_["LogSoftmax"] =
buildOperation<mlir::ONNXLogSoftmaxOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXLogSoftmaxOp>;
if (opName == "Loop") import_handler_map_["Loop"] =
buildOperation<mlir::ONNXLoopOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXLoopOp>;
if (opName == "LpNormalization") import_handler_map_["LpNormalization"] =
buildOperation<mlir::ONNXLpNormalizationOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXLpNormalizationOp>;
if (opName == "LpPool") import_handler_map_["LpPool"] =
buildOperation<mlir::ONNXLpPoolOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXLpPoolOp>;
if (opName == "MatMul") import_handler_map_["MatMul"] =
buildOperation<mlir::ONNXMatMulOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXMatMulOp>;
if (opName == "MatMulInteger") import_handler_map_["MatMulInteger"] =
buildOperation<mlir::ONNXMatMulIntegerOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXMatMulIntegerOp>;
if (opName == "Max") import_handler_map_["Max"] =
buildOperation<mlir::ONNXMaxOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXMaxOp>;
if (opName == "MaxPool") import_handler_map_["MaxPool"] =
ImportNodeMaxPool(node); &onnx_mlir::detail::FrontendGenImpl::ImportNodeMaxPool;
if (opName == "MaxRoiPool") import_handler_map_["MaxRoiPool"] =
buildOperation<mlir::ONNXMaxRoiPoolOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXMaxRoiPoolOp>;
if (opName == "MaxUnpool") import_handler_map_["MaxUnpool"] =
buildOperation<mlir::ONNXMaxUnpoolOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXMaxUnpoolOp>;
if (opName == "Mean") import_handler_map_["Mean"] =
buildOperation<mlir::ONNXMeanOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXMeanOp>;
if (opName == "MeanVarianceNormalization") import_handler_map_["MeanVarianceNormalization"] =
buildOperation<mlir::ONNXMeanVarianceNormalizationOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXMeanVarianceNormalizationOp>;
if (opName == "Min") import_handler_map_["Min"] =
buildOperation<mlir::ONNXMinOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXMinOp>;
if (opName == "Mod") import_handler_map_["Mod"] =
buildOperation<mlir::ONNXModOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXModOp>;
if (opName == "Mul") import_handler_map_["Mul"] =
buildOperation<mlir::ONNXMulOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXMulOp>;
if (opName == "Multinomial") import_handler_map_["Multinomial"] =
buildOperation<mlir::ONNXMultinomialOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXMultinomialOp>;
if (opName == "Neg") import_handler_map_["Neg"] =
buildOperation<mlir::ONNXNegOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXNegOp>;
if (opName == "NonMaxSuppression") import_handler_map_["NonMaxSuppression"] =
buildOperation<mlir::ONNXNonMaxSuppressionOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXNonMaxSuppressionOp>;
if (opName == "NonZero") import_handler_map_["NonZero"] =
buildOperation<mlir::ONNXNonZeroOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXNonZeroOp>;
if (opName == "Not") import_handler_map_["Not"] =
buildOperation<mlir::ONNXNotOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXNotOp>;
if (opName == "OneHot") import_handler_map_["OneHot"] =
buildOperation<mlir::ONNXOneHotOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXOneHotOp>;
if (opName == "Or") import_handler_map_["Or"] =
buildOperation<mlir::ONNXOrOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXOrOp>;
if (opName == "PRelu") import_handler_map_["PRelu"] =
buildOperation<mlir::ONNXPReluOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXPReluOp>;
if (opName == "Pad") import_handler_map_["Pad"] =
ImportNodePad(node); &onnx_mlir::detail::FrontendGenImpl::ImportNodePad;
if (opName == "Pow") import_handler_map_["Pow"] =
buildOperation<mlir::ONNXPowOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXPowOp>;
if (opName == "QLinearConv") import_handler_map_["QLinearConv"] =
buildOperation<mlir::ONNXQLinearConvOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXQLinearConvOp>;
if (opName == "QLinearMatMul") import_handler_map_["QLinearMatMul"] =
buildOperation<mlir::ONNXQLinearMatMulOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXQLinearMatMulOp>;
if (opName == "QuantizeLinear") import_handler_map_["QuantizeLinear"] =
buildOperation<mlir::ONNXQuantizeLinearOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXQuantizeLinearOp>;
if (opName == "RNN") import_handler_map_["RNN"] =
buildOperation<mlir::ONNXRNNOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXRNNOp>;
if (opName == "RandomNormal") import_handler_map_["RandomNormal"] =
buildOperation<mlir::ONNXRandomNormalOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXRandomNormalOp>;
if (opName == "RandomNormalLike") import_handler_map_["RandomNormalLike"] =
buildOperation<mlir::ONNXRandomNormalLikeOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXRandomNormalLikeOp>;
if (opName == "RandomUniform") import_handler_map_["RandomUniform"] =
buildOperation<mlir::ONNXRandomUniformOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXRandomUniformOp>;
if (opName == "RandomUniformLike") import_handler_map_["RandomUniformLike"] =
buildOperation<mlir::ONNXRandomUniformLikeOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXRandomUniformLikeOp>;
if (opName == "Range") import_handler_map_["Range"] =
buildOperation<mlir::ONNXRangeOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXRangeOp>;
if (opName == "Reciprocal") import_handler_map_["Reciprocal"] =
buildOperation<mlir::ONNXReciprocalOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReciprocalOp>;
if (opName == "ReduceL1") import_handler_map_["ReduceL1"] =
buildOperation<mlir::ONNXReduceL1Op>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReduceL1Op>;
if (opName == "ReduceL2") import_handler_map_["ReduceL2"] =
buildOperation<mlir::ONNXReduceL2Op>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReduceL2Op>;
if (opName == "ReduceLogSum") import_handler_map_["ReduceLogSum"] =
buildOperation<mlir::ONNXReduceLogSumOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReduceLogSumOp>;
if (opName == "ReduceLogSumExp") import_handler_map_["ReduceLogSumExp"] =
buildOperation<mlir::ONNXReduceLogSumExpOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReduceLogSumExpOp>;
if (opName == "ReduceMax") import_handler_map_["ReduceMax"] =
buildOperation<mlir::ONNXReduceMaxOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReduceMaxOp>;
if (opName == "ReduceMean") import_handler_map_["ReduceMean"] =
buildOperation<mlir::ONNXReduceMeanOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReduceMeanOp>;
if (opName == "ReduceMin") import_handler_map_["ReduceMin"] =
buildOperation<mlir::ONNXReduceMinOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReduceMinOp>;
if (opName == "ReduceProd") import_handler_map_["ReduceProd"] =
buildOperation<mlir::ONNXReduceProdOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReduceProdOp>;
if (opName == "ReduceSum") import_handler_map_["ReduceSum"] =
buildOperation<mlir::ONNXReduceSumOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReduceSumOp>;
if (opName == "ReduceSumSquare") import_handler_map_["ReduceSumSquare"] =
buildOperation<mlir::ONNXReduceSumSquareOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReduceSumSquareOp>;
if (opName == "Relu") import_handler_map_["Relu"] =
buildOperation<mlir::ONNXReluOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReluOp>;
if (opName == "Reshape") import_handler_map_["Reshape"] =
buildOperation<mlir::ONNXReshapeOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReshapeOp>;
if (opName == "Resize") import_handler_map_["Resize"] =
buildOperation<mlir::ONNXResizeOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXResizeOp>;
if (opName == "ReverseSequence") import_handler_map_["ReverseSequence"] =
buildOperation<mlir::ONNXReverseSequenceOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXReverseSequenceOp>;
if (opName == "RoiAlign") import_handler_map_["RoiAlign"] =
buildOperation<mlir::ONNXRoiAlignOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXRoiAlignOp>;
if (opName == "Round") import_handler_map_["Round"] =
buildOperation<mlir::ONNXRoundOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXRoundOp>;
if (opName == "Scan") import_handler_map_["Scan"] =
buildOperation<mlir::ONNXScanOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXScanOp>;
if (opName == "Scatter") import_handler_map_["Scatter"] =
buildOperation<mlir::ONNXScatterOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXScatterOp>;
if (opName == "ScatterElements") import_handler_map_["ScatterElements"] =
buildOperation<mlir::ONNXScatterElementsOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXScatterElementsOp>;
if (opName == "ScatterND") import_handler_map_["ScatterND"] =
buildOperation<mlir::ONNXScatterNDOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXScatterNDOp>;
if (opName == "Selu") import_handler_map_["Selu"] =
buildOperation<mlir::ONNXSeluOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSeluOp>;
if (opName == "SequenceAt") import_handler_map_["SequenceAt"] =
buildOperation<mlir::ONNXSequenceAtOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSequenceAtOp>;
if (opName == "SequenceConstruct") import_handler_map_["SequenceConstruct"] =
buildOperation<mlir::ONNXSequenceConstructOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSequenceConstructOp>;
if (opName == "SequenceEmpty") import_handler_map_["SequenceEmpty"] =
buildOperation<mlir::ONNXSequenceEmptyOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSequenceEmptyOp>;
if (opName == "SequenceErase") import_handler_map_["SequenceErase"] =
buildOperation<mlir::ONNXSequenceEraseOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSequenceEraseOp>;
if (opName == "SequenceInsert") import_handler_map_["SequenceInsert"] =
buildOperation<mlir::ONNXSequenceInsertOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSequenceInsertOp>;
if (opName == "SequenceLength") import_handler_map_["SequenceLength"] =
buildOperation<mlir::ONNXSequenceLengthOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSequenceLengthOp>;
if (opName == "Shape") import_handler_map_["Shape"] =
buildOperation<mlir::ONNXShapeOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXShapeOp>;
if (opName == "Shrink") import_handler_map_["Shrink"] =
buildOperation<mlir::ONNXShrinkOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXShrinkOp>;
if (opName == "Sigmoid") import_handler_map_["Sigmoid"] =
buildOperation<mlir::ONNXSigmoidOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSigmoidOp>;
if (opName == "Sign") import_handler_map_["Sign"] =
buildOperation<mlir::ONNXSignOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSignOp>;
if (opName == "Sin") import_handler_map_["Sin"] =
buildOperation<mlir::ONNXSinOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSinOp>;
if (opName == "Sinh") import_handler_map_["Sinh"] =
buildOperation<mlir::ONNXSinhOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSinhOp>;
if (opName == "Size") import_handler_map_["Size"] =
buildOperation<mlir::ONNXSizeOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSizeOp>;
if (opName == "Slice") import_handler_map_["Slice"] =
buildOperation<mlir::ONNXSliceOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSliceOp>;
if (opName == "Softmax") import_handler_map_["Softmax"] =
buildOperation<mlir::ONNXSoftmaxOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSoftmaxOp>;
if (opName == "Softplus") import_handler_map_["Softplus"] =
buildOperation<mlir::ONNXSoftplusOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSoftplusOp>;
if (opName == "Softsign") import_handler_map_["Softsign"] =
buildOperation<mlir::ONNXSoftsignOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSoftsignOp>;
if (opName == "SpaceToDepth") import_handler_map_["SpaceToDepth"] =
buildOperation<mlir::ONNXSpaceToDepthOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSpaceToDepthOp>;
if (opName == "Split") import_handler_map_["Split"] =
buildOperation<mlir::ONNXSplitOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSplitOp>;
if (opName == "SplitToSequence") import_handler_map_["SplitToSequence"] =
buildOperation<mlir::ONNXSplitToSequenceOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSplitToSequenceOp>;
if (opName == "Sqrt") import_handler_map_["Sqrt"] =
buildOperation<mlir::ONNXSqrtOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSqrtOp>;
if (opName == "Squeeze") import_handler_map_["Squeeze"] =
buildOperation<mlir::ONNXSqueezeOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSqueezeOp>;
if (opName == "StringNormalizer") import_handler_map_["StringNormalizer"] =
buildOperation<mlir::ONNXStringNormalizerOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXStringNormalizerOp>;
if (opName == "Sub") import_handler_map_["Sub"] =
buildOperation<mlir::ONNXSubOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSubOp>;
if (opName == "Sum") import_handler_map_["Sum"] =
buildOperation<mlir::ONNXSumOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSumOp>;
if (opName == "Tan") import_handler_map_["Tan"] =
buildOperation<mlir::ONNXTanOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXTanOp>;
if (opName == "Tanh") import_handler_map_["Tanh"] =
buildOperation<mlir::ONNXTanhOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXTanhOp>;
if (opName == "TfIdfVectorizer") import_handler_map_["TfIdfVectorizer"] =
buildOperation<mlir::ONNXTfIdfVectorizerOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXTfIdfVectorizerOp>;
if (opName == "ThresholdedRelu") import_handler_map_["ThresholdedRelu"] =
buildOperation<mlir::ONNXThresholdedReluOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXThresholdedReluOp>;
if (opName == "Tile") import_handler_map_["Tile"] =
buildOperation<mlir::ONNXTileOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXTileOp>;
if (opName == "TopK") import_handler_map_["TopK"] =
buildOperation<mlir::ONNXTopKOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXTopKOp>;
if (opName == "Transpose") import_handler_map_["Transpose"] =
buildOperation<mlir::ONNXTransposeOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXTransposeOp>;
if (opName == "Unique") import_handler_map_["Unique"] =
buildOperation<mlir::ONNXUniqueOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXUniqueOp>;
if (opName == "Unsqueeze") import_handler_map_["Unsqueeze"] =
buildOperation<mlir::ONNXUnsqueezeOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXUnsqueezeOp>;
if (opName == "Upsample") import_handler_map_["Upsample"] =
buildOperation<mlir::ONNXUpsampleOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXUpsampleOp>;
if (opName == "Where") import_handler_map_["Where"] =
buildOperation<mlir::ONNXWhereOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXWhereOp>;
if (opName == "Xor") import_handler_map_["Xor"] =
buildOperation<mlir::ONNXXorOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXXorOp>;
if (opName == "ArrayFeatureExtractor") import_handler_map_["ArrayFeatureExtractor"] =
buildOperation<mlir::ONNXArrayFeatureExtractorOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXArrayFeatureExtractorOp>;
if (opName == "Binarizer") import_handler_map_["Binarizer"] =
buildOperation<mlir::ONNXBinarizerOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXBinarizerOp>;
if (opName == "CastMap") import_handler_map_["CastMap"] =
buildOperation<mlir::ONNXCastMapOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXCastMapOp>;
if (opName == "CategoryMapper") import_handler_map_["CategoryMapper"] =
buildOperation<mlir::ONNXCategoryMapperOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXCategoryMapperOp>;
if (opName == "DictVectorizer") import_handler_map_["DictVectorizer"] =
buildOperation<mlir::ONNXDictVectorizerOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXDictVectorizerOp>;
if (opName == "FeatureVectorizer") import_handler_map_["FeatureVectorizer"] =
buildOperation<mlir::ONNXFeatureVectorizerOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXFeatureVectorizerOp>;
if (opName == "Imputer") import_handler_map_["Imputer"] =
buildOperation<mlir::ONNXImputerOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXImputerOp>;
if (opName == "LabelEncoder") import_handler_map_["LabelEncoder"] =
buildOperation<mlir::ONNXLabelEncoderOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXLabelEncoderOp>;
if (opName == "LinearClassifier") import_handler_map_["LinearClassifier"] =
buildOperation<mlir::ONNXLinearClassifierOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXLinearClassifierOp>;
if (opName == "LinearRegressor") import_handler_map_["LinearRegressor"] =
buildOperation<mlir::ONNXLinearRegressorOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXLinearRegressorOp>;
if (opName == "Normalizer") import_handler_map_["Normalizer"] =
buildOperation<mlir::ONNXNormalizerOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXNormalizerOp>;
if (opName == "OneHotEncoder") import_handler_map_["OneHotEncoder"] =
buildOperation<mlir::ONNXOneHotEncoderOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXOneHotEncoderOp>;
if (opName == "SVMClassifier") import_handler_map_["SVMClassifier"] =
buildOperation<mlir::ONNXSVMClassifierOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSVMClassifierOp>;
if (opName == "SVMRegressor") import_handler_map_["SVMRegressor"] =
buildOperation<mlir::ONNXSVMRegressorOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXSVMRegressorOp>;
if (opName == "Scaler") import_handler_map_["Scaler"] =
buildOperation<mlir::ONNXScalerOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXScalerOp>;
if (opName == "TreeEnsembleClassifier") import_handler_map_["TreeEnsembleClassifier"] =
buildOperation<mlir::ONNXTreeEnsembleClassifierOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXTreeEnsembleClassifierOp>;
if (opName == "TreeEnsembleRegressor") import_handler_map_["TreeEnsembleRegressor"] =
buildOperation<mlir::ONNXTreeEnsembleRegressorOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXTreeEnsembleRegressorOp>;
if (opName == "ZipMap") import_handler_map_["ZipMap"] =
buildOperation<mlir::ONNXZipMapOp>(node); &onnx_mlir::detail::FrontendGenImpl::buildOperation<mlir::ONNXZipMapOp>;

View File

@ -954,7 +954,7 @@ special cases:
def gen_op_importer(schema, file): def gen_op_importer(schema, file):
indent = inc_indent() indent = inc_indent()
s = indent + 'if (opName == "' + schema.name + '")\n' s = indent + 'import_handler_map_["' + schema.name +'"] = \n '
expected_num_operands = len(schema.inputs) expected_num_operands = len(schema.inputs)
expected_num_results = len(schema.outputs) expected_num_results = len(schema.outputs)
@ -978,8 +978,8 @@ def gen_op_importer(schema, file):
args.append( args.append(
'/* expected_num_results = */ {}'.format(expected_num_results)) '/* expected_num_results = */ {}'.format(expected_num_results))
""" """
s += inc_indent(indent) + " {}({});\n".format( s += inc_indent(indent) + '&onnx_mlir::detail::FrontendGenImpl::'
handler_func, ", ".join(args)) s += handler_func+';\n'
file.write(s) file.write(s)