Add MLIR generated NegOp GPU kernel for complex types.

PiperOrigin-RevId: 379905236
This commit is contained in:
Adrian Kuegel 2021-06-17 01:30:08 -07:00 committed by TensorFlow MLIR Team
parent 8c8e81cb69
commit 73ed8cbf82
3 changed files with 30 additions and 2 deletions

View File

@ -671,8 +671,9 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
ArrayRef<Value> args,
OpBuilder* b) {
Type element_type = getElementTypeOrSelf(args.front().getType());
if (element_type.isa<FloatType>()) {
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::NegFOp>{}(
if (element_type.isa<ComplexType, FloatType>()) {
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::NegFOp, isComplexType,
::mlir::complex::NegOp>{}(
loc, result_types, arg_types, args, b);
}
if (element_type.isa<IntegerType>()) {

View File

@ -245,6 +245,17 @@ func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// -----
// CHECK-LABEL: func @complex_neg
func @complex_neg(%arg0: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
// CHECK: linalg.generic
// CHECK: complex.neg
%0 = "mhlo.negate"(%arg0) : (tensor<2x2xcomplex<f32>>)
-> tensor<2x2xcomplex<f32>>
return %0 : tensor<2x2xcomplex<f32>>
}
// -----
// CHECK-LABEL: func @float_tanh
func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic

View File

@ -690,6 +690,22 @@ func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// -----
// CHECK-LABEL: func @complex_neg
func @complex_neg(%input: memref<2x2xcomplex<f32>>,
%result: memref<2x2xcomplex<f32>>) {
"lmhlo.negate"(%input, %result) : (memref<2x2xcomplex<f32>>,
memref<2x2xcomplex<f32>>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: complex<f32>, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[RESULT:.*]] = complex.neg %[[OPERAND_IN]] : complex<f32>
// CHECK-NEXT: linalg.yield %[[RESULT]] : complex<f32>
// -----
// -----
// CHECK-LABEL: func @negi
func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
"lmhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()