From e0918258960dac1381a72d24667b9f06a240dfe1 Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Fri, 10 Jan 2020 15:26:29 -0500 Subject: [PATCH] Add check for matrix size match for 1 and 2 dimenisional cases. --- src/dialect/onnx/onnx_ops.cpp | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 81f8887..09f9ebc 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -389,10 +389,18 @@ void ONNXMatMulOp::inferShapes() { dims.emplace_back(rhsShape[rightDims - 1]); } else { // This case covers all remaining combinations of 1 and 2-D matrices. - if (lhsShape.size() != 1) + int64_t lhsDim = lhsShape[0]; + int64_t rhsDim = rhsShape[0]; + if (lhsShape.size() > 1) { + lhsDim = lhsShape[1]; dims.emplace_back(lhsShape[0]); + } - if (rhsShape.size() != 1) + // Check legality of matrix multiplication. + if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim) + emitError("Attempt to multiply incompatible matrices."); + + if (rhsShape.size() > 1) dims.emplace_back(rhsShape[1]); }