Add MLIR generated NegOp GPU kernel for complex types.
PiperOrigin-RevId: 379905236
This commit is contained in:
		
							parent
							
								
									8c8e81cb69
								
							
						
					
					
						commit
						73ed8cbf82
					
				|  | @ -671,8 +671,9 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc, | ||||||
|                                                   ArrayRef<Value> args, |                                                   ArrayRef<Value> args, | ||||||
|                                                   OpBuilder* b) { |                                                   OpBuilder* b) { | ||||||
|   Type element_type = getElementTypeOrSelf(args.front().getType()); |   Type element_type = getElementTypeOrSelf(args.front().getType()); | ||||||
|   if (element_type.isa<FloatType>()) { |   if (element_type.isa<ComplexType, FloatType>()) { | ||||||
|     return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::NegFOp>{}( |     return MapLhloOpToScalarOpImpl<isFloatType, ::mlir::NegFOp, isComplexType, | ||||||
|  |                                    ::mlir::complex::NegOp>{}( | ||||||
|         loc, result_types, arg_types, args, b); |         loc, result_types, arg_types, args, b); | ||||||
|   } |   } | ||||||
|   if (element_type.isa<IntegerType>()) { |   if (element_type.isa<IntegerType>()) { | ||||||
|  |  | ||||||
|  | @ -245,6 +245,17 @@ func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { | ||||||
| 
 | 
 | ||||||
| // ----- | // ----- | ||||||
| 
 | 
 | ||||||
|  | // CHECK-LABEL: func @complex_neg | ||||||
|  | func @complex_neg(%arg0: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> { | ||||||
|  |   // CHECK: linalg.generic | ||||||
|  |   // CHECK: complex.neg | ||||||
|  |   %0 = "mhlo.negate"(%arg0) : (tensor<2x2xcomplex<f32>>) | ||||||
|  |                             -> tensor<2x2xcomplex<f32>> | ||||||
|  |   return %0 : tensor<2x2xcomplex<f32>> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
| // CHECK-LABEL: func @float_tanh | // CHECK-LABEL: func @float_tanh | ||||||
| func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { | func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { | ||||||
|   // CHECK: linalg.generic |   // CHECK: linalg.generic | ||||||
|  |  | ||||||
|  | @ -690,6 +690,22 @@ func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { | ||||||
| 
 | 
 | ||||||
| // ----- | // ----- | ||||||
| 
 | 
 | ||||||
|  | // CHECK-LABEL: func @complex_neg | ||||||
|  | func @complex_neg(%input: memref<2x2xcomplex<f32>>, | ||||||
|  |                   %result: memref<2x2xcomplex<f32>>) { | ||||||
|  |   "lmhlo.negate"(%input, %result) : (memref<2x2xcomplex<f32>>, | ||||||
|  |                                      memref<2x2xcomplex<f32>>) -> () | ||||||
|  |   return | ||||||
|  | } | ||||||
|  | // CHECK: linalg.generic | ||||||
|  | // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: complex<f32>, %[[RESULT_OUT:.*]]): | ||||||
|  | // CHECK-NEXT:   %[[RESULT:.*]] = complex.neg %[[OPERAND_IN]] : complex<f32> | ||||||
|  | // CHECK-NEXT:   linalg.yield %[[RESULT]] : complex<f32> | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
| // CHECK-LABEL: func @negi | // CHECK-LABEL: func @negi | ||||||
| func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { | func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { | ||||||
|   "lmhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () |   "lmhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue