Add MLIR generated SignOp GPU kernel for complex types.

PiperOrigin-RevId: 379924456
This commit is contained in:
Adrian Kuegel 2021-06-17 03:55:59 -07:00 committed by TensorFlow MLIR Team
parent 73ed8cbf82
commit 376da8592f
3 changed files with 28 additions and 0 deletions

View File

@ -887,6 +887,8 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one); b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one);
Value or_op = b->create<::mlir::OrOp>(loc, ashr, one); Value or_op = b->create<::mlir::OrOp>(loc, ashr, one);
return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op); return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op);
} else if (element_type.isa<ComplexType>()) {
return b->create<::mlir::complex::SignOp>(loc, element_type, args.front());
} }
return nullptr; return nullptr;
} }

View File

@ -256,6 +256,18 @@ func @complex_neg(%arg0: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
// ----- // -----
// CHECK-LABEL: func @complex_sign
func @complex_sign(
%arg0: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
// CHECK: linalg.generic
// CHECK: complex.sign
%0 = "mhlo.sign"(%arg0) : (tensor<2x2xcomplex<f32>>)
-> tensor<2x2xcomplex<f32>>
return %0 : tensor<2x2xcomplex<f32>>
}
// -----
// CHECK-LABEL: func @float_tanh // CHECK-LABEL: func @float_tanh
func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic // CHECK: linalg.generic

View File

@ -810,6 +810,20 @@ func @sign_i16(%input: memref<2x2xi16>, %result: memref<2x2xi16>) {
// ----- // -----
// CHECK-LABEL: func @sign_complex
func @sign_complex(%input: memref<2x2xcomplex<f32>>,
%result: memref<2x2xcomplex<f32>>) {
"lmhlo.sign"(%input, %result) : (memref<2x2xcomplex<f32>>,
memref<2x2xcomplex<f32>>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: complex<f32>, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[RESULT:.*]] = complex.sign %[[OPERAND_IN]] : complex<f32>
// CHECK-NEXT: linalg.yield %[[RESULT]] : complex<f32>
// -----
// CHECK-LABEL: func @sqrt // CHECK-LABEL: func @sqrt
func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"lmhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()