From da0e9b01b163d5ecf5833556df6bf0404de9c3bf Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Fri, 10 Jan 2020 15:16:45 -0500 Subject: [PATCH] Fix 1 and 2 dimensional cases. Add test for 1 and 2 dimensional combinations. --- src/dialect/onnx/onnx_ops.cpp | 7 ++++-- test/mlir/onnx/onnx_shape_inference.mlir | 30 ++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index a3a49a5..81f8887 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -389,8 +389,11 @@ void ONNXMatMulOp::inferShapes() { dims.emplace_back(rhsShape[rightDims - 1]); } else { // This case covers all remaining combinations of 1 and 2-D matrices. - dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]); - dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]); + if (lhsShape.size() != 1) + dims.emplace_back(lhsShape[0]); + + if (rhsShape.size() != 1) + dims.emplace_back(rhsShape[1]); } getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 64f1533..26c0e1e 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -60,4 +60,34 @@ func @test_matmul_5(%arg0 : tensor<16x?x?x42xf32>, %arg1 : tensor<32x?x64x42x32x // CHECK-LABEL: test_matmul_5 // CHECK: [[RES5:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x?x42xf32>, tensor<32x?x64x42x32xf32>) -> tensor<32x16x64x?x32xf32> // CHECK: return [[RES5]] : tensor<32x16x64x?x32xf32> +} + +/// MatMul: 1-D x 2-D +func @test_matmul_6(%arg0 : tensor<32xf32>, %arg1 : tensor<32x64xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32x64xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_matmul_6 + // CHECK: [[RES6:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32x64xf32>) -> tensor<64xf32> + // CHECK: return [[RES6]] : tensor<64xf32> +} + +/// MatMul: 2-D x 1-D +func @test_matmul_7(%arg0 : tensor<32x64xf32>, %arg1 : tensor<64xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_matmul_7 + // CHECK: [[RES7:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64xf32>) -> tensor<32xf32> + // CHECK: return [[RES7]] : tensor<32xf32> +} + +/// MatMul: 2-D x 2-D +func @test_matmul_8(%arg0 : tensor<32x64xf32>, %arg1 : tensor<64x128xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64x128xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_matmul_8 + // CHECK: [[RES8:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64x128xf32>) -> tensor<32x128xf32> + // CHECK: return [[RES8]] : tensor<32x128xf32> } \ No newline at end of file