[XLA:GPU] Migrate GEMM Thunk emission to MLIR.
- Map Custom call for GEMM in XLA HLO to Gemm/Gemm bias operations in LHLO GPU dialect. - Make 'algorithm' an optional attribute to better match with XLA HLO backend config. - Replace 'alpha' with 'alpha_real' and 'alpha_complex' to support complex GEMM correctly. - Generate GemmThunk off of LHLO GPU Gemm operations. PiperOrigin-RevId: 345250840
This commit is contained in:
		
							parent
							
								
									1b711670bc
								
							
						
					
					
						commit
						d7bd5233ab
					
				|  | @ -179,9 +179,10 @@ def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> { | |||
|     Arg<LHLO_Buffer, "", [MemRead]>:$rhs, | ||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$output, | ||||
|     DotDimensionNumbers:$dot_dimension_numbers, | ||||
|     F64Attr:$alpha, | ||||
|     F64Attr:$alpha_real, | ||||
|     F64Attr:$alpha_imag, | ||||
|     I64Attr:$batch_size, | ||||
|     I64Attr:$algorithm); | ||||
|     OptionalAttr<I64Attr>:$algorithm); | ||||
| } | ||||
| 
 | ||||
| // output = alpha(lhs * rhs) + beta * bias | ||||
|  | @ -192,10 +193,11 @@ def LHLOGPU_GEMM_BiasOp : LHLOGPU_Op<"gemm_bias"> { | |||
|     Arg<LHLO_Buffer, "", [MemRead]>:$bias, | ||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$output, | ||||
|     DotDimensionNumbers:$dot_dimension_numbers, | ||||
|     F64Attr:$alpha, | ||||
|     F64Attr:$alpha_real, | ||||
|     F64Attr:$alpha_imag, | ||||
|     F64Attr:$beta, | ||||
|     I64Attr:$batch_size, | ||||
|     I64Attr:$algorithm); | ||||
|     OptionalAttr<I64Attr>:$algorithm); | ||||
| } | ||||
| 
 | ||||
| def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> { | ||||
|  |  | |||
|  | @ -65,7 +65,8 @@ func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5xf32> | |||
|        rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, | ||||
|        lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>, | ||||
|        rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>}, | ||||
|        alpha = 0.5, | ||||
|        alpha_real = 0.5, | ||||
|        alpha_imag = 0.0, | ||||
|        batch_size = 1, | ||||
|        algorithm = 0} | ||||
|     : (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>) -> () | ||||
|  | @ -81,7 +82,8 @@ func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, | |||
|        rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, | ||||
|        lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>, | ||||
|        rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>}, | ||||
|        alpha = 0.5, | ||||
|        alpha_real = 0.5, | ||||
|        alpha_imag = 0.0, | ||||
|        beta = 1.0, | ||||
|        batch_size = 1, | ||||
|        algorithm = 0} | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue