diff --git a/src/builder/frontend_dialect_transformer.cpp b/src/builder/frontend_dialect_transformer.cpp index 43f8aa9..31311ee 100644 --- a/src/builder/frontend_dialect_transformer.cpp +++ b/src/builder/frontend_dialect_transformer.cpp @@ -632,6 +632,21 @@ private: } } + /*! + * Special handle for MaxPool operations. + */ + void ImportNodeMaxPool( + onnx::NodeProto node, int nIn, + std::initializer_list> + attrs) { + int nOuts = node.output().size(); + if (nOuts == 1) { + ImportNodeOneOut(node, nIn, nOuts, attrs); + } else { + ImportNodeMultipleOuts(node, nIn, nOuts, attrs); + } + } + void ImportNode(onnx::NodeProto node) { std::vector inputs; for (auto item : node.input()) { diff --git a/src/builder/op_build_table.inc b/src/builder/op_build_table.inc index b9e5720..10adfde 100644 --- a/src/builder/op_build_table.inc +++ b/src/builder/op_build_table.inc @@ -303,7 +303,7 @@ ImportNodeOneOut(node, 1, 1, { }); }else if (OpName == "MaxPool") { - ImportNodeMultipleOuts(node, 1, 2, { + ImportNodeMaxPool(node, 1, { {"auto_pad","str","NOTSET"} ,{"ceil_mode","int","0"} ,{"dilations","", ""} diff --git a/src/dialect/onnx/gen_doc.py b/src/dialect/onnx/gen_doc.py index 0c1f4bc..6d986c2 100644 --- a/src/dialect/onnx/gen_doc.py +++ b/src/dialect/onnx/gen_doc.py @@ -365,6 +365,7 @@ special cases: def gen_code(schema,fefile) : special_handler = dict([ ("Conv", "ImportNodeConv"), + ("MaxPool", "ImportNodeMaxPool"), #("Transpose", "ImportNodeTranspose") ]) special_type = dict([ diff --git a/src/dialect/onnx/onnx.td b/src/dialect/onnx/onnx.td index 0e7f0d9..dba9f14 100644 --- a/src/dialect/onnx/onnx.td +++ b/src/dialect/onnx/onnx.td @@ -78,6 +78,17 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> { }]; } +//===----------------------------------------------------------------------===// +// ONNX Operations for handling optional arguments +//===----------------------------------------------------------------------===// + +// To allow pattern matching on operations with optional arguments/outputs we +// implement variants of the original ONNX dialect operations. The ONNX +// operations automatically generated by the `gen_doc.py` script and included +// in the `onnxop.inc` file have all optional arguments and outputs present. +// In the operations below we include the variants with missing operands +// or outputs. This decision affects only ONNX operations with optional +// arguments not ONNX operations with variadic operands. def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias", [NoSideEffect, DeclareOpInterfaceMethods]> { @@ -114,4 +125,15 @@ def ONNXConv3Op:ONNX_Op<"Conv3", let results = (outs AnyTensor); } +def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut", + [NoSideEffect]> { + let summary = "ONNX MaxPool operation with a single output."; + let description = [{ + "ONNX MaxPool operation with a single output." + "See ONNXMaxPoolOp for a full description of the MaxPool semantics." + }]; + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); +} + #endif // ONNX_OPS