[MLIR][KernelGen] Add cosh kernels and tests

Allow for relative tolerance in unary kernel tests. In case of the cosh kernels,
this allows to accept an observed difference of 5.6e-8 between the kernel and
the `std::cosh` reference (32829984.568665262 vs. 32829984.568665318) in one of
the test cases.

PiperOrigin-RevId: 351983698
This commit is contained in:
A. Unique TensorFlower 2021-01-15 04:30:14 -08:00 committed by TensorFlow MLIR Team
parent 9a1abaa212
commit 316f630728
3 changed files with 48 additions and 5 deletions

View File

@ -398,6 +398,19 @@ def HLOClient_ConjOp : HLOClient_UnaryElementwiseOp<"conj", [],
}]; }];
} }
def HLOClient_CoshOp : HLOClient_UnaryElementwiseOp<"cosh", [],
HLO_FpOrComplexTensor> {
let summary = "Cosh operator";
let description = [{
Returns `Cosh(operand)` element-wise.
$$
\cosh(x) = (e^x + e^-x) / 2
$$
}];
}
def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [], def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [],
HLO_FpOrComplexTensor> { HLO_FpOrComplexTensor> {
let summary = "Sinh operation"; let summary = "Sinh operation";

View File

@ -62,7 +62,7 @@ def : Pat<(HLOClient_AcosOp NonComplexElementType:$input),
// Expand asin to MHLO dialect as follows: // Expand asin to MHLO dialect as follows:
// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) // asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
def : Pat<(HLOClient_AsinOp $input), def : Pat<(HLOClient_AsinOp NonComplexElementType:$input),
(HLO_MulOp (HLO_MulOp
(HLO_ConstantLike<"2"> $input), (HLO_ConstantLike<"2"> $input),
(HLO_Atan2Op (HLO_Atan2Op
@ -92,6 +92,36 @@ def : Pat<(HLOClient_AtanOp $input),
def : Pat<(HLOClient_ConjOp $v), def : Pat<(HLOClient_ConjOp $v),
(HLO_ComplexOp (HLO_RealOp $v), (HLO_NegOp (HLO_ImagOp $v)))>; (HLO_ComplexOp (HLO_RealOp $v), (HLO_NegOp (HLO_ImagOp $v)))>;
// Express `cosh` as
// cosh(x) = (e^x + e^-x) / 2
// = e^(x + log(1/2)) + e^(-x + log(1/2))
//
// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not.
//
// This incorrectly overflows to inf for two f32 input values, namely
// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
// we deem this acceptable.
def : Pat<(HLOClient_CoshOp NonComplexElementType:$input),
(HLO_AddOp
(HLO_ExpOp
(HLO_AddOp
$input,
(HLO_LogOp
(HLO_ConstantLike<"0.5"> $input)
)
)
),
(HLO_ExpOp
(HLO_AddOp
(HLO_NegOp $input),
(HLO_LogOp
(HLO_ConstantLike<"0.5"> $input)
)
)
)
)>;
// Express `sinh` as // Express `sinh` as
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1 // sinh(x) = (e^x - e^-x) / 2 if |x| < 1
// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise. // = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
@ -136,7 +166,7 @@ def : Pat<(HLOClient_SinhOp NonComplexElementType:$input),
// Express tan in MHLO dialect as // Express tan in MHLO dialect as
// tan(x) = sin(x) / cos(x). // tan(x) = sin(x) / cos(x).
def : Pat<(HLOClient_TanOp $input), def : Pat<(HLOClient_TanOp NonComplexElementType:$input),
(HLO_DivOp (HLO_DivOp
(HLO_SinOp $input), (HLO_SinOp $input),
(HLO_CosOp $input) (HLO_CosOp $input)

View File

@ -51,8 +51,8 @@ namespace {
// TODO(herhut): Generate these out of op definitions. // TODO(herhut): Generate these out of op definitions.
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \ #define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
fn(AcosOp) sep fn(AsinOp) sep fn(AtanOp) sep fn(ConjOp) sep fn(ErfOp) \ fn(AcosOp) sep fn(AsinOp) sep fn(AtanOp) sep fn(ConjOp) sep fn(CoshOp) \
sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp) sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp)
template <typename OpTy> template <typename OpTy>
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {