[XLA:GPU] Convert Cholesky custom call in XLA HLO to LHLO GPU Dialect.
- Restructured LHLO GPU Cholesky to better match XLA HLO by eliminating the untyped buffer and changing is_upper attribute to is_lower. - Change LhloDialectEmitter to emit LHLO GPU Cholesky operation. PiperOrigin-RevId: 343873516
This commit is contained in:
parent
aa4d33149a
commit
ac54c5ccfa
|
@ -202,9 +202,9 @@ def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> {
|
|||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
Arg<UntypedBuffer, "", [MemWrite]>:$scratch,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$scratch,
|
||||
Arg<I32Buffer, "", [MemWrite]>:$info,
|
||||
BoolAttr:$is_upper);
|
||||
BoolAttr:$is_lower);
|
||||
}
|
||||
|
||||
#endif // LHLO_GPU_OPS
|
||||
|
|
|
@ -93,7 +93,7 @@ func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>,
|
|||
func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) {
|
||||
%scratch = alloc() : memref<32xi8>
|
||||
%info = alloc() : memref<32xi32>
|
||||
"lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_upper = true }
|
||||
"lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_lower = true }
|
||||
: (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue