diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index 9e43188..37e6727 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -398,6 +398,19 @@ def HLOClient_ConjOp : HLOClient_UnaryElementwiseOp<"conj", [], }]; } +def HLOClient_CoshOp : HLOClient_UnaryElementwiseOp<"cosh", [], + HLO_FpOrComplexTensor> { + let summary = "Cosh operator"; + + let description = [{ + Returns `Cosh(operand)` element-wise. + + $$ + \cosh(x) = (e^x + e^-x) / 2 + $$ + }]; +} + def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [], HLO_FpOrComplexTensor> { let summary = "Sinh operation"; diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td index 26a1da2..3d89f92 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td @@ -62,7 +62,7 @@ def : Pat<(HLOClient_AcosOp NonComplexElementType:$input), // Expand asin to MHLO dialect as follows: // asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) -def : Pat<(HLOClient_AsinOp $input), +def : Pat<(HLOClient_AsinOp NonComplexElementType:$input), (HLO_MulOp (HLO_ConstantLike<"2"> $input), (HLO_Atan2Op @@ -92,6 +92,36 @@ def : Pat<(HLOClient_AtanOp $input), def : Pat<(HLOClient_ConjOp $v), (HLO_ComplexOp (HLO_RealOp $v), (HLO_NegOp (HLO_ImagOp $v)))>; +// Express `cosh` as +// cosh(x) = (e^x + e^-x) / 2 +// = e^(x + log(1/2)) + e^(-x + log(1/2)) +// +// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not. +// +// This incorrectly overflows to inf for two f32 input values, namely +// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The +// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so +// we deem this acceptable. +def : Pat<(HLOClient_CoshOp NonComplexElementType:$input), + (HLO_AddOp + (HLO_ExpOp + (HLO_AddOp + $input, + (HLO_LogOp + (HLO_ConstantLike<"0.5"> $input) + ) + ) + ), + (HLO_ExpOp + (HLO_AddOp + (HLO_NegOp $input), + (HLO_LogOp + (HLO_ConstantLike<"0.5"> $input) + ) + ) + ) + )>; + // Express `sinh` as // sinh(x) = (e^x - e^-x) / 2 if |x| < 1 // = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise. @@ -136,7 +166,7 @@ def : Pat<(HLOClient_SinhOp NonComplexElementType:$input), // Express tan in MHLO dialect as // tan(x) = sin(x) / cos(x). -def : Pat<(HLOClient_TanOp $input), +def : Pat<(HLOClient_TanOp NonComplexElementType:$input), (HLO_DivOp (HLO_SinOp $input), (HLO_CosOp $input) diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index eb78100..152be21 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -50,9 +50,9 @@ namespace { sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp) // TODO(herhut): Generate these out of op definitions. -#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \ - fn(AcosOp) sep fn(AsinOp) sep fn(AtanOp) sep fn(ConjOp) sep fn(ErfOp) \ - sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp) +#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \ + fn(AcosOp) sep fn(AsinOp) sep fn(AtanOp) sep fn(ConjOp) sep fn(CoshOp) \ + sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp) template inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {