[MLIR][KernelGen] Add `tf.Atanh` kernels
PiperOrigin-RevId: 352393602
This commit is contained in:
		
							parent
							
								
									ba2ee556f1
								
							
						
					
					
						commit
						c11ea4ef5a
					
				|  | @ -397,6 +397,20 @@ def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [], | |||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def HLOClient_AtanhOp : HLOClient_UnaryElementwiseOp<"atanh", [], | ||||
|     HLO_FpOrComplexTensor> { | ||||
|   let summary = "Atanh operator"; | ||||
| 
 | ||||
|   let description = [{ | ||||
|     Returns `Atanh(operand)` element-wise. | ||||
| 
 | ||||
|     $$ | ||||
|     \atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1 | ||||
|               = nan                          otherwise | ||||
|     $$ | ||||
|   }]; | ||||
| } | ||||
| 
 | ||||
| def HLOClient_ConjOp : HLOClient_UnaryElementwiseOp<"conj", [], | ||||
|     HLO_FpOrComplexTensor> { | ||||
|   let summary = "Conj operator"; | ||||
|  |  | |||
|  | @ -175,6 +175,29 @@ def : Pat<(HLOClient_AtanOp $input), | |||
|     (HLO_ConstantLike<"1"> $input) | ||||
|   )>; | ||||
| 
 | ||||
| // Express `atanh` as follows: | ||||
| //   atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1 | ||||
| //   atanh(x) = nan                          otherwise | ||||
| def : Pat<(HLOClient_AtanhOp NonComplexElementType:$input), | ||||
|   (HLO_SelectOp | ||||
|     (HLO_CompareOp | ||||
|       (HLO_AbsOp $input), | ||||
|       (HLO_ConstantLike<"1"> $input), | ||||
|       HLO_COMPARISON_DIRECTION_GT, | ||||
|       (HLO_DEFAULT_COMPARISON_TYPE) | ||||
|     ), | ||||
|     (HLO_ConstantLike<"NAN"> $input), | ||||
|     (HLO_MulOp | ||||
|       (HLO_SubOp | ||||
|         (HLO_Log1pOp $input), | ||||
|         (HLO_Log1pOp | ||||
|           (HLO_NegOp $input) | ||||
|         ) | ||||
|       ), | ||||
|       (HLO_ConstantLike<"0.5"> $input) | ||||
|     ) | ||||
|   )>; | ||||
| 
 | ||||
| // Express `conj` as | ||||
| //   conj(x) = (re(x), -im(x)). | ||||
| def : Pat<(HLOClient_ConjOp $v), | ||||
|  |  | |||
|  | @ -50,9 +50,10 @@ namespace { | |||
|               sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp) | ||||
| 
 | ||||
| // TODO(herhut): Generate these out of op definitions.
 | ||||
| #define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep)                           \ | ||||
|   fn(AcosOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) sep fn(ConjOp) \ | ||||
|       sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp) | ||||
| #define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep)                            \ | ||||
|   fn(AcosOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) sep fn(AtanhOp) \ | ||||
|       sep fn(ConjOp) sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp)           \ | ||||
|           sep fn(SinhOp) sep fn(TanOp) | ||||
| 
 | ||||
| template <typename OpTy> | ||||
| inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue