[XLA:GPU] Convert Cholesky custom call in XLA HLO to LHLO GPU Dialect.
- Restructured LHLO GPU Cholesky to better match XLA HLO by eliminating the untyped buffer and changing is_upper attribute to is_lower. - Change LhloDialectEmitter to emit LHLO GPU Cholesky operation. PiperOrigin-RevId: 343873516
This commit is contained in:
parent
aa4d33149a
commit
ac54c5ccfa
|
@ -202,9 +202,9 @@ def LHLOGPU_CholeskyOp : LHLOGPU_Op<"cholesky"> {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||||
Arg<UntypedBuffer, "", [MemWrite]>:$scratch,
|
Arg<LHLO_Buffer, "", [MemWrite]>:$scratch,
|
||||||
Arg<I32Buffer, "", [MemWrite]>:$info,
|
Arg<I32Buffer, "", [MemWrite]>:$info,
|
||||||
BoolAttr:$is_upper);
|
BoolAttr:$is_lower);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // LHLO_GPU_OPS
|
#endif // LHLO_GPU_OPS
|
||||||
|
|
|
@ -93,7 +93,7 @@ func @gemm_bias(%lhs: memref<5x4xf32>, %rhs: memref<4x5xf32>,
|
||||||
func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) {
|
func @cholesky(%arg : memref<10x10xf32>, %out: memref<10x10xf32>) {
|
||||||
%scratch = alloc() : memref<32xi8>
|
%scratch = alloc() : memref<32xi8>
|
||||||
%info = alloc() : memref<32xi32>
|
%info = alloc() : memref<32xi32>
|
||||||
"lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_upper = true }
|
"lmhlo_gpu.cholesky"(%arg, %out, %scratch, %info) { is_lower = true }
|
||||||
: (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> ()
|
: (memref<10x10xf32>, memref<10x10xf32>, memref<32xi8>, memref<32xi32>) -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue