diff --git a/src/builder/frontend_dialect_transformer.cpp b/src/builder/frontend_dialect_transformer.cpp index 43f8aa9..040aa89 100644 --- a/src/builder/frontend_dialect_transformer.cpp +++ b/src/builder/frontend_dialect_transformer.cpp @@ -611,7 +611,7 @@ private: * a specialized function is used */ void ImportNodeConv( - onnx::NodeProto node, int nIn, int nOut, + onnx::NodeProto node, int nOut, std::initializer_list> attrs) { // Conv has attribute dilations, kernel_shape, pads, the default value of @@ -624,12 +624,12 @@ private: // similar situation for pads, strides in AveragePool // axes of ReduceSum, pads, strides, dilations and kernel_shape of MaxPool // TODO: fix this after type inference + int nOps = node.input().size(); - if (node.input().size() == 1) { - ImportNodeOneOut(node, nIn, nOut, attrs); - } else { - ImportNodeOneOut(node, nIn, nOut, attrs); - } + if (nOps == 2) + ImportNodeOneOut(node, nOps, nOut, attrs); + else + ImportNodeOneOut(node, nOps, nOut, attrs); } void ImportNode(onnx::NodeProto node) { diff --git a/src/builder/op_build_table.inc b/src/builder/op_build_table.inc index b9e5720..78c4f98 100644 --- a/src/builder/op_build_table.inc +++ b/src/builder/op_build_table.inc @@ -87,7 +87,7 @@ {"value","", ""} }); }else if (OpName == "Conv") { - ImportNodeConv(node, 3, 1, { + ImportNodeConv(node, 1, { {"auto_pad","str","NOTSET"} ,{"dilations","", ""} ,{"group","int", "1"} diff --git a/src/dialect/onnx/onnx.td b/src/dialect/onnx/onnx.td index 87901ce..734668f 100644 --- a/src/dialect/onnx/onnx.td +++ b/src/dialect/onnx/onnx.td @@ -92,26 +92,15 @@ def ONNXFullGemmOp: ONNX_Op<"FullGemm", let results = (outs AnyTensor); } -def ONNXConv1Op:ONNX_Op<"Conv1", +def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias", [NoSideEffect]> { - let summary = "ONNX Conv operation"; + let summary = "ONNX Conv operation with no Bias operand."; let description = [{ "The convolution operator consumes an input tensor and a filter, and" "computes the output." }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); -} - -def ONNXConv3Op:ONNX_Op<"Conv3", - [NoSideEffect]> { - let summary = "ONNX Conv operation"; - let description = [{ - "The convolution operator consumes an input tensor and a filter, and" - "computes the output." - }]; - let arguments = (ins AnyTensor:$X, AnyTensor:$W, AnyTensor:$B); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } #endif // ONNX_OPS diff --git a/src/dialect/onnx/onnxop.inc b/src/dialect/onnx/onnxop.inc index 3ec669a..16ad979 100644 --- a/src/dialect/onnx/onnxop.inc +++ b/src/dialect/onnx/onnxop.inc @@ -315,14 +315,14 @@ def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape", let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } -def ONNXConvOp:ONNX_Op<"Conv", +def ONNXConvOp:ONNX_Op<"Conv", [NoSideEffect]> { let summary = "ONNX Conv operation"; let description = [{ "The convolution operator consumes an input tensor and a filter, and" "computes the output." }]; - let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, Variadic>:$B); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); }