[KERNEL_GEN] Add unranked Conj kernel.
PiperOrigin-RevId: 344243271
This commit is contained in:
		
							parent
							
								
									a6948f6b41
								
							
						
					
					
						commit
						5583c63cab
					
				|  | @ -372,6 +372,19 @@ def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [], | |||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def HLOClient_ConjOp : HLOClient_UnaryElementwiseOp<"conj", [], | ||||
|     HLO_FpOrComplexTensor> { | ||||
|   let summary = "Conj operator"; | ||||
| 
 | ||||
|   let description = [{ | ||||
|     Returns `Conj(operand)` element-wise. | ||||
| 
 | ||||
|     $$ | ||||
|     \conj(x) = (\real(x), \neg(\imag(x))) | ||||
|     $$ | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [], | ||||
|     HLO_FpOrComplexTensor> { | ||||
|   let summary = "Sinh operation"; | ||||
|  |  | |||
|  | @ -60,6 +60,11 @@ def : Pat<(HLOClient_AtanOp $input), | |||
|     (HLO_ConstantLike<"1"> $input) | ||||
|   )>; | ||||
| 
 | ||||
| // Express `conj` as | ||||
| //   conj(x) = (re(x), -im(x)). | ||||
| def : Pat<(HLOClient_ConjOp $v), | ||||
|           (HLO_ComplexOp (HLO_RealOp $v), (HLO_NegOp (HLO_ImagOp $v)))>; | ||||
| 
 | ||||
| // Express `sinh` as | ||||
| //   sinh(x) = (e^x - e^-x) / 2                     if |x| < 1 | ||||
| //           = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise. | ||||
|  |  | |||
|  | @ -50,8 +50,8 @@ namespace { | |||
| 
 | ||||
| // TODO(herhut): Generate these out of op definitions.
 | ||||
| #define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep)                         \ | ||||
|   fn(AcosOp) sep fn(AtanOp) sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) \ | ||||
|       sep fn(TanOp) | ||||
|   fn(AcosOp) sep fn(AtanOp) sep fn(ConjOp) sep fn(ErfOp) sep fn(ErfcOp) \ | ||||
|       sep fn(SinhOp) sep fn(TanOp) | ||||
| 
 | ||||
| template <typename OpTy> | ||||
| inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { | ||||
|  |  | |||
|  | @ -24,3 +24,14 @@ func @constant_like_dynamic_shape(%arg : tensor<?x?xi64>) -> tensor<?x?xf32> { | |||
|   return %result : tensor<?x?xf32> | ||||
| } | ||||
| 
 | ||||
| // CHECK-LABEL: func @conj | ||||
| func @conj(%arg0: tensor<3xcomplex<f32>>) -> tensor<3xcomplex<f32>> { | ||||
|   // CHECK-SAME: ([[INPUT:%.*]]: tensor | ||||
|   // CHECK-NEXT: [[R1:%.*]] = "mhlo.real"([[INPUT]]) | ||||
|   // CHECK-NEXT: [[R2:%.*]] = "mhlo.imag"([[INPUT]]) | ||||
|   // CHECK-NEXT: [[R3:%.*]] = "mhlo.negate"([[R2]]) | ||||
|   // CHECK-NEXT: [[R4:%.*]] = "mhlo.complex"([[R1]], [[R3]]) | ||||
|   %1 = "chlo.conj"(%arg0) : (tensor<3xcomplex<f32>>) -> tensor<3xcomplex<f32>> | ||||
|   return %1 : tensor<3xcomplex<f32>> | ||||
| } | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue