Add MLIR generated SignOp GPU kernel for complex types.
PiperOrigin-RevId: 379924456
This commit is contained in:
parent
73ed8cbf82
commit
376da8592f
|
@ -887,6 +887,8 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
|
|||
b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one);
|
||||
Value or_op = b->create<::mlir::OrOp>(loc, ashr, one);
|
||||
return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op);
|
||||
} else if (element_type.isa<ComplexType>()) {
|
||||
return b->create<::mlir::complex::SignOp>(loc, element_type, args.front());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -256,6 +256,18 @@ func @complex_neg(%arg0: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex_sign
|
||||
func @complex_sign(
|
||||
%arg0: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: complex.sign
|
||||
%0 = "mhlo.sign"(%arg0) : (tensor<2x2xcomplex<f32>>)
|
||||
-> tensor<2x2xcomplex<f32>>
|
||||
return %0 : tensor<2x2xcomplex<f32>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @float_tanh
|
||||
func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
|
|
|
@ -810,6 +810,20 @@ func @sign_i16(%input: memref<2x2xi16>, %result: memref<2x2xi16>) {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @sign_complex
|
||||
func @sign_complex(%input: memref<2x2xcomplex<f32>>,
|
||||
%result: memref<2x2xcomplex<f32>>) {
|
||||
"lmhlo.sign"(%input, %result) : (memref<2x2xcomplex<f32>>,
|
||||
memref<2x2xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: complex<f32>, %[[RESULT_OUT:.*]]):
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = complex.sign %[[OPERAND_IN]] : complex<f32>
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : complex<f32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @sqrt
|
||||
func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"lmhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
|
|
Loading…
Reference in New Issue