diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 9166a59..a3a49a5 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -334,66 +334,63 @@ void ONNXMatMulOp::inferShapes() { lhsShape[0] != rhsShape[0]) emitError("Attempt to multiply incompatible matrices."); dims.emplace_back(1); + } else if (lhsShape.size() > 2 && rhsShape.size() == 2) { + // (s1 x s2 x... x sK x M x N) MATMUL (N x P) + // => + // (s1 x s2 x... x sK x M x P) + + // Check legality of matrix multiplication. + unsigned leftDims = lhsShape.size(); + if (lhsShape[leftDims - 1] != -1 && rhsShape[0] != -1 && + lhsShape[leftDims - 1] != rhsShape[0]) + emitError("Attempt to multiply incompatible matrices."); + + for (int i = 0; i < leftDims - 1; ++i) + dims.emplace_back(lhsShape[i]); + dims.emplace_back(rhsShape[1]); + } else if (lhsShape.size() == 2 && rhsShape.size() > 2) { + // (M x N) MATMUL (s1 x s2 x... x sK x N x P) + // => + // (s1 x s2 x... x sK x M x P) + + // Check legality of matrix multiplication. + unsigned rightDims = rhsShape.size(); + if (lhsShape[1] != -1 && rhsShape[rightDims - 2] != -1 && + lhsShape[1] != rhsShape[rightDims - 2]) + emitError("Attempt to multiply incompatible matrices."); + + for (int i = 0; i < rightDims - 2; ++i) + dims.emplace_back(rhsShape[i]); + dims.emplace_back(lhsShape[0]); + dims.emplace_back(rhsShape[rightDims - 1]); + } else if (lhsShape.size() > 2 && rhsShape.size() > 2) { + // (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P) + // => + // (u1 x u2 x... x uK x M x P) + + // Check legality of matrix multiplication. + unsigned leftDims = lhsShape.size(); + unsigned rightDims = rhsShape.size(); + if (lhsShape[leftDims - 1] != -1 && rhsShape[rightDims - 2] != -1 && + lhsShape[leftDims - 1] != rhsShape[rightDims - 2]) + emitError("Attempt to multiply incompatible matrices."); + + // Check and perform broadcasting for the shapes. + SmallVector lhsBcastShape; + for (int i = 0; i < leftDims - 2; ++i) + lhsBcastShape.emplace_back(lhsShape[i]); + SmallVector rhsBcastShape; + for (int i = 0; i < rightDims - 2; ++i) + rhsBcastShape.emplace_back(rhsShape[i]); + if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims)) + emitError("Broadcasted dimensions are incompatible."); + + dims.emplace_back(lhsShape[leftDims - 2]); + dims.emplace_back(rhsShape[rightDims - 1]); } else { - // Special cases for when at least one matrix has more than two dimensions. - if (lhsShape.size() > 2 && rhsShape.size() == 2) { - // (s1 x s2 x... x sK x M x N) MATMUL (N x P) - // => - // (s1 x s2 x... x sK x M x P) - - // Check legality of matrix multiplication. - unsigned leftDims = lhsShape.size(); - if (lhsShape[leftDims - 1] != -1 && rhsShape[0] != -1 && - lhsShape[leftDims - 1] != rhsShape[0]) - emitError("Attempt to multiply incompatible matrices."); - - for (int i = 0; i < leftDims - 1; ++i) - dims.emplace_back(lhsShape[i]); - dims.emplace_back(rhsShape[1]); - } else if (lhsShape.size() == 2 && rhsShape.size() > 2) { - // (M x N) MATMUL (s1 x s2 x... x sK x N x P) - // => - // (s1 x s2 x... x sK x M x P) - - // Check legality of matrix multiplication. - unsigned rightDims = rhsShape.size(); - if (lhsShape[1] != -1 && rhsShape[rightDims - 2] != -1 && - lhsShape[1] != rhsShape[rightDims - 2]) - emitError("Attempt to multiply incompatible matrices."); - - for (int i = 0; i < rightDims - 2; ++i) - dims.emplace_back(rhsShape[i]); - dims.emplace_back(lhsShape[0]); - dims.emplace_back(rhsShape[rightDims - 1]); - } else if (lhsShape.size() > 2 && rhsShape.size() > 2) { - // (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P) - // => - // (u1 x u2 x... x uK x M x P) - - // Check legality of matrix multiplication. - unsigned leftDims = lhsShape.size(); - unsigned rightDims = rhsShape.size(); - if (lhsShape[leftDims - 1] != -1 && rhsShape[rightDims - 2] != -1 && - lhsShape[leftDims - 1] != rhsShape[rightDims - 2]) - emitError("Attempt to multiply incompatible matrices."); - - // Check and perform broadcasting for the shapes. - SmallVector lhsBcastShape; - for (int i = 0; i < leftDims - 2; ++i) - lhsBcastShape.emplace_back(lhsShape[i]); - SmallVector rhsBcastShape; - for (int i = 0; i < rightDims - 2; ++i) - rhsBcastShape.emplace_back(rhsShape[i]); - if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims)) - emitError("Broadcasted dimensions are incompatible."); - - dims.emplace_back(lhsShape[leftDims - 2]); - dims.emplace_back(rhsShape[rightDims - 1]); - } else { - // This case covers all remaining combinations of 1 and 2-D matrices. - dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]); - dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]); - } + // This case covers all remaining combinations of 1 and 2-D matrices. + dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]); + dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]); } getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));