diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 0aaa87b..0a4fb5e 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -350,7 +350,11 @@ void ONNXMatMulOp::inferShapes() { SmallVector dims; auto lhsShape = lhsTy.getShape(); auto rhsShape = rhsTy.getShape(); - if (lhsShape.size() == 1 && rhsShape.size() == 1) { + + if (lhsShape.size() < 1 && rhsShape.size() < 1) { + // Multiplication by scalars is not allowed. + emitError("Multiplication by scalar arguments not allowed."); + } else if (lhsShape.size() == 1 && rhsShape.size() == 1) { // Special case when both arrays are 1-dimensional and according to // numpy rules the types need to be extended to 1xN and Nx1. Helper sizes // need to be removed after the multiplication but cannot be removed if all @@ -365,12 +369,12 @@ void ONNXMatMulOp::inferShapes() { // (s1 x s2 x... x sK x M x P) // Check legality of matrix multiplication. - unsigned leftDims = lhsShape.size(); - if (lhsShape[leftDims - 1] != -1 && rhsShape[0] != -1 && - lhsShape[leftDims - 1] != rhsShape[0]) + unsigned lhsRank = lhsShape.size(); + if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 && + lhsShape[lhsRank - 1] != rhsShape[0]) emitError("Attempt to multiply incompatible matrices."); - for (int i = 0; i < leftDims - 1; ++i) + for (int i = 0; i < lhsRank - 1; ++i) dims.emplace_back(lhsShape[i]); dims.emplace_back(rhsShape[1]); } else if (lhsShape.size() == 2 && rhsShape.size() > 2) { @@ -379,39 +383,39 @@ void ONNXMatMulOp::inferShapes() { // (s1 x s2 x... x sK x M x P) // Check legality of matrix multiplication. - unsigned rightDims = rhsShape.size(); - if (lhsShape[1] != -1 && rhsShape[rightDims - 2] != -1 && - lhsShape[1] != rhsShape[rightDims - 2]) + unsigned rhsRank = rhsShape.size(); + if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 && + lhsShape[1] != rhsShape[rhsRank - 2]) emitError("Attempt to multiply incompatible matrices."); - for (int i = 0; i < rightDims - 2; ++i) + for (int i = 0; i < rhsRank - 2; ++i) dims.emplace_back(rhsShape[i]); dims.emplace_back(lhsShape[0]); - dims.emplace_back(rhsShape[rightDims - 1]); + dims.emplace_back(rhsShape[rhsRank - 1]); } else if (lhsShape.size() > 2 && rhsShape.size() > 2) { // (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P) // => // (u1 x u2 x... x uK x M x P) // Check legality of matrix multiplication. - unsigned leftDims = lhsShape.size(); - unsigned rightDims = rhsShape.size(); - if (lhsShape[leftDims - 1] != -1 && rhsShape[rightDims - 2] != -1 && - lhsShape[leftDims - 1] != rhsShape[rightDims - 2]) + unsigned lhsRank = lhsShape.size(); + unsigned rhsRank = rhsShape.size(); + if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 && + lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2]) emitError("Attempt to multiply incompatible matrices."); // Check and perform broadcasting for the shapes. SmallVector lhsBcastShape; - for (int i = 0; i < leftDims - 2; ++i) + for (int i = 0; i < lhsRank - 2; ++i) lhsBcastShape.emplace_back(lhsShape[i]); SmallVector rhsBcastShape; - for (int i = 0; i < rightDims - 2; ++i) + for (int i = 0; i < rhsRank - 2; ++i) rhsBcastShape.emplace_back(rhsShape[i]); if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims)) emitError("Broadcasted dimensions are incompatible."); - dims.emplace_back(lhsShape[leftDims - 2]); - dims.emplace_back(rhsShape[rightDims - 1]); + dims.emplace_back(lhsShape[lhsRank - 2]); + dims.emplace_back(rhsShape[rhsRank - 1]); } else { // This case covers all remaining combinations of 1 and 2-D matrices. int64_t lhsDim = lhsShape[0];