diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index edb2611..cedcea4 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -371,6 +371,42 @@ void ONNXMatMulOp::inferShapes() { lhsShape[0] != rhsShape[0]) emitError("Attempt to multiply incompatible matrices."); dims.emplace_back(1); + } else if (lhsShape.size() == 1 && rhsShape.size() >= 2) { + // If the first argument is 1-D, it is promoted to a matrix by prepending a + // 1 to its dimensions. After matrix multiplication the prepended 1 is + // removed. + // + // N MATMUL (s1 x s2 x... x sK x N x P) + // => + // (s1 x s2 x... x sK x P) + + // Check legality of matrix multiplication. + unsigned rhsRank = rhsShape.size(); + if (lhsShape[0] != -1 && rhsShape[rhsRank - 2] != -1 && + lhsShape[0] != rhsShape[rhsRank - 2]) + emitError("Attempt to multiply incompatible matrices."); + + for (int i = 0; i < rhsRank - 2; ++i) + dims.emplace_back(rhsShape[i]); + dims.emplace_back(rhsShape[rhsRank - 1]); + } else if (lhsShape.size() >= 2 && rhsShape.size() == 1) { + // If the second argument is 1-D, it is promoted to a matrix by appending a + // 1 to its dimensions. After matrix multiplication the appended 1 is + // removed. + // + // (s1 x s2 x... x sK x M x N) MATMUL N + // => + // (s1 x s2 x... x sK x M) + + // Check legality of matrix multiplication. + unsigned lhsRank = lhsShape.size(); + if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 && + lhsShape[lhsRank - 1] != rhsShape[0]) + emitError("Attempt to multiply incompatible matrices."); + + for (int i = 0; i < lhsRank - 2; ++i) + dims.emplace_back(lhsShape[i]); + dims.emplace_back(lhsShape[lhsRank - 2]); } else if (lhsShape.size() > 2 && rhsShape.size() == 2) { // (s1 x s2 x... x sK x M x N) MATMUL (N x P) // => diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 5c9baf6..14c575d 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -117,6 +117,28 @@ func @test_matmul_8(%arg0 : tensor<32x64xf32>, %arg1 : tensor<64x128xf32>) -> te // CHECK: return [[RES8]] : tensor<32x128xf32> } +/// MatMul: 1-D x N-D + +func @test_matmul_9(%arg0 : tensor<42xf32>, %arg1 : tensor) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<42xf32>, tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_matmul_9 + // CHECK: [[RES1:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<42xf32>, tensor) -> tensor + // CHECK: return [[RES1]] : tensor +} + +/// MatMul: N-D x 1-D + +func @test_matmul_10(%arg0 : tensor, %arg1 : tensor<32xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor, tensor<32xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_matmul_10 + // CHECK: [[RES1:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor, tensor<32xf32>) -> tensor + // CHECK: return [[RES1]] : tensor +} + //===----------------------------------------------------------------------===// /// Test shape inference for ConvNoBias operation and all its attributes. //===----------------------------------------------------------------------===//