diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 5474856..78ba92b 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -781,7 +781,8 @@ void ONNXConvNoBiasOp::inferShapes() { int64_t group = ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue(); // Check that the X.shape[1] == (W.shape[1] * group) == C condition holds. - if (dataShape[1] != (weightShape[1] * group)) + if (dataShape[1] != -1 && weightShape[1] != -1 && + dataShape[1] != (weightShape[1] * group)) emitError("Channel dimension mismatch"); // Note: the value of the group attribut only impacts the way the