[XLA:GPU] Convert Cholesky custom call in XLA HLO to LHLO GPU Dialect.

- Restructured LHLO GPU Cholesky to better match XLA HLO by eliminating the
  untyped buffer and changing is_upper attribute to is_lower.
- Change LhloDialectEmitter to emit LHLO GPU Cholesky operation.

PiperOrigin-RevId: 343873516
This commit is contained in:
Rahul Joshi 2020-11-23 10:05:34 -08:00 committed by TensorFlow MLIR Team
parent aa4d33149a
commit ac54c5ccfa
2 changed files with 3 additions and 3 deletions

View File

@ -202,9 +202,9 @@ def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> {
let arguments = (ins let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$input, Arg<LHLO_Buffer, "", [MemRead]>:$input,
Arg<LHLO_Buffer, "", [MemWrite]>:$output, Arg<LHLO_Buffer, "", [MemWrite]>:$output,
Arg<UntypedBuffer, "", [MemWrite]>:$scratch, Arg<LHLO_Buffer, "", [MemWrite]>:$scratch,
Arg<I32Buffer, "", [MemWrite]>:$info, Arg<I32Buffer, "", [MemWrite]>:$info,
BoolAttr:$is_upper); BoolAttr:$is_lower);
} }
#endif // LHLO_GPU_OPS #endif // LHLO_GPU_OPS

View File

@ -93,7 +93,7 @@ func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>,
func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) { func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) {
%scratch = alloc() : memref<32xi8> %scratch = alloc() : memref<32xi8>
%info = alloc() : memref<32xi32> %info = alloc() : memref<32xi32>
"lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_upper = true } "lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_lower = true }
: (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> () : (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> ()
return return
} }