[XLA:GPU] Migrate GEMM Thunk emission to MLIR.
- Map Custom call for GEMM in XLA HLO to Gemm/Gemm bias operations in LHLO GPU dialect. - Make 'algorithm' an optional attribute to better match with XLA HLO backend config. - Replace 'alpha' with 'alpha_real' and 'alpha_complex' to support complex GEMM correctly. - Generate GemmThunk off of LHLO GPU Gemm operations. PiperOrigin-RevId: 345250840
This commit is contained in:
parent
1b711670bc
commit
d7bd5233ab
|
@ -179,9 +179,10 @@ def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> {
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$output,
|
Arg<LHLO_Buffer, "", [MemRead]>:$output,
|
||||||
DotDimensionNumbers:$dot_dimension_numbers,
|
DotDimensionNumbers:$dot_dimension_numbers,
|
||||||
F64Attr:$alpha,
|
F64Attr:$alpha_real,
|
||||||
|
F64Attr:$alpha_imag,
|
||||||
I64Attr:$batch_size,
|
I64Attr:$batch_size,
|
||||||
I64Attr:$algorithm);
|
OptionalAttr<I64Attr>:$algorithm);
|
||||||
}
|
}
|
||||||
|
|
||||||
// output = alpha(lhs * rhs) + beta * bias
|
// output = alpha(lhs * rhs) + beta * bias
|
||||||
|
@ -192,10 +193,11 @@ def LHLOGPU_GEMM_BiasOp : LHLOGPU_Op<"gemm_bias"> {
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$bias,
|
Arg<LHLO_Buffer, "", [MemRead]>:$bias,
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$output,
|
Arg<LHLO_Buffer, "", [MemRead]>:$output,
|
||||||
DotDimensionNumbers:$dot_dimension_numbers,
|
DotDimensionNumbers:$dot_dimension_numbers,
|
||||||
F64Attr:$alpha,
|
F64Attr:$alpha_real,
|
||||||
|
F64Attr:$alpha_imag,
|
||||||
F64Attr:$beta,
|
F64Attr:$beta,
|
||||||
I64Attr:$batch_size,
|
I64Attr:$batch_size,
|
||||||
I64Attr:$algorithm);
|
OptionalAttr<I64Attr>:$algorithm);
|
||||||
}
|
}
|
||||||
|
|
||||||
def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> {
|
def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> {
|
||||||
|
|
|
@ -65,7 +65,8 @@ func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5xf32>
|
||||||
rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
|
rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
|
||||||
lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>,
|
lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>,
|
||||||
rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>},
|
rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>},
|
||||||
alpha = 0.5,
|
alpha_real = 0.5,
|
||||||
|
alpha_imag = 0.0,
|
||||||
batch_size = 1,
|
batch_size = 1,
|
||||||
algorithm = 0}
|
algorithm = 0}
|
||||||
: (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>) -> ()
|
: (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>) -> ()
|
||||||
|
@ -81,7 +82,8 @@ func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>,
|
||||||
rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
|
rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
|
||||||
lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>,
|
lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>,
|
||||||
rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>},
|
rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>},
|
||||||
alpha = 0.5,
|
alpha_real = 0.5,
|
||||||
|
alpha_imag = 0.0,
|
||||||
beta = 1.0,
|
beta = 1.0,
|
||||||
batch_size = 1,
|
batch_size = 1,
|
||||||
algorithm = 0}
|
algorithm = 0}
|
||||||
|
|
Loading…
Reference in New Issue