Add MLIR generated NegOp GPU kernel for complex types.
PiperOrigin-RevId: 379905236
This commit is contained in:
parent
8c8e81cb69
commit
73ed8cbf82
|
@ -671,8 +671,9 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
Type element_type = getElementTypeOrSelf(args.front().getType());
|
Type element_type = getElementTypeOrSelf(args.front().getType());
|
||||||
if (element_type.isa<FloatType>()) {
|
if (element_type.isa<ComplexType, FloatType>()) {
|
||||||
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::NegFOp>{}(
|
return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::NegFOp, isComplexType,
|
||||||
|
::mlir::complex::NegOp>{}(
|
||||||
loc, result_types, arg_types, args, b);
|
loc, result_types, arg_types, args, b);
|
||||||
}
|
}
|
||||||
if (element_type.isa<IntegerType>()) {
|
if (element_type.isa<IntegerType>()) {
|
||||||
|
|
|
@ -245,6 +245,17 @@ func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @complex_neg
|
||||||
|
func @complex_neg(%arg0: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: complex.neg
|
||||||
|
%0 = "mhlo.negate"(%arg0) : (tensor<2x2xcomplex<f32>>)
|
||||||
|
-> tensor<2x2xcomplex<f32>>
|
||||||
|
return %0 : tensor<2x2xcomplex<f32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @float_tanh
|
// CHECK-LABEL: func @float_tanh
|
||||||
func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
|
|
|
@ -690,6 +690,22 @@ func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @complex_neg
|
||||||
|
func @complex_neg(%input: memref<2x2xcomplex<f32>>,
|
||||||
|
%result: memref<2x2xcomplex<f32>>) {
|
||||||
|
"lmhlo.negate"(%input, %result) : (memref<2x2xcomplex<f32>>,
|
||||||
|
memref<2x2xcomplex<f32>>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: complex<f32>, %[[RESULT_OUT:.*]]):
|
||||||
|
// CHECK-NEXT: %[[RESULT:.*]] = complex.neg %[[OPERAND_IN]] : complex<f32>
|
||||||
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : complex<f32>
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @negi
|
// CHECK-LABEL: func @negi
|
||||||
func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||||
"lmhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
"lmhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||||
|
|
Loading…
Reference in New Issue