Add MLIR generated SignOp GPU kernel for complex types.
PiperOrigin-RevId: 379924456
This commit is contained in:
		
							parent
							
								
									73ed8cbf82
								
							
						
					
					
						commit
						376da8592f
					
				|  | @ -887,6 +887,8 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc, | |||
|         b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one); | ||||
|     Value or_op = b->create<::mlir::OrOp>(loc, ashr, one); | ||||
|     return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op); | ||||
|   } else if (element_type.isa<ComplexType>()) { | ||||
|     return b->create<::mlir::complex::SignOp>(loc, element_type, args.front()); | ||||
|   } | ||||
|   return nullptr; | ||||
| } | ||||
|  |  | |||
|  | @ -256,6 +256,18 @@ func @complex_neg(%arg0: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> { | |||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: func @complex_sign | ||||
| func @complex_sign( | ||||
|     %arg0: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> { | ||||
|   // CHECK: linalg.generic | ||||
|   // CHECK: complex.sign | ||||
|   %0 = "mhlo.sign"(%arg0) : (tensor<2x2xcomplex<f32>>) | ||||
|                           -> tensor<2x2xcomplex<f32>> | ||||
|   return %0 : tensor<2x2xcomplex<f32>> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: func @float_tanh | ||||
| func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { | ||||
|   // CHECK: linalg.generic | ||||
|  |  | |||
|  | @ -810,6 +810,20 @@ func @sign_i16(%input: memref<2x2xi16>, %result: memref<2x2xi16>) { | |||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: func @sign_complex | ||||
| func @sign_complex(%input: memref<2x2xcomplex<f32>>, | ||||
|                    %result: memref<2x2xcomplex<f32>>) { | ||||
|   "lmhlo.sign"(%input, %result) : (memref<2x2xcomplex<f32>>, | ||||
|                                    memref<2x2xcomplex<f32>>) -> () | ||||
|   return | ||||
| } | ||||
| // CHECK: linalg.generic | ||||
| // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: complex<f32>, %[[RESULT_OUT:.*]]): | ||||
| // CHECK-NEXT:   %[[RESULT:.*]] = complex.sign %[[OPERAND_IN]] : complex<f32> | ||||
| // CHECK-NEXT:   linalg.yield %[[RESULT]] : complex<f32> | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: func @sqrt | ||||
| func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { | ||||
|   "lmhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue