From d7bd5233ab153283e942d97bc09516115e7aa801 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Wed, 2 Dec 2020 09:42:26 -0800 Subject: [PATCH] [XLA:GPU] Migrate GEMM Thunk emission to MLIR. - Map Custom call for GEMM in XLA HLO to Gemm/Gemm bias operations in LHLO GPU dialect. - Make 'algorithm' an optional attribute to better match with XLA HLO backend config. - Replace 'alpha' with 'alpha_real' and 'alpha_complex' to support complex GEMM correctly. - Generate GemmThunk off of LHLO GPU Gemm operations. PiperOrigin-RevId: 345250840 --- include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td | 10 ++++++---- tests/lhlo_gpu_ops.mlir | 6 ++++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td index 4613627..f4f2d85 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td @@ -179,9 +179,10 @@ def LHLOGPU_GEMMOp : LHLOGPU_Op<"gemm"> { Arg:$rhs, Arg:$output, DotDimensionNumbers:$dot_dimension_numbers, - F64Attr:$alpha, + F64Attr:$alpha_real, + F64Attr:$alpha_imag, I64Attr:$batch_size, - I64Attr:$algorithm); + OptionalAttr:$algorithm); } // output = alpha(lhs * rhs) + beta * bias @@ -192,10 +193,11 @@ def LHLOGPU_GEMM_BiasOp : LHLOGPU_Op<"gemm_bias"> { Arg:$bias, Arg:$output, DotDimensionNumbers:$dot_dimension_numbers, - F64Attr:$alpha, + F64Attr:$alpha_real, + F64Attr:$alpha_imag, F64Attr:$beta, I64Attr:$batch_size, - I64Attr:$algorithm); + OptionalAttr:$algorithm); } def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> { diff --git a/tests/lhlo_gpu_ops.mlir b/tests/lhlo_gpu_ops.mlir index a939cab..bd5df38 100644 --- a/tests/lhlo_gpu_ops.mlir +++ b/tests/lhlo_gpu_ops.mlir @@ -65,7 +65,8 @@ func @gemm(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, %output:memref<5x5xf32> rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>}, - alpha = 0.5, + alpha_real = 0.5, + alpha_imag = 0.0, batch_size = 1, algorithm = 0} : (memref<5x4xf32>, memref<4x5xf32>, memref<5x5xf32>) -> () @@ -81,7 +82,8 @@ func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>, rhs_batching_dimensions = dense<[1,1]> : tensor<2xi64>, lhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>, rhs_contracting_dimensions = dense<[1,1]> : tensor<2xi64>}, - alpha = 0.5, + alpha_real = 0.5, + alpha_imag = 0.0, beta = 1.0, batch_size = 1, algorithm = 0}