diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 33468f9..0a4fb5e 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -343,18 +343,99 @@ void ONNXMatMulOp::inferShapes() { if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) return; + auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); + SmallVector dims; - dims.emplace_back(lhsTy.getShape()[0]); - dims.emplace_back(rhsTy.getShape()[1]); + auto lhsShape = lhsTy.getShape(); + auto rhsShape = rhsTy.getShape(); + + if (lhsShape.size() < 1 && rhsShape.size() < 1) { + // Multiplication by scalars is not allowed. + emitError("Multiplication by scalar arguments not allowed."); + } else if (lhsShape.size() == 1 && rhsShape.size() == 1) { + // Special case when both arrays are 1-dimensional and according to + // numpy rules the types need to be extended to 1xN and Nx1. Helper sizes + // need to be removed after the multiplication but cannot be removed if all + // sizes are 1. + if (lhsShape[0] != -1 && rhsShape[0] != -1 && + lhsShape[0] != rhsShape[0]) + emitError("Attempt to multiply incompatible matrices."); + dims.emplace_back(1); + } else if (lhsShape.size() > 2 && rhsShape.size() == 2) { + // (s1 x s2 x... x sK x M x N) MATMUL (N x P) + // => + // (s1 x s2 x... x sK x M x P) + + // Check legality of matrix multiplication. + unsigned lhsRank = lhsShape.size(); + if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 && + lhsShape[lhsRank - 1] != rhsShape[0]) + emitError("Attempt to multiply incompatible matrices."); + + for (int i = 0; i < lhsRank - 1; ++i) + dims.emplace_back(lhsShape[i]); + dims.emplace_back(rhsShape[1]); + } else if (lhsShape.size() == 2 && rhsShape.size() > 2) { + // (M x N) MATMUL (s1 x s2 x... x sK x N x P) + // => + // (s1 x s2 x... x sK x M x P) + + // Check legality of matrix multiplication. + unsigned rhsRank = rhsShape.size(); + if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 && + lhsShape[1] != rhsShape[rhsRank - 2]) + emitError("Attempt to multiply incompatible matrices."); + + for (int i = 0; i < rhsRank - 2; ++i) + dims.emplace_back(rhsShape[i]); + dims.emplace_back(lhsShape[0]); + dims.emplace_back(rhsShape[rhsRank - 1]); + } else if (lhsShape.size() > 2 && rhsShape.size() > 2) { + // (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P) + // => + // (u1 x u2 x... x uK x M x P) + + // Check legality of matrix multiplication. + unsigned lhsRank = lhsShape.size(); + unsigned rhsRank = rhsShape.size(); + if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 && + lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2]) + emitError("Attempt to multiply incompatible matrices."); + + // Check and perform broadcasting for the shapes. + SmallVector lhsBcastShape; + for (int i = 0; i < lhsRank - 2; ++i) + lhsBcastShape.emplace_back(lhsShape[i]); + SmallVector rhsBcastShape; + for (int i = 0; i < rhsRank - 2; ++i) + rhsBcastShape.emplace_back(rhsShape[i]); + if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims)) + emitError("Broadcasted dimensions are incompatible."); + + dims.emplace_back(lhsShape[lhsRank - 2]); + dims.emplace_back(rhsShape[rhsRank - 1]); + } else { + // This case covers all remaining combinations of 1 and 2-D matrices. + int64_t lhsDim = lhsShape[0]; + int64_t rhsDim = rhsShape[0]; + if (lhsShape.size() > 1) { + lhsDim = lhsShape[1]; + dims.emplace_back(lhsShape[0]); + } + + // Check legality of matrix multiplication. + if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim) + emitError("Attempt to multiply incompatible matrices."); + + if (rhsShape.size() > 1) + dims.emplace_back(rhsShape[1]); + } + getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); } -// TODO: -// Verify that matrix sizes are valid. -// Take into account the dimensionality of the matrix. - //===----------------------------------------------------------------------===// // Gemm diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 6334c5f..5c9baf6 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -8,22 +8,114 @@ func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { %0 = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () -} -// CHECK-LABEL: test_default_transpose -// CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32> -// CHECK: return [[RES]] : tensor<32x1x5x5xf32> + // CHECK-LABEL: test_default_transpose + // CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32> + // CHECK: return [[RES]] : tensor<32x1x5x5xf32> +} /// Test shape inference for transposition when perm attribute is specified. func @test_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { %0 = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_transpose + // CHECK: [[RES_ATTR:%.+]] = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<1x5x32x5xf32> + // CHECK: return [[RES_ATTR]] : tensor<1x5x32x5xf32> } -// CHECK-LABEL: test_transpose -// CHECK: [[RES_ATTR:%.+]] = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<1x5x32x5xf32> -// CHECK: return [[RES_ATTR]] : tensor<1x5x32x5xf32> +//===----------------------------------------------------------------------===// +/// Test the shape inferencing scheme for the matmul operation. +//===----------------------------------------------------------------------===// + +/// MatMul: 1-D x 1-D + +func @test_matmul_1(%arg0 : tensor<32xf32>, %arg1 : tensor<32xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_matmul_1 + // CHECK: [[RES1:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32xf32>) -> tensor<1xf32> + // CHECK: return [[RES1]] : tensor<1xf32> +} + +/// MatMul: K-D x 2-D (K > 2) + +func @test_matmul_2(%arg0 : tensor<16x?x64x42xf32>, %arg1 : tensor<42x32xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x42xf32>, tensor<42x32xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_matmul_2 + // CHECK: [[RES2:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x42xf32>, tensor<42x32xf32>) -> tensor<16x?x64x32xf32> + // CHECK: return [[RES2]] : tensor<16x?x64x32xf32> +} + +/// MatMul: 2-D x K-D (K > 2) + +func @test_matmul_3(%arg0 : tensor<64x42xf32>, %arg1 : tensor<16x?x42x32xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor<16x?x42x32xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_matmul_3 + // CHECK: [[RES3:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor<16x?x42x32xf32>) -> tensor<16x?x64x32xf32> + // CHECK: return [[RES3]] : tensor<16x?x64x32xf32> +} + +/// MatMul: 2-D x K-D (K > 2) + +func @test_matmul_4(%arg0 : tensor<64x42xf32>, %arg1 : tensor) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_matmul_4 + // CHECK: [[RES4:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor) -> tensor + // CHECK: return [[RES4]] : tensor +} + +/// MatMul: K1-D x K2-D (K1 > 2, K2 > 2) + +func @test_matmul_5(%arg0 : tensor<16x?x?x42xf32>, %arg1 : tensor<32x?x64x42x32xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x?x42xf32>, tensor<32x?x64x42x32xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_matmul_5 + // CHECK: [[RES5:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x?x42xf32>, tensor<32x?x64x42x32xf32>) -> tensor<32x16x64x?x32xf32> + // CHECK: return [[RES5]] : tensor<32x16x64x?x32xf32> +} + +/// MatMul: 1-D x 2-D + +func @test_matmul_6(%arg0 : tensor<32xf32>, %arg1 : tensor<32x64xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32x64xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_matmul_6 + // CHECK: [[RES6:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32x64xf32>) -> tensor<64xf32> + // CHECK: return [[RES6]] : tensor<64xf32> +} + +/// MatMul: 2-D x 1-D + +func @test_matmul_7(%arg0 : tensor<32x64xf32>, %arg1 : tensor<64xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_matmul_7 + // CHECK: [[RES7:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64xf32>) -> tensor<32xf32> + // CHECK: return [[RES7]] : tensor<32xf32> +} + +/// MatMul: 2-D x 2-D + +func @test_matmul_8(%arg0 : tensor<32x64xf32>, %arg1 : tensor<64x128xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64x128xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_matmul_8 + // CHECK: [[RES8:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64x128xf32>) -> tensor<32x128xf32> + // CHECK: return [[RES8]] : tensor<32x128xf32> +} //===----------------------------------------------------------------------===// /// Test shape inference for ConvNoBias operation and all its attributes.