[XLA:GPU] Migrate GEMM Thunk emission to MLIR.
- Map Custom call for GEMM in XLA HLO to Gemm/Gemm bias operations in LHLO GPU dialect. - Make 'algorithm' an optional attribute to better match with XLA HLO backend config. - Replace 'alpha' with 'alpha_real' and 'alpha_complex' to support complex GEMM correctly. - Generate GemmThunk off of LHLO GPU Gemm operations. PiperOrigin-RevId: 345250840
This commit is contained in:
parent
1b711670bc
commit
d7bd5233ab
|
@ -179,9 +179,10 @@ def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> {
|
|||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$output,
|
||||
DotDimensionNumbers:$dot_dimension_numbers,
|
||||
F64Attr:$alpha,
|
||||
F64Attr:$alpha_real,
|
||||
F64Attr:$alpha_imag,
|
||||
I64Attr:$batch_size,
|
||||
I64Attr:$algorithm);
|
||||
OptionalAttr<I64Attr>:$algorithm);
|
||||
}
|
||||
|
||||
// output = alpha(lhs * rhs) + beta * bias
|
||||
|
@ -192,10 +193,11 @@ def LHLOGPU_GEMM_BiasOp : LHLOGPU_Op<"gemm_bias"> {
|
|||
Arg<LHLO_Buffer, "", [MemRead]>:$bias,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$output,
|
||||
DotDimensionNumbers:$dot_dimension_numbers,
|
||||
F64Attr:$alpha,
|
||||
F64Attr:$alpha_real,
|
||||
F64Attr:$alpha_imag,
|
||||
F64Attr:$beta,
|
||||
I64Attr:$batch_size,
|
||||
I64Attr:$algorithm);
|
||||
OptionalAttr<I64Attr>:$algorithm);
|
||||
}
|
||||
|
||||
def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> {
|
||||
|
|
|
@ -65,7 +65,8 @@ func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5xf32>
|
|||
rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
|
||||
lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>,
|
||||
rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>},
|
||||
alpha = 0.5,
|
||||
alpha_real = 0.5,
|
||||
alpha_imag = 0.0,
|
||||
batch_size = 1,
|
||||
algorithm = 0}
|
||||
: (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>) -> ()
|
||||
|
@ -81,7 +82,8 @@ func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>,
|
|||
rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
|
||||
lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>,
|
||||
rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>},
|
||||
alpha = 0.5,
|
||||
alpha_real = 0.5,
|
||||
alpha_imag = 0.0,
|
||||
beta = 1.0,
|
||||
batch_size = 1,
|
||||
algorithm = 0}
|
||||
|
|
Loading…
Reference in New Issue