From 0c4a0102836c9dcdfae49c5011fd96766714a78f Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Thu, 27 Feb 2020 01:40:52 +0900 Subject: [PATCH] Remove rank constraints in gemm fusion (#101) * Remove rank constraints in gemm fusion * Add an MLIR test Co-authored-by: Tian Jin Co-authored-by: Gheorghe-Teodor Bercea --- src/pass/onnx_combine.td | 2 +- test/mlir/onnx/onnx_canonicalization.mlir | 14 +++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/pass/onnx_combine.td b/src/pass/onnx_combine.td index 3674fd3..37c9eee 100644 --- a/src/pass/onnx_combine.td +++ b/src/pass/onnx_combine.td @@ -45,7 +45,7 @@ def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3), // onnx.add(onnx.Gemm(%X, %Y, None), %Z) = onnx.Gemm(%X, %Y, %Z) def FuseGemmFollowedByAddition : Pat<(ONNXAddOp (ONNXGemmOp:$res $m1, $m2, $none, $alpha, $beta, $transA, $transB), $bias), (ONNXGemmOp $m1, $m2, $bias, $alpha, $beta, $transA, $transB), - [(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2), (HasNoneType $none)]>; + [(HasOneUse $res), (HasNoneType $none)]>; // ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X) def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg), diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index a535a33..7661f28 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -111,4 +111,16 @@ func @test_gemm_add_fusion(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32 // CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128xf32>) -> tensor<*xf32> // return [[GEMM]] : tensor<*xf32> -} \ No newline at end of file +} + +//CHECK-LABEL: @test_gemm_add_fusion_rank3(%{{.*}}: tensor<128x128x256xf32>, %{{.*}}: tensor<128x128x256xf32>, %{{.*}}: tensor<256xf32>) -> tensor<*xf32> { +func @test_gemm_add_fusion_rank3(%arg0: tensor<128x128x256xf32>, %arg1: tensor<128x128x256xf32>, %arg2: tensor<256xf32>) -> tensor<*xf32> { + %cst = constant unit + %0 = "onnx.Gemm"(%arg0, %arg1, %cst) : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, none) -> tensor<*xf32> + %1 = "onnx.Add"(%0, %arg2) : (tensor<*xf32>, tensor<256xf32>) -> tensor<*xf32> + return %1 : tensor<*xf32> + + // CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, tensor<256xf32>) -> tensor<*xf32> + // return [[GEMM]] : tensor<*xf32> +} +