Remove rank constraints in gemm fusion (#101)
* Remove rank constraints in gemm fusion * Add an MLIR test Co-authored-by: Tian Jin <tjingrant@gmail.com> Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
		
							parent
							
								
									24d89625e3
								
							
						
					
					
						commit
						0c4a010283
					
				|  | @ -45,7 +45,7 @@ def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3), | |||
| // onnx.add(onnx.Gemm(%X, %Y, None), %Z) = onnx.Gemm(%X, %Y, %Z) | ||||
| def FuseGemmFollowedByAddition : Pat<(ONNXAddOp (ONNXGemmOp:$res $m1, $m2, $none, $alpha, $beta, $transA, $transB), $bias), | ||||
|                                      (ONNXGemmOp $m1, $m2, $bias, $alpha, $beta, $transA, $transB), | ||||
|                                      [(HasOneUse $res), (HasRankOf<2> $m1), (HasRankOf<2> $m2), (HasNoneType $none)]>; | ||||
|                                      [(HasOneUse $res), (HasNoneType $none)]>; | ||||
| 
 | ||||
| // ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X) | ||||
| def IdentityEliminationPattern : Pat<(ONNXIdentityOp $arg), | ||||
|  |  | |||
|  | @ -111,4 +111,16 @@ func @test_gemm_add_fusion(%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32 | |||
| 
 | ||||
|   // CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128xf32>) -> tensor<*xf32> | ||||
|   // return [[GEMM]] : tensor<*xf32> | ||||
| } | ||||
| } | ||||
| 
 | ||||
| //CHECK-LABEL: @test_gemm_add_fusion_rank3(%{{.*}}: tensor<128x128x256xf32>, %{{.*}}: tensor<128x128x256xf32>, %{{.*}}: tensor<256xf32>) -> tensor<*xf32> { | ||||
| func @test_gemm_add_fusion_rank3(%arg0: tensor<128x128x256xf32>, %arg1: tensor<128x128x256xf32>, %arg2: tensor<256xf32>) -> tensor<*xf32> { | ||||
|   %cst = constant unit | ||||
|   %0 = "onnx.Gemm"(%arg0, %arg1, %cst) : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, none) -> tensor<*xf32> | ||||
|   %1 = "onnx.Add"(%0, %arg2) : (tensor<*xf32>, tensor<256xf32>) -> tensor<*xf32> | ||||
|   return %1 : tensor<*xf32> | ||||
| 
 | ||||
|   // CHECK-NEXT: [[GEMM:%.+]] = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = 0 : i64, transB = 0 : i64} : (tensor<128x128x256xf32>, tensor<128x128x256xf32>, tensor<256xf32>) -> tensor<*xf32> | ||||
|   // return [[GEMM]] : tensor<*xf32> | ||||
| } | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue