Remove rank constraints in gemm fusion (#101)
* Remove rank constraints in gemm fusion * Add an MLIR test Co-authored-by: Tian Jin <tjingrant@gmail.com> Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
24d89625e3
commit
0c4a010283
|
@ -45,7 +45,7 @@ def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
|
|||
// onnx.add(onnx.Gemm(%X, %Y, None), %Z) = onnx.Gemm(%X, %Y, %Z)
|
||||
def FuseGemmFollowedByAddition : Pat<(ONNXAddOp (ONNXGemmOp:$res $m1, $m2, $none, $alpha, $beta, $transA, $transB), $bias),
|
||||
(ONNXGemmOp $m1, $m2, $bias, $alpha, $beta, $transA, $transB),
|
||||
[(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2), (HasNoneType $none)]>;
|
||||
[(HasOneUse $res), (HasNoneType $none)]>;
|
||||
|
||||
// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X)
|
||||
def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg),
|
||||
|
|
|
@ -112,3 +112,15 @@ func @test_gemm_add_fusion(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32
|
|||
// CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128xf32>) -> tensor<*xf32>
|
||||
// return [[GEMM]] : tensor<*xf32>
|
||||
}
|
||||
|
||||
//CHECK-LABEL: @test_gemm_add_fusion_rank3(%{{.*}}: tensor<128x128x256xf32>, %{{.*}}: tensor<128x128x256xf32>, %{{.*}}: tensor<256xf32>) -> tensor<*xf32> {
|
||||
func @test_gemm_add_fusion_rank3(%arg0: tensor<128x128x256xf32>, %arg1: tensor<128x128x256xf32>, %arg2: tensor<256xf32>) -> tensor<*xf32> {
|
||||
%cst = constant unit
|
||||
%0 = "onnx.Gemm"(%arg0, %arg1, %cst) : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, none) -> tensor<*xf32>
|
||||
%1 = "onnx.Add"(%0, %arg2) : (tensor<*xf32>, tensor<256xf32>) -> tensor<*xf32>
|
||||
return %1 : tensor<*xf32>
|
||||
|
||||
// CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, tensor<256xf32>) -> tensor<*xf32>
|
||||
// return [[GEMM]] : tensor<*xf32>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue