[XLA:GPU] Migrate GEMM Thunk emission to MLIR.

- Map Custom call for GEMM in XLA HLO to Gemm/Gemm bias operations in LHLO GPU
  dialect.
- Make 'algorithm' an optional attribute to better match with XLA HLO backend config.
- Replace 'alpha' with 'alpha_real' and 'alpha_complex' to support complex GEMM correctly.
- Generate GemmThunk off of LHLO GPU Gemm operations.

PiperOrigin-RevId: 345250840
This commit is contained in:
Rahul Joshi 2020-12-02 09:42:26 -08:00 committed by TensorFlow MLIR Team
parent 1b711670bc
commit d7bd5233ab
2 changed files with 10 additions and 6 deletions

View File

@ -179,9 +179,10 @@ def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> {
Arg<LHLO_Buffer, "", [MemRead]>:$rhs, Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_Buffer, "", [MemRead]>:$output, Arg<LHLO_Buffer, "", [MemRead]>:$output,
DotDimensionNumbers:$dot_dimension_numbers, DotDimensionNumbers:$dot_dimension_numbers,
F64Attr:$alpha, F64Attr:$alpha_real,
F64Attr:$alpha_imag,
I64Attr:$batch_size, I64Attr:$batch_size,
I64Attr:$algorithm); OptionalAttr<I64Attr>:$algorithm);
} }
// output = alpha(lhs * rhs) + beta * bias // output = alpha(lhs * rhs) + beta * bias
@ -192,10 +193,11 @@ def LHLOGPU_GEMM_BiasOp : LHLOGPU_Op<"gemm_bias"> {
Arg<LHLO_Buffer, "", [MemRead]>:$bias, Arg<LHLO_Buffer, "", [MemRead]>:$bias,
Arg<LHLO_Buffer, "", [MemRead]>:$output, Arg<LHLO_Buffer, "", [MemRead]>:$output,
DotDimensionNumbers:$dot_dimension_numbers, DotDimensionNumbers:$dot_dimension_numbers,
F64Attr:$alpha, F64Attr:$alpha_real,
F64Attr:$alpha_imag,
F64Attr:$beta, F64Attr:$beta,
I64Attr:$batch_size, I64Attr:$batch_size,
I64Attr:$algorithm); OptionalAttr<I64Attr>:$algorithm);
} }
def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> { def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> {

View File

@ -65,7 +65,8 @@ func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5xf32>
rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>,
rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>}, rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>},
alpha = 0.5, alpha_real = 0.5,
alpha_imag = 0.0,
batch_size = 1, batch_size = 1,
algorithm = 0} algorithm = 0}
: (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>) -> () : (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>) -> ()
@ -81,7 +82,8 @@ func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>,
rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>,
lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>,
rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>}, rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>},
alpha = 0.5, alpha_real = 0.5,
alpha_imag = 0.0,
beta = 1.0, beta = 1.0,
batch_size = 1, batch_size = 1,
algorithm = 0} algorithm = 0}