diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 2380f5a..f19242b 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -329,14 +329,31 @@ void ONNXMatMulOp::inferShapes() { // Special case when both arrays are 1-dimensional and according to // numpy rules the types need to be extended to 1xN and Nx1. Helper sizes // need to be removed after the multiplication but cannot be removed if all - // remaining sizes are 1. + // sizes are 1. if (lhsShape[0] != -1 && rhsShape[0] != -1 && lhsShape[0] != rhsShape[0]) emitError("Attempt to multiply incompatible matrices."); dims.emplace_back(1); } else { - dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]); - dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]); + // Special cases for when at least one matrix has more than two dimensions. + if (lhsShape.size() > 2 && rhsShape.size() == 2) { + // (s1 x s2 x... x sKx x M x N) MATMUL (N x P) + // => + // (s1 x s2 x... x sKx x M x P) + + // Check legality of matrix multiplication. + unsigned leftDims = lhsShape.size(); + if (lhsShape[leftDims - 2] != -1 && rhsShape[0] != -1 && + lhsShape[leftDims - 2] != rhsShape[0]) + emitError("Attempt to multiply incompatible matrices."); + + for (int i = 0; i < leftDims - 1; ++i) + dims.emplace_back(lhsShape[i]); + dims.emplace_back(rhsShape[1]); + } else { + dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]); + dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]); + } } getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));