[MLIR][KernelGen] Add `tf.Atanh` kernels

PiperOrigin-RevId: 352393602
This commit is contained in:
A. Unique TensorFlower 2021-01-18 05:13:02 -08:00 committed by TensorFlow MLIR Team
parent ba2ee556f1
commit c11ea4ef5a
3 changed files with 41 additions and 3 deletions

View File

@ -397,6 +397,20 @@ def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
}]; }];
} }
def HLOClient_AtanhOp : HLOClient_UnaryElementwiseOp<"atanh", [],
HLO_FpOrComplexTensor> {
let summary = "Atanh operator";
let description = [{
Returns `Atanh(operand)` element-wise.
$$
\atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1
= nan otherwise
$$
}];
}
def HLOClient_ConjOp : HLOClient_UnaryElementwiseOp<"conj", [], def HLOClient_ConjOp : HLOClient_UnaryElementwiseOp<"conj", [],
HLO_FpOrComplexTensor> { HLO_FpOrComplexTensor> {
let summary = "Conj operator"; let summary = "Conj operator";

View File

@ -175,6 +175,29 @@ def : Pat<(HLOClient_AtanOp $input),
(HLO_ConstantLike<"1"> $input) (HLO_ConstantLike<"1"> $input)
)>; )>;
// Express `atanh` as follows:
// atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1
// atanh(x) = nan otherwise
def : Pat<(HLOClient_AtanhOp NonComplexElementType:$input),
(HLO_SelectOp
(HLO_CompareOp
(HLO_AbsOp $input),
(HLO_ConstantLike<"1"> $input),
HLO_COMPARISON_DIRECTION_GT,
(HLO_DEFAULT_COMPARISON_TYPE)
),
(HLO_ConstantLike<"NAN"> $input),
(HLO_MulOp
(HLO_SubOp
(HLO_Log1pOp $input),
(HLO_Log1pOp
(HLO_NegOp $input)
)
),
(HLO_ConstantLike<"0.5"> $input)
)
)>;
// Express `conj` as // Express `conj` as
// conj(x) = (re(x), -im(x)). // conj(x) = (re(x), -im(x)).
def : Pat<(HLOClient_ConjOp $v), def : Pat<(HLOClient_ConjOp $v),

View File

@ -50,9 +50,10 @@ namespace {
sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp) sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp)
// TODO(herhut): Generate these out of op definitions. // TODO(herhut): Generate these out of op definitions.
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \ #define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
fn(AcosOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) sep fn(ConjOp) \ fn(AcosOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) sep fn(AtanhOp) \
sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp) sep fn(ConjOp) sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp) \
sep fn(SinhOp) sep fn(TanOp)
template <typename OpTy> template <typename OpTy>
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {