From d61cf35471f7fad258ad403da58ea2423b1a9bf6 Mon Sep 17 00:00:00 2001 From: TUNG LEDUC Date: Tue, 26 Nov 2019 11:49:48 +0900 Subject: [PATCH] [MLIR] Add one more test case for MatMul-Add fusion (#380) * Add one more testcase for matmul-add fusion * Code format for identity elimination testcase --- test/mlir/onnx/onnx_canonicalization.mlir | 25 ++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index b7414a0..ad4fc6a 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -8,11 +8,22 @@ func @test_matmul_add_simplification(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf "std.return"(%1) : (tensor<10x10xf32>) -> () } -func @test_identity_identity(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>) -> tensor<10x10xf32> { - // CHECK-LABEL: test_identity_identity - // CHECK-NEXT: %{{[0-9]+}} = "onnx.Add"(%{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> - %0 = "onnx.Identity"(%a0) : (tensor<10x10xf32>) -> tensor<10x10xf32> - %1 = "onnx.Identity"(%a1) : (tensor<10x10xf32>) -> tensor<10x10xf32> - %2 = "onnx.Add"(%0, %1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> - "std.return"(%2) : (tensor<10x10xf32>) -> () +// onnx.MatMul ops with more than one result uses should not get fused +// CHECK-LABEL: func @test_sigmoid_add(%{{.*}}: tensor<10x10xf32>, %{{.*}}: tensor<10x10xf32>, %{{.*}}: tensor<10x10xf32>) -> tensor<10x10xf32> +func @test_sigmoid_add(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>, %a2: tensor<10x10xf32>) -> tensor<10x10xf32> { + // CHECK: %{{[0-9]+}} = "onnx.MatMul"(%{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + %0 = "onnx.MatMul"(%a0, %a1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + %1 = "onnx.Add"(%0, %a2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + %2 = "onnx.Add"(%0, %a1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + %3 = "onnx.Add"(%1, %2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + "std.return"(%3) : (tensor<10x10xf32>) -> () +} + +// CHECK-LABEL: @test_identity_identity(%{{.*}}: tensor<10x10xf32>, %{{.*}}: tensor<10x10xf32>) -> tensor<10x10xf32> +func @test_identity_identity(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>) -> tensor<10x10xf32> { + // CHECK-NEXT: %{{[0-9]+}} = "onnx.Add"(%{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + %0 = "onnx.Identity"(%a0) : (tensor<10x10xf32>) -> tensor<10x10xf32> + %1 = "onnx.Identity"(%a1) : (tensor<10x10xf32>) -> tensor<10x10xf32> + %2 = "onnx.Add"(%0, %1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + "std.return"(%2) : (tensor<10x10xf32>) -> () }