From 170296b7c696d1a2d3d818a86bd2921124557dae Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Thu, 9 Jan 2020 15:30:57 -0500 Subject: [PATCH 01/16] Add special case for 1-D matrix multiplication. --- src/dialect/onnx/onnx_ops.cpp | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 4b68fe7..2380f5a 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -318,12 +318,28 @@ void ONNXMatMulOp::inferShapes() { if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) return; - auto lhsTy = getOperand(0).getType().cast(); - auto rhsTy = getOperand(1).getType().cast(); + + auto lhsTy = getOperand(0)->getType().cast(); + auto rhsTy = getOperand(1)->getType().cast(); + SmallVector dims; - dims.emplace_back(lhsTy.getShape()[0]); - dims.emplace_back(rhsTy.getShape()[1]); - getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); + auto lhsShape = lhsTy.getShape(); + auto rhsShape = rhsTy.getShape(); + if (lhsShape.size() == 1 && rhsShape.size() == 1) { + // Special case when both arrays are 1-dimensional and according to + // numpy rules the types need to be extended to 1xN and Nx1. Helper sizes + // need to be removed after the multiplication but cannot be removed if all + // remaining sizes are 1. + if (lhsShape[0] != -1 && rhsShape[0] != -1 && + lhsShape[0] != rhsShape[0]) + emitError("Attempt to multiply incompatible matrices."); + dims.emplace_back(1); + } else { + dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]); + dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]); + } + + getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); } // TODO: From d176b84506c47b5500230a16b862c062fd522af1 Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Fri, 10 Jan 2020 11:22:41 -0500 Subject: [PATCH 02/16] Add support for broadcasting right matrix. --- src/dialect/onnx/onnx_ops.cpp | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 2380f5a..f19242b 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -329,14 +329,31 @@ void ONNXMatMulOp::inferShapes() { // Special case when both arrays are 1-dimensional and according to // numpy rules the types need to be extended to 1xN and Nx1. Helper sizes // need to be removed after the multiplication but cannot be removed if all - // remaining sizes are 1. + // sizes are 1. if (lhsShape[0] != -1 && rhsShape[0] != -1 && lhsShape[0] != rhsShape[0]) emitError("Attempt to multiply incompatible matrices."); dims.emplace_back(1); } else { - dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]); - dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]); + // Special cases for when at least one matrix has more than two dimensions. + if (lhsShape.size() > 2 && rhsShape.size() == 2) { + // (s1 x s2 x... x sKx x M x N) MATMUL (N x P) + // => + // (s1 x s2 x... x sKx x M x P) + + // Check legality of matrix multiplication. + unsigned leftDims = lhsShape.size(); + if (lhsShape[leftDims - 2] != -1 && rhsShape[0] != -1 && + lhsShape[leftDims - 2] != rhsShape[0]) + emitError("Attempt to multiply incompatible matrices."); + + for (int i = 0; i < leftDims - 1; ++i) + dims.emplace_back(lhsShape[i]); + dims.emplace_back(rhsShape[1]); + } else { + dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]); + dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]); + } } getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); From 38bffee619874714123f8375c62eb1fc8db72a52 Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Fri, 10 Jan 2020 11:34:26 -0500 Subject: [PATCH 03/16] Add support for broadcasting left matrix. --- src/dialect/onnx/onnx_ops.cpp | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index f19242b..a449e05 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -337,9 +337,9 @@ void ONNXMatMulOp::inferShapes() { } else { // Special cases for when at least one matrix has more than two dimensions. if (lhsShape.size() > 2 && rhsShape.size() == 2) { - // (s1 x s2 x... x sKx x M x N) MATMUL (N x P) + // (s1 x s2 x... x sK x M x N) MATMUL (N x P) // => - // (s1 x s2 x... x sKx x M x P) + // (s1 x s2 x... x sK x M x P) // Check legality of matrix multiplication. unsigned leftDims = lhsShape.size(); @@ -350,6 +350,21 @@ void ONNXMatMulOp::inferShapes() { for (int i = 0; i < leftDims - 1; ++i) dims.emplace_back(lhsShape[i]); dims.emplace_back(rhsShape[1]); + } else if (lhsShape.size() == 2 && rhsShape.size() > 2) { + // (M x N) MATMUL (s1 x s2 x... x sK x N x P) + // => + // (s1 x s2 x... x sK x M x P) + + // Check legality of matrix multiplication. + unsigned rightDims = rhsShape.size(); + if (lhsShape[1] != -1 && rhsShape[rightDims - 2] != -1 && + lhsShape[1] != rhsShape[rightDims - 2]) + emitError("Attempt to multiply incompatible matrices."); + + for (int i = 0; i < rightDims - 2; ++i) + dims.emplace_back(rhsShape[i]); + dims.emplace_back(lhsShape[0]); + dims.emplace_back(rhsShape[rightDims - 1]); } else { dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]); dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]); From a3995b61e765ccd04467157becdf39f3e6bfe5db Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Fri, 10 Jan 2020 12:27:34 -0500 Subject: [PATCH 04/16] Add support for shape broadcast. --- src/dialect/onnx/onnx_ops.cpp | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index a449e05..9166a59 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -343,8 +343,8 @@ void ONNXMatMulOp::inferShapes() { // Check legality of matrix multiplication. unsigned leftDims = lhsShape.size(); - if (lhsShape[leftDims - 2] != -1 && rhsShape[0] != -1 && - lhsShape[leftDims - 2] != rhsShape[0]) + if (lhsShape[leftDims - 1] != -1 && rhsShape[0] != -1 && + lhsShape[leftDims - 1] != rhsShape[0]) emitError("Attempt to multiply incompatible matrices."); for (int i = 0; i < leftDims - 1; ++i) @@ -365,7 +365,32 @@ void ONNXMatMulOp::inferShapes() { dims.emplace_back(rhsShape[i]); dims.emplace_back(lhsShape[0]); dims.emplace_back(rhsShape[rightDims - 1]); + } else if (lhsShape.size() > 2 && rhsShape.size() > 2) { + // (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P) + // => + // (u1 x u2 x... x uK x M x P) + + // Check legality of matrix multiplication. + unsigned leftDims = lhsShape.size(); + unsigned rightDims = rhsShape.size(); + if (lhsShape[leftDims - 1] != -1 && rhsShape[rightDims - 2] != -1 && + lhsShape[leftDims - 1] != rhsShape[rightDims - 2]) + emitError("Attempt to multiply incompatible matrices."); + + // Check and perform broadcasting for the shapes. + SmallVector lhsBcastShape; + for (int i = 0; i < leftDims - 2; ++i) + lhsBcastShape.emplace_back(lhsShape[i]); + SmallVector rhsBcastShape; + for (int i = 0; i < rightDims - 2; ++i) + rhsBcastShape.emplace_back(rhsShape[i]); + if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims)) + emitError("Broadcasted dimensions are incompatible."); + + dims.emplace_back(lhsShape[leftDims - 2]); + dims.emplace_back(rhsShape[rightDims - 1]); } else { + // This case covers all remaining combinations of 1 and 2-D matrices. dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]); dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]); } From 96551ef71e93dfbd8f8008059c2edf790b5c456c Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Fri, 10 Jan 2020 12:30:12 -0500 Subject: [PATCH 05/16] Fix conditions. --- src/dialect/onnx/onnx_ops.cpp | 115 +++++++++++++++++----------------- 1 file changed, 56 insertions(+), 59 deletions(-) diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 9166a59..a3a49a5 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -334,66 +334,63 @@ void ONNXMatMulOp::inferShapes() { lhsShape[0] != rhsShape[0]) emitError("Attempt to multiply incompatible matrices."); dims.emplace_back(1); + } else if (lhsShape.size() > 2 && rhsShape.size() == 2) { + // (s1 x s2 x... x sK x M x N) MATMUL (N x P) + // => + // (s1 x s2 x... x sK x M x P) + + // Check legality of matrix multiplication. + unsigned leftDims = lhsShape.size(); + if (lhsShape[leftDims - 1] != -1 && rhsShape[0] != -1 && + lhsShape[leftDims - 1] != rhsShape[0]) + emitError("Attempt to multiply incompatible matrices."); + + for (int i = 0; i < leftDims - 1; ++i) + dims.emplace_back(lhsShape[i]); + dims.emplace_back(rhsShape[1]); + } else if (lhsShape.size() == 2 && rhsShape.size() > 2) { + // (M x N) MATMUL (s1 x s2 x... x sK x N x P) + // => + // (s1 x s2 x... x sK x M x P) + + // Check legality of matrix multiplication. + unsigned rightDims = rhsShape.size(); + if (lhsShape[1] != -1 && rhsShape[rightDims - 2] != -1 && + lhsShape[1] != rhsShape[rightDims - 2]) + emitError("Attempt to multiply incompatible matrices."); + + for (int i = 0; i < rightDims - 2; ++i) + dims.emplace_back(rhsShape[i]); + dims.emplace_back(lhsShape[0]); + dims.emplace_back(rhsShape[rightDims - 1]); + } else if (lhsShape.size() > 2 && rhsShape.size() > 2) { + // (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P) + // => + // (u1 x u2 x... x uK x M x P) + + // Check legality of matrix multiplication. + unsigned leftDims = lhsShape.size(); + unsigned rightDims = rhsShape.size(); + if (lhsShape[leftDims - 1] != -1 && rhsShape[rightDims - 2] != -1 && + lhsShape[leftDims - 1] != rhsShape[rightDims - 2]) + emitError("Attempt to multiply incompatible matrices."); + + // Check and perform broadcasting for the shapes. + SmallVector lhsBcastShape; + for (int i = 0; i < leftDims - 2; ++i) + lhsBcastShape.emplace_back(lhsShape[i]); + SmallVector rhsBcastShape; + for (int i = 0; i < rightDims - 2; ++i) + rhsBcastShape.emplace_back(rhsShape[i]); + if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims)) + emitError("Broadcasted dimensions are incompatible."); + + dims.emplace_back(lhsShape[leftDims - 2]); + dims.emplace_back(rhsShape[rightDims - 1]); } else { - // Special cases for when at least one matrix has more than two dimensions. - if (lhsShape.size() > 2 && rhsShape.size() == 2) { - // (s1 x s2 x... x sK x M x N) MATMUL (N x P) - // => - // (s1 x s2 x... x sK x M x P) - - // Check legality of matrix multiplication. - unsigned leftDims = lhsShape.size(); - if (lhsShape[leftDims - 1] != -1 && rhsShape[0] != -1 && - lhsShape[leftDims - 1] != rhsShape[0]) - emitError("Attempt to multiply incompatible matrices."); - - for (int i = 0; i < leftDims - 1; ++i) - dims.emplace_back(lhsShape[i]); - dims.emplace_back(rhsShape[1]); - } else if (lhsShape.size() == 2 && rhsShape.size() > 2) { - // (M x N) MATMUL (s1 x s2 x... x sK x N x P) - // => - // (s1 x s2 x... x sK x M x P) - - // Check legality of matrix multiplication. - unsigned rightDims = rhsShape.size(); - if (lhsShape[1] != -1 && rhsShape[rightDims - 2] != -1 && - lhsShape[1] != rhsShape[rightDims - 2]) - emitError("Attempt to multiply incompatible matrices."); - - for (int i = 0; i < rightDims - 2; ++i) - dims.emplace_back(rhsShape[i]); - dims.emplace_back(lhsShape[0]); - dims.emplace_back(rhsShape[rightDims - 1]); - } else if (lhsShape.size() > 2 && rhsShape.size() > 2) { - // (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P) - // => - // (u1 x u2 x... x uK x M x P) - - // Check legality of matrix multiplication. - unsigned leftDims = lhsShape.size(); - unsigned rightDims = rhsShape.size(); - if (lhsShape[leftDims - 1] != -1 && rhsShape[rightDims - 2] != -1 && - lhsShape[leftDims - 1] != rhsShape[rightDims - 2]) - emitError("Attempt to multiply incompatible matrices."); - - // Check and perform broadcasting for the shapes. - SmallVector lhsBcastShape; - for (int i = 0; i < leftDims - 2; ++i) - lhsBcastShape.emplace_back(lhsShape[i]); - SmallVector rhsBcastShape; - for (int i = 0; i < rightDims - 2; ++i) - rhsBcastShape.emplace_back(rhsShape[i]); - if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims)) - emitError("Broadcasted dimensions are incompatible."); - - dims.emplace_back(lhsShape[leftDims - 2]); - dims.emplace_back(rhsShape[rightDims - 1]); - } else { - // This case covers all remaining combinations of 1 and 2-D matrices. - dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]); - dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]); - } + // This case covers all remaining combinations of 1 and 2-D matrices. + dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]); + dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]); } getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); From 6478c88cdc51ac8e1d50d7eabb00eae1420bec1a Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Fri, 10 Jan 2020 14:36:02 -0500 Subject: [PATCH 06/16] Add test for all one dimensional case. --- test/mlir/onnx/onnx_shape_inference.mlir | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 4cb9bec..a0a2550 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -1,5 +1,6 @@ // RUN: onnf-opt --shape-inference %s -split-input-file | FileCheck %s + /// Test the default behavior of transpose when no information for the /// permutation of the axes is provided. func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { @@ -9,4 +10,15 @@ func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { // CHECK-LABEL: test_default_transpose // CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32> -// CHECK: return [[RES]] : tensor<32x1x5x5xf32> \ No newline at end of file +// CHECK: return [[RES]] : tensor<32x1x5x5xf32> + + +/// Test the shape inferencing scheme for the matmul operation. +func @test_matmul_1(%arg0 : tensor<32xf32>, %arg1 : tensor<32xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () +} + +// CHECK-LABEL: test_matmul_1 +// CHECK: [[RES1:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32xf32>) -> tensor<1xf32> +// CHECK: return [[RES1]] : tensor<1xf32> \ No newline at end of file From a5f1d39c20e8f7754bde041894d213f9d85267dc Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Fri, 10 Jan 2020 14:55:54 -0500 Subject: [PATCH 07/16] Add tests for matrices and stack of matrices combinations. --- test/mlir/onnx/onnx_shape_inference.mlir | 37 ++++++++++++++++++------ 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index a0a2550..7b47e60 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -1,24 +1,43 @@ // RUN: onnf-opt --shape-inference %s -split-input-file | FileCheck %s - /// Test the default behavior of transpose when no information for the /// permutation of the axes is provided. func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { %0 = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_default_transpose + // CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32> + // CHECK: return [[RES]] : tensor<32x1x5x5xf32> } -// CHECK-LABEL: test_default_transpose -// CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32> -// CHECK: return [[RES]] : tensor<32x1x5x5xf32> - - /// Test the shape inferencing scheme for the matmul operation. +/// MatMul: 1-D x 1-D func @test_matmul_1(%arg0 : tensor<32xf32>, %arg1 : tensor<32xf32>) -> tensor<*xf32> { %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32xf32>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_matmul_1 + // CHECK: [[RES1:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32xf32>) -> tensor<1xf32> + // CHECK: return [[RES1]] : tensor<1xf32> } -// CHECK-LABEL: test_matmul_1 -// CHECK: [[RES1:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32xf32>) -> tensor<1xf32> -// CHECK: return [[RES1]] : tensor<1xf32> \ No newline at end of file +/// MatMul: K-D x 2-D (K > 2) +func @test_matmul_2(%arg0 : tensor<16x?x64x42xf32>, %arg1 : tensor<42x32xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x42xf32>, tensor<42x32xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_matmul_2 + // CHECK: [[RES2:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x42xf32>, tensor<42x32xf32>) -> tensor<16x?x64x32xf32> + // CHECK: return [[RES2]] : tensor<16x?x64x32xf32> +} + +/// MatMul: 2-D x K-D (K > 2) +func @test_matmul_3(%arg0 : tensor<64x42xf32>, %arg1 : tensor<16x?x42x32xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor<16x?x42x32xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_matmul_3 + // CHECK: [[RES3:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor<16x?x42x32xf32>) -> tensor<16x?x64x32xf32> + // CHECK: return [[RES3]] : tensor<16x?x64x32xf32> +} \ No newline at end of file From ae966cdee96cb47c21bad844e404bcccbd034d97 Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Fri, 10 Jan 2020 14:56:01 -0500 Subject: [PATCH 08/16] Add tests for matrices and stack of matrices combinations. --- test/mlir/onnx/onnx_shape_inference.mlir | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 7b47e60..3e4eb4a 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -40,4 +40,14 @@ func @test_matmul_3(%arg0 : tensor<64x42xf32>, %arg1 : tensor<16x?x42x32xf32>) - // CHECK-LABEL: test_matmul_3 // CHECK: [[RES3:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor<16x?x42x32xf32>) -> tensor<16x?x64x32xf32> // CHECK: return [[RES3]] : tensor<16x?x64x32xf32> +} + +/// MatMul: 2-D x K-D (K > 2) +func @test_matmul_4(%arg0 : tensor<64x42xf32>, %arg1 : tensor) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_matmul_4 + // CHECK: [[RES4:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor) -> tensor + // CHECK: return [[RES4]] : tensor } \ No newline at end of file From 95ebf3e23ad1c1508eae6d5a993eaaa0050bd21a Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Fri, 10 Jan 2020 15:00:56 -0500 Subject: [PATCH 09/16] Add test for multypling stacks of matrices. --- test/mlir/onnx/onnx_shape_inference.mlir | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 3e4eb4a..7c44535 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -50,4 +50,14 @@ func @test_matmul_4(%arg0 : tensor<64x42xf32>, %arg1 : tensor) -> t // CHECK-LABEL: test_matmul_4 // CHECK: [[RES4:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor) -> tensor // CHECK: return [[RES4]] : tensor +} + +/// MatMul: K1-D x K2-D (K1 > 2, K2 > 2) +func @test_matmul_5(%arg0 : tensor<16x?x64x42xf32>, %arg1 : tensor<32x?x64x42x32xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x42xf32>, tensor<32x?x64x42x32xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_matmul_5 + // CHECK: [[RES5:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x42xf32>, tensor<32x?x64x42x32xf32>) -> tensor<32x16x64x64x32xf32> + // CHECK: return [[RES5]] : tensor<32x16x64x64x32xf32> } \ No newline at end of file From 642f77abedc846342e485fd877bc1c5c5f265825 Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Fri, 10 Jan 2020 15:02:47 -0500 Subject: [PATCH 10/16] Add additional dynamic dimension. --- test/mlir/onnx/onnx_shape_inference.mlir | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 7c44535..64f1533 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -53,11 +53,11 @@ func @test_matmul_4(%arg0 : tensor<64x42xf32>, %arg1 : tensor) -> t } /// MatMul: K1-D x K2-D (K1 > 2, K2 > 2) -func @test_matmul_5(%arg0 : tensor<16x?x64x42xf32>, %arg1 : tensor<32x?x64x42x32xf32>) -> tensor<*xf32> { - %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x42xf32>, tensor<32x?x64x42x32xf32>) -> tensor<*xf32> +func @test_matmul_5(%arg0 : tensor<16x?x?x42xf32>, %arg1 : tensor<32x?x64x42x32xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x?x42xf32>, tensor<32x?x64x42x32xf32>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_matmul_5 - // CHECK: [[RES5:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x42xf32>, tensor<32x?x64x42x32xf32>) -> tensor<32x16x64x64x32xf32> - // CHECK: return [[RES5]] : tensor<32x16x64x64x32xf32> + // CHECK: [[RES5:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x?x42xf32>, tensor<32x?x64x42x32xf32>) -> tensor<32x16x64x?x32xf32> + // CHECK: return [[RES5]] : tensor<32x16x64x?x32xf32> } \ No newline at end of file From da0e9b01b163d5ecf5833556df6bf0404de9c3bf Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Fri, 10 Jan 2020 15:16:45 -0500 Subject: [PATCH 11/16] Fix 1 and 2 dimensional cases. Add test for 1 and 2 dimensional combinations. --- src/dialect/onnx/onnx_ops.cpp | 7 ++++-- test/mlir/onnx/onnx_shape_inference.mlir | 30 ++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index a3a49a5..81f8887 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -389,8 +389,11 @@ void ONNXMatMulOp::inferShapes() { dims.emplace_back(rhsShape[rightDims - 1]); } else { // This case covers all remaining combinations of 1 and 2-D matrices. - dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]); - dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]); + if (lhsShape.size() != 1) + dims.emplace_back(lhsShape[0]); + + if (rhsShape.size() != 1) + dims.emplace_back(rhsShape[1]); } getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 64f1533..26c0e1e 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -60,4 +60,34 @@ func @test_matmul_5(%arg0 : tensor<16x?x?x42xf32>, %arg1 : tensor<32x?x64x42x32x // CHECK-LABEL: test_matmul_5 // CHECK: [[RES5:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x?x42xf32>, tensor<32x?x64x42x32xf32>) -> tensor<32x16x64x?x32xf32> // CHECK: return [[RES5]] : tensor<32x16x64x?x32xf32> +} + +/// MatMul: 1-D x 2-D +func @test_matmul_6(%arg0 : tensor<32xf32>, %arg1 : tensor<32x64xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32x64xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_matmul_6 + // CHECK: [[RES6:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<32xf32>, tensor<32x64xf32>) -> tensor<64xf32> + // CHECK: return [[RES6]] : tensor<64xf32> +} + +/// MatMul: 2-D x 1-D +func @test_matmul_7(%arg0 : tensor<32x64xf32>, %arg1 : tensor<64xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_matmul_7 + // CHECK: [[RES7:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64xf32>) -> tensor<32xf32> + // CHECK: return [[RES7]] : tensor<32xf32> +} + +/// MatMul: 2-D x 2-D +func @test_matmul_8(%arg0 : tensor<32x64xf32>, %arg1 : tensor<64x128xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64x128xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_matmul_8 + // CHECK: [[RES8:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<32x64xf32>, tensor<64x128xf32>) -> tensor<32x128xf32> + // CHECK: return [[RES8]] : tensor<32x128xf32> } \ No newline at end of file From e0918258960dac1381a72d24667b9f06a240dfe1 Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Fri, 10 Jan 2020 15:26:29 -0500 Subject: [PATCH 12/16] Add check for matrix size match for 1 and 2 dimenisional cases. --- src/dialect/onnx/onnx_ops.cpp | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 81f8887..09f9ebc 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -389,10 +389,18 @@ void ONNXMatMulOp::inferShapes() { dims.emplace_back(rhsShape[rightDims - 1]); } else { // This case covers all remaining combinations of 1 and 2-D matrices. - if (lhsShape.size() != 1) + int64_t lhsDim = lhsShape[0]; + int64_t rhsDim = rhsShape[0]; + if (lhsShape.size() > 1) { + lhsDim = lhsShape[1]; dims.emplace_back(lhsShape[0]); + } - if (rhsShape.size() != 1) + // Check legality of matrix multiplication. + if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim) + emitError("Attempt to multiply incompatible matrices."); + + if (rhsShape.size() > 1) dims.emplace_back(rhsShape[1]); } From 36475ac509b739843a92f388038c4a4d0c22cd4f Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Fri, 10 Jan 2020 15:27:37 -0500 Subject: [PATCH 13/16] Code clean-up. --- src/dialect/onnx/onnx_ops.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 09f9ebc..e0e25e2 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -407,10 +407,6 @@ void ONNXMatMulOp::inferShapes() { getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); } -// TODO: -// Verify that matrix sizes are valid. -// Take into account the dimensionality of the matrix. - //===----------------------------------------------------------------------===// // Gemm From 1784ec2314f40860a8cc0617741e544351218674 Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Wed, 22 Jan 2020 16:09:19 -0500 Subject: [PATCH 14/16] Fix reference error. --- src/dialect/onnx/onnx_ops.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 3606fdd..e2e60c3 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -328,8 +328,8 @@ void ONNXMatMulOp::inferShapes() { !getOperand(1).getType().isa()) return; - auto lhsTy = getOperand(0)->getType().cast(); - auto rhsTy = getOperand(1)->getType().cast(); + auto lhsTy = getOperand(0).getType().cast(); + auto rhsTy = getOperand(1).getType().cast(); SmallVector dims; auto lhsShape = lhsTy.getShape(); @@ -413,7 +413,7 @@ void ONNXMatMulOp::inferShapes() { dims.emplace_back(rhsShape[1]); } - getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); + getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); } //===----------------------------------------------------------------------===// From 050d7d277d119070550f7351086d89e272331d45 Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Wed, 22 Jan 2020 16:12:09 -0500 Subject: [PATCH 15/16] Fix test. --- test/mlir/onnx/onnx_shape_inference.mlir | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index e246ffc..7057582 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -11,19 +11,15 @@ func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { // CHECK: return [[RES]] : tensor<32x1x5x5xf32> } -// CHECK-LABEL: test_default_transpose -// CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32> -// CHECK: return [[RES]] : tensor<32x1x5x5xf32> - /// Test shape inference for transposition when perm attribute is specified. func @test_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { %0 = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () -} -// CHECK-LABEL: test_transpose -// CHECK: [[RES_ATTR:%.+]] = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<1x5x32x5xf32> -// CHECK: return [[RES_ATTR]] : tensor<1x5x32x5xf32> + // CHECK-LABEL: test_transpose + // CHECK: [[RES_ATTR:%.+]] = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<1x5x32x5xf32> + // CHECK: return [[RES_ATTR]] : tensor<1x5x32x5xf32> +} //===----------------------------------------------------------------------===// /// Test the shape inferencing scheme for the matmul operation. From b450a763d19bcd37374c7c21cf70c826b0d37768 Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Mon, 27 Jan 2020 12:08:23 -0500 Subject: [PATCH 16/16] Change variable names to use rank. Add aditional check for scalars. --- src/dialect/onnx/onnx_ops.cpp | 40 +++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 0aaa87b..0a4fb5e 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -350,7 +350,11 @@ void ONNXMatMulOp::inferShapes() { SmallVector dims; auto lhsShape = lhsTy.getShape(); auto rhsShape = rhsTy.getShape(); - if (lhsShape.size() == 1 && rhsShape.size() == 1) { + + if (lhsShape.size() < 1 && rhsShape.size() < 1) { + // Multiplication by scalars is not allowed. + emitError("Multiplication by scalar arguments not allowed."); + } else if (lhsShape.size() == 1 && rhsShape.size() == 1) { // Special case when both arrays are 1-dimensional and according to // numpy rules the types need to be extended to 1xN and Nx1. Helper sizes // need to be removed after the multiplication but cannot be removed if all @@ -365,12 +369,12 @@ void ONNXMatMulOp::inferShapes() { // (s1 x s2 x... x sK x M x P) // Check legality of matrix multiplication. - unsigned leftDims = lhsShape.size(); - if (lhsShape[leftDims - 1] != -1 && rhsShape[0] != -1 && - lhsShape[leftDims - 1] != rhsShape[0]) + unsigned lhsRank = lhsShape.size(); + if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 && + lhsShape[lhsRank - 1] != rhsShape[0]) emitError("Attempt to multiply incompatible matrices."); - for (int i = 0; i < leftDims - 1; ++i) + for (int i = 0; i < lhsRank - 1; ++i) dims.emplace_back(lhsShape[i]); dims.emplace_back(rhsShape[1]); } else if (lhsShape.size() == 2 && rhsShape.size() > 2) { @@ -379,39 +383,39 @@ void ONNXMatMulOp::inferShapes() { // (s1 x s2 x... x sK x M x P) // Check legality of matrix multiplication. - unsigned rightDims = rhsShape.size(); - if (lhsShape[1] != -1 && rhsShape[rightDims - 2] != -1 && - lhsShape[1] != rhsShape[rightDims - 2]) + unsigned rhsRank = rhsShape.size(); + if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 && + lhsShape[1] != rhsShape[rhsRank - 2]) emitError("Attempt to multiply incompatible matrices."); - for (int i = 0; i < rightDims - 2; ++i) + for (int i = 0; i < rhsRank - 2; ++i) dims.emplace_back(rhsShape[i]); dims.emplace_back(lhsShape[0]); - dims.emplace_back(rhsShape[rightDims - 1]); + dims.emplace_back(rhsShape[rhsRank - 1]); } else if (lhsShape.size() > 2 && rhsShape.size() > 2) { // (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P) // => // (u1 x u2 x... x uK x M x P) // Check legality of matrix multiplication. - unsigned leftDims = lhsShape.size(); - unsigned rightDims = rhsShape.size(); - if (lhsShape[leftDims - 1] != -1 && rhsShape[rightDims - 2] != -1 && - lhsShape[leftDims - 1] != rhsShape[rightDims - 2]) + unsigned lhsRank = lhsShape.size(); + unsigned rhsRank = rhsShape.size(); + if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 && + lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2]) emitError("Attempt to multiply incompatible matrices."); // Check and perform broadcasting for the shapes. SmallVector lhsBcastShape; - for (int i = 0; i < leftDims - 2; ++i) + for (int i = 0; i < lhsRank - 2; ++i) lhsBcastShape.emplace_back(lhsShape[i]); SmallVector rhsBcastShape; - for (int i = 0; i < rightDims - 2; ++i) + for (int i = 0; i < rhsRank - 2; ++i) rhsBcastShape.emplace_back(rhsShape[i]); if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims)) emitError("Broadcasted dimensions are incompatible."); - dims.emplace_back(lhsShape[leftDims - 2]); - dims.emplace_back(rhsShape[rightDims - 1]); + dims.emplace_back(lhsShape[lhsRank - 2]); + dims.emplace_back(rhsShape[rhsRank - 1]); } else { // This case covers all remaining combinations of 1 and 2-D matrices. int64_t lhsDim = lhsShape[0];