diff --git a/doc/gen_doc.py b/doc/gen_doc.py index bd126d5..2b150d6 100644 --- a/doc/gen_doc.py +++ b/doc/gen_doc.py @@ -36,6 +36,7 @@ special_attr_defaults = dict([ special_op_handler = dict([ ("Conv", "ImportNodeConv"), ("MaxPool", "ImportNodeMaxPool"), + ("Gemm", "ImportNodeGemm"), #("Transpose", "ImportNodeTranspose") ]) diff --git a/src/builder/frontend_dialect_transformer.cpp b/src/builder/frontend_dialect_transformer.cpp index 7058cdc..10fc2c9 100644 --- a/src/builder/frontend_dialect_transformer.cpp +++ b/src/builder/frontend_dialect_transformer.cpp @@ -434,6 +434,18 @@ private: } } + /*! + * Special handle for Gemm operations. + */ + void ImportNodeGemm(onnx::NodeProto node, int nIn, int nOut) { + int nOps = node.input().size(); + if (nOps == 2) { + ImportNodeOneOut(node, 2, nOut); + } else { + ImportNodeOneOut(node, nIn, nOut); + } + } + void ImportNode(const onnx::NodeProto &node) { std::vector inputs; for (const auto &item : node.input()) { diff --git a/src/builder/op_build_table.inc b/src/builder/op_build_table.inc index 963b3b4..d7dea0f 100644 --- a/src/builder/op_build_table.inc +++ b/src/builder/op_build_table.inc @@ -98,7 +98,7 @@ }else if (OpName == "GatherND") { ImportNodeOneOut(node, 2, 1); }else if (OpName == "Gemm") { - ImportNodeOneOut(node, 3, 1); + ImportNodeGemm(node, 3, 1); }else if (OpName == "GlobalAveragePool") { ImportNodeOneOut(node, 1, 1); }else if (OpName == "GlobalLpPool") {