Add MLIR generated NegOp GPU kernel for complex types.
PiperOrigin-RevId: 379905236
This commit is contained in:
parent
8c8e81cb69
commit
73ed8cbf82
|
@ -671,8 +671,9 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
|
|||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type element_type = getElementTypeOrSelf(args.front().getType());
|
||||
if (element_type.isa<FloatType>()) {
|
||||
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::NegFOp>{}(
|
||||
if (element_type.isa<ComplexType, FloatType>()) {
|
||||
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::NegFOp, isComplexType,
|
||||
::mlir::complex::NegOp>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
if (element_type.isa<IntegerType>()) {
|
||||
|
|
|
@ -245,6 +245,17 @@ func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex_neg
|
||||
func @complex_neg(%arg0: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: complex.neg
|
||||
%0 = "mhlo.negate"(%arg0) : (tensor<2x2xcomplex<f32>>)
|
||||
-> tensor<2x2xcomplex<f32>>
|
||||
return %0 : tensor<2x2xcomplex<f32>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @float_tanh
|
||||
func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
|
|
|
@ -690,6 +690,22 @@ func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex_neg
|
||||
func @complex_neg(%input: memref<2x2xcomplex<f32>>,
|
||||
%result: memref<2x2xcomplex<f32>>) {
|
||||
"lmhlo.negate"(%input, %result) : (memref<2x2xcomplex<f32>>,
|
||||
memref<2x2xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: complex<f32>, %[[RESULT_OUT:.*]]):
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = complex.neg %[[OPERAND_IN]] : complex<f32>
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : complex<f32>
|
||||
|
||||
// -----
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @negi
|
||||
func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
"lmhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
|
|
Loading…
Reference in New Issue