From 3f6efdf4a4e809f000e9529fc4b838d56df2417d Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Wed, 15 Jan 2020 15:16:45 -0500 Subject: [PATCH] Fix MaxPool translation to ONNX dialect. --- src/builder/frontend_dialect_transformer.cpp | 15 +++++++++++++++ src/builder/op_build_table.inc | 2 +- src/dialect/onnx/gen_doc.py | 1 + src/dialect/onnx/onnx.td | 11 +++++++++++ 4 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/builder/frontend_dialect_transformer.cpp b/src/builder/frontend_dialect_transformer.cpp index 43f8aa9..31311ee 100644 --- a/src/builder/frontend_dialect_transformer.cpp +++ b/src/builder/frontend_dialect_transformer.cpp @@ -632,6 +632,21 @@ private: } } + /*! + * Special handle for MaxPool operations. + */ + void ImportNodeMaxPool( + onnx::NodeProto node, int nIn, + std::initializer_list> + attrs) { + int nOuts = node.output().size(); + if (nOuts == 1) { + ImportNodeOneOut(node, nIn, nOuts, attrs); + } else { + ImportNodeMultipleOuts(node, nIn, nOuts, attrs); + } + } + void ImportNode(onnx::NodeProto node) { std::vector inputs; for (auto item : node.input()) { diff --git a/src/builder/op_build_table.inc b/src/builder/op_build_table.inc index b9e5720..10adfde 100644 --- a/src/builder/op_build_table.inc +++ b/src/builder/op_build_table.inc @@ -303,7 +303,7 @@ ImportNodeOneOut(node, 1, 1, { }); }else if (OpName == "MaxPool") { - ImportNodeMultipleOuts(node, 1, 2, { + ImportNodeMaxPool(node, 1, { {"auto_pad","str","NOTSET"} ,{"ceil_mode","int","0"} ,{"dilations","", ""} diff --git a/src/dialect/onnx/gen_doc.py b/src/dialect/onnx/gen_doc.py index 0c1f4bc..6d986c2 100644 --- a/src/dialect/onnx/gen_doc.py +++ b/src/dialect/onnx/gen_doc.py @@ -365,6 +365,7 @@ special cases: def gen_code(schema,fefile) : special_handler = dict([ ("Conv", "ImportNodeConv"), + ("MaxPool", "ImportNodeMaxPool"), #("Transpose", "ImportNodeTranspose") ]) special_type = dict([ diff --git a/src/dialect/onnx/onnx.td b/src/dialect/onnx/onnx.td index 87901ce..7396990 100644 --- a/src/dialect/onnx/onnx.td +++ b/src/dialect/onnx/onnx.td @@ -114,4 +114,15 @@ def ONNXConv3Op:ONNX_Op<"Conv3", let results = (outs AnyTensor); } +def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut", + [NoSideEffect]> { + let summary = "ONNX MaxPool operation with a single output."; + let description = [{ + "ONNX MaxPool operation with a single output." + "See ONNXMaxPoolOp for a full description of the MaxPool semantics." + }]; + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); +} + #endif // ONNX_OPS