diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 81f8887..09f9ebc 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -389,10 +389,18 @@ void ONNXMatMulOp::inferShapes() { dims.emplace_back(rhsShape[rightDims - 1]); } else { // This case covers all remaining combinations of 1 and 2-D matrices. - if (lhsShape.size() != 1) + int64_t lhsDim = lhsShape[0]; + int64_t rhsDim = rhsShape[0]; + if (lhsShape.size() > 1) { + lhsDim = lhsShape[1]; dims.emplace_back(lhsShape[0]); + } - if (rhsShape.size() != 1) + // Check legality of matrix multiplication. + if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim) + emitError("Attempt to multiply incompatible matrices."); + + if (rhsShape.size() > 1) dims.emplace_back(rhsShape[1]); }