[XLA:GPU] Migrate GEMM Thunk emission to MLIR.
- Map Custom call for GEMM in XLA HLO to Gemm/Gemm bias operations in LHLO GPU dialect. - Make 'algorithm' an optional attribute to better match with XLA HLO backend config. - Replace 'alpha' with 'alpha_real' and 'alpha_complex' to support complex GEMM correctly. - Generate GemmThunk off of LHLO GPU Gemm operations. PiperOrigin-RevId: 345250840
This commit is contained in:
		
							parent
							
								
									1b711670bc
								
							
						
					
					
						commit
						d7bd5233ab
					
				|  | @ -179,9 +179,10 @@ def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> { | ||||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$rhs, |     Arg<LHLO_Buffer, "", [MemRead]>:$rhs, | ||||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$output, |     Arg<LHLO_Buffer, "", [MemRead]>:$output, | ||||||
|     DotDimensionNumbers:$dot_dimension_numbers, |     DotDimensionNumbers:$dot_dimension_numbers, | ||||||
|     F64Attr:$alpha, |     F64Attr:$alpha_real, | ||||||
|  |     F64Attr:$alpha_imag, | ||||||
|     I64Attr:$batch_size, |     I64Attr:$batch_size, | ||||||
|     I64Attr:$algorithm); |     OptionalAttr<I64Attr>:$algorithm); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // output = alpha(lhs * rhs) + beta * bias | // output = alpha(lhs * rhs) + beta * bias | ||||||
|  | @ -192,10 +193,11 @@ def LHLOGPU_GEMM_BiasOp : LHLOGPU_Op<"gemm_bias"> { | ||||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$bias, |     Arg<LHLO_Buffer, "", [MemRead]>:$bias, | ||||||
|     Arg<LHLO_Buffer, "", [MemRead]>:$output, |     Arg<LHLO_Buffer, "", [MemRead]>:$output, | ||||||
|     DotDimensionNumbers:$dot_dimension_numbers, |     DotDimensionNumbers:$dot_dimension_numbers, | ||||||
|     F64Attr:$alpha, |     F64Attr:$alpha_real, | ||||||
|  |     F64Attr:$alpha_imag, | ||||||
|     F64Attr:$beta, |     F64Attr:$beta, | ||||||
|     I64Attr:$batch_size, |     I64Attr:$batch_size, | ||||||
|     I64Attr:$algorithm); |     OptionalAttr<I64Attr>:$algorithm); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> { | def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> { | ||||||
|  |  | ||||||
|  | @ -65,7 +65,8 @@ func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5xf32> | ||||||
|        rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, |        rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, | ||||||
|        lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>, |        lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>, | ||||||
|        rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>}, |        rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>}, | ||||||
|        alpha = 0.5, |        alpha_real = 0.5, | ||||||
|  |        alpha_imag = 0.0, | ||||||
|        batch_size = 1, |        batch_size = 1, | ||||||
|        algorithm = 0} |        algorithm = 0} | ||||||
|     : (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>) -> () |     : (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>) -> () | ||||||
|  | @ -81,7 +82,8 @@ func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, | ||||||
|        rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, |        rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, | ||||||
|        lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>, |        lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>, | ||||||
|        rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>}, |        rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>}, | ||||||
|        alpha = 0.5, |        alpha_real = 0.5, | ||||||
|  |        alpha_imag = 0.0, | ||||||
|        beta = 1.0, |        beta = 1.0, | ||||||
|        batch_size = 1, |        batch_size = 1, | ||||||
|        algorithm = 0} |        algorithm = 0} | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue