diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index bd0ec23..9d36a0f 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -648,6 +648,51 @@ Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc, lgamma); } +// Express `cosh` as +// cosh(x) = (e^x + e^-x) / 2 +// = e^(x + log(1/2)) + e^(-x + log(1/2)) +// +// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not. +// +// This incorrectly overflows to inf for two f32 input values, namely +// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The +// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so +// we deem this acceptable. +Value MaterializeCoshApproximation(ConversionPatternRewriter &rewriter, + Location loc, ValueRange operands) { + CoshOp::Adaptor transformed(operands); + Value x = transformed.operand(); + + Value log_one_half = + rewriter.create(loc, getConstantLike(rewriter, loc, 0.5, x)); + Value exp_add = rewriter.create( + loc, rewriter.create(loc, x, log_one_half)); + Value exp_sub = rewriter.create( + loc, rewriter.create(loc, log_one_half, x)); + return rewriter.create(loc, exp_add, exp_sub); +} + +struct ConvertCoshOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + CoshOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + CoshOp::Adaptor transformed(operands); + Value x = transformed.operand(); + if (x.getType().cast().getElementType().isa()) { + // TODO(hinsu): Support operands with complex element types by always + // using the formula for large x. The compare op is not legal for complex + // numbers. + return failure(); + } + rewriter.replaceOp(op, + MaterializeWithUpcast(rewriter, op.getLoc(), operands, + rewriter.getF32Type(), + &MaterializeCoshApproximation)); + return success(); + } +}; + // Compute the Digamma function using Lanczos' approximation from "A Precision // Approximation of the Gamma Function". SIAM Journal on Numerical Analysis // series B. Vol. 1: @@ -1318,7 +1363,8 @@ void PopulateDecomposeChloPatterns(MLIRContext *context, // Other patterns. // clang-format off - patterns->insertinsert; -// Express `cosh` as -// cosh(x) = (e^x + e^-x) / 2 -// = e^(x + log(1/2)) + e^(-x + log(1/2)) -// -// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not. -// -// This incorrectly overflows to inf for two f32 input values, namely -// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The -// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so -// we deem this acceptable. -def : Pat<(HLOClient_CoshOp NonComplexElementType:$input), - (HLO_AddOp - (HLO_ExpOp - (HLO_AddOp - $input, - (HLO_LogOp - (HLO_ConstantLike<"0.5"> $input) - ) - ) - ), - (HLO_ExpOp - (HLO_AddOp - (HLO_NegOp $input), - (HLO_LogOp - (HLO_ConstantLike<"0.5"> $input) - ) - ) - ) - )>; - // Express `is_inf` as // is_inf(x) = is_pos_inf(|x|) def : Pat<(HLOClient_IsInfOp NonComplexElementType:$input), diff --git a/tests/chlo_legalize_to_mhlo.mlir b/tests/chlo_legalize_to_mhlo.mlir index f94bb8d..81e18f1 100644 --- a/tests/chlo_legalize_to_mhlo.mlir +++ b/tests/chlo_legalize_to_mhlo.mlir @@ -2187,3 +2187,32 @@ func @sinh_complex(%x : tensor<2xcomplex>) -> tensor<2xcomplex> { %1 = chlo.sinh %x : tensor<2xcomplex> -> tensor<2xcomplex> return %1 : tensor<2xcomplex> } + +// ---- + +// CHECK-LABEL: @cosh_f32 +// CHECK-SAME: (%[[X:.*]]: tensor) +func @cosh_f32(%x : tensor) -> tensor { + // CHECK: %[[HALF:.*]] = mhlo.constant dense<5.000000e-01> : tensor + // CHECK: %[[LOG_HALF:.*]] = "mhlo.log"(%[[HALF]]) : (tensor) -> tensor + // CHECK: %[[X_PLUS_LOG_HALF:.*]] = mhlo.add %[[X]], %[[LOG_HALF]] : tensor + // CHECK: %[[EXP_1:.*]] = "mhlo.exponential"(%[[X_PLUS_LOG_HALF]]) : (tensor) -> tensor + // CHECK: %[[LOG_HALF_MINUS_X:.*]] = mhlo.subtract %[[LOG_HALF]], %[[X]] : tensor + // CHECK: %[[EXP_2:.*]] = "mhlo.exponential"(%[[LOG_HALF_MINUS_X]]) : (tensor) -> tensor + // CHECK: %[[RESULT:.*]] = mhlo.add %[[EXP_1]], %[[EXP_2]] : tensor + // CHECK: return %[[RESULT]] : tensor + %1 = chlo.cosh %x : tensor -> tensor + return %1 : tensor +} + +// ---- + +// CHECK-LABEL: @cosh_f16 +// CHECK-SAME: (%[[ARG0:.*]]: tensor) +func @cosh_f16(%x : tensor) -> tensor { + // CHECK: "mhlo.convert"(%[[ARG0]]) : (tensor) -> tensor + // CHECK: %[[RES:.*]] = "mhlo.convert"(%{{.*}}) : (tensor) -> tensor + // CHECK: return %[[RES]] + %1 = chlo.cosh %x : tensor -> tensor + return %1 : tensor +}