2019-11-19 08:37:58 +08:00
|
|
|
// RUN: onnf-opt --canonicalize %s -split-input-file | FileCheck %s
|
|
|
|
|
2019-11-26 10:36:38 +08:00
|
|
|
func @test_matmul_add_simplification(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>, %a2: tensor<10x10xf32>) -> tensor<10x10xf32> {
|
|
|
|
// CHECK-LABEL: test_matmul_add_simplification
|
2020-01-27 23:09:14 +08:00
|
|
|
// CHECK: %{{[0-9]+}} = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
2019-11-26 10:36:38 +08:00
|
|
|
%0 = "onnx.MatMul"(%a0, %a1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
|
|
|
%1 = "onnx.Add"(%0, %a2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
|
|
|
"std.return"(%1) : (tensor<10x10xf32>) -> ()
|
2019-11-19 10:08:21 +08:00
|
|
|
}
|
2019-11-21 10:57:13 +08:00
|
|
|
|
2019-11-26 10:49:48 +08:00
|
|
|
// onnx.MatMul ops with more than one result uses should not get fused
|
|
|
|
// CHECK-LABEL: func @test_sigmoid_add(%{{.*}}: tensor<10x10xf32>, %{{.*}}: tensor<10x10xf32>, %{{.*}}: tensor<10x10xf32>) -> tensor<10x10xf32>
|
|
|
|
func @test_sigmoid_add(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>, %a2: tensor<10x10xf32>) -> tensor<10x10xf32> {
|
|
|
|
// CHECK: %{{[0-9]+}} = "onnx.MatMul"(%{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
|
|
|
%0 = "onnx.MatMul"(%a0, %a1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
|
|
|
%1 = "onnx.Add"(%0, %a2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
|
|
|
%2 = "onnx.Add"(%0, %a1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
|
|
|
%3 = "onnx.Add"(%1, %2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
|
|
|
"std.return"(%3) : (tensor<10x10xf32>) -> ()
|
|
|
|
}
|
|
|
|
|
|
|
|
// CHECK-LABEL: @test_identity_identity(%{{.*}}: tensor<10x10xf32>, %{{.*}}: tensor<10x10xf32>) -> tensor<10x10xf32>
|
2019-11-21 10:57:13 +08:00
|
|
|
func @test_identity_identity(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>) -> tensor<10x10xf32> {
|
2019-11-26 10:49:48 +08:00
|
|
|
// CHECK-NEXT: %{{[0-9]+}} = "onnx.Add"(%{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
|
|
|
%0 = "onnx.Identity"(%a0) : (tensor<10x10xf32>) -> tensor<10x10xf32>
|
|
|
|
%1 = "onnx.Identity"(%a1) : (tensor<10x10xf32>) -> tensor<10x10xf32>
|
|
|
|
%2 = "onnx.Add"(%0, %1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
|
|
|
"std.return"(%2) : (tensor<10x10xf32>) -> ()
|
2019-11-21 10:57:13 +08:00
|
|
|
}
|