diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index 4ac5f69..48adffc 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -554,6 +554,15 @@ def HLOClient_IsPosInfOp : HLOClient_UnaryElementwiseOp<"is_pos_inf", }]; } +def HLOClient_LgammaOp : HLOClient_UnaryElementwiseOp<"lgamma", + [SameOperandsAndResultType], HLO_FpTensor, HLO_FpTensor> { + let summary = "Lgamma function"; + + let description = [{ + Returns `Lgamma(operand)` element-wise. + }]; +} + //===----------------------------------------------------------------------===// // Broadcasting compare op //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index a15e240..a45b91a 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -471,6 +471,191 @@ struct ConvertErfcOp : public OpConversionPattern { } }; +// Coefficients for the Lanczos approximation of the gamma function. The +// coefficients are uniquely determined by the choice of g and n (kLanczosGamma +// and kLanczosCoefficients.size() + 1). The coefficients below correspond to +// [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and +// [7, 9] seemed to be the least sensitive to the quality of the log function. +// In particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5 +// for a particularly inaccurate log function. +constexpr double kLanczosGamma = 7; // aka g +constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478; +constexpr std::array kLanczosCoefficients = { + 676.520368121885098567009190444019, -1259.13921672240287047156078755283, + 771.3234287776530788486528258894, -176.61502916214059906584551354, + 12.507343278686904814458936853, -0.13857109526572011689554707, + 9.984369578019570859563e-6, 1.50563273514931155834e-7}; + +// Compute the Lgamma function using Lanczos' approximation from "A Precision +// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis +// series B. Vol. 1: +// lgamma(z + 1) = (log(2) + log(pi)) / 2 +// + (z + 1/2) * log(t(z)) +// - t(z) + log(a(z)) +// with t(z) = z + kLanczosGamma + 1/2 +// a(z) = kBaseLanczosCoeff +// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k)) +Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc, + Value x) { + // If the input is less than 0.5 use Euler's reflection formula. + // gamma(x) = pi / (sin(pi * x) * gamma(1 - x)) + // Let z be + // z = -x if x < 1/2 + // z = x - 1 otheriwse + const StringAttr kLT = rewriter.getStringAttr( + mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT)); + Value half = getConstantLike(rewriter, loc, 0.5, x); + Value need_to_reflect = rewriter.create(loc, x, half, kLT); + Value neg_x = rewriter.create(loc, x); + Value one = getConstantLike(rewriter, loc, 1, x); + Value x_sub_one = rewriter.create(loc, x, one); + Value z = + rewriter.create(loc, need_to_reflect, neg_x, x_sub_one); + + // Materialize + // a(z) = kBaseLanczosCoeff + // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k)) + Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x); + for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) { + Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x); + Value one_based_index = getConstantLike(rewriter, loc, i + 1, x); + Value quotient = rewriter.create( + loc, coeff, rewriter.create(loc, z, one_based_index)); + a = rewriter.create(loc, a, quotient); + } + + // To improve accuracy on platforms with less-precise log implementations, + // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the + // device. + // Materialize as + // log(t) = log(kLanczosGamma + 1/2 + z) + // = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)). + Value lanczos_plus_half = + getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x); + Value t = rewriter.create(loc, lanczos_plus_half, z); + Value log_term = + getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x); + Value log1p_term = rewriter.create( + loc, rewriter.create(loc, z, lanczos_plus_half)); + Value log_t = rewriter.create(loc, log_term, log1p_term); + + // Note that t(z) may be large and we need to be careful not to overflow to + // infinity in the relevant term + // r = (z + 1/2) * log(t(z)) - t(z). + // Therefore, we compute this as + // r = (z + 1/2 - t(z) / log(t(z))) * log(t(z)). + Value t_div_log_t = rewriter.create(loc, t, log_t); + Value sum = rewriter.create( + loc, rewriter.create(loc, z, half), t_div_log_t); + Value r = rewriter.create(loc, sum, log_t); + + // Compute the final result (modulo reflection) as + // lgamma(z + 1) = (log(2) + log(pi)) / 2 + r + log(a(z)). + Value log_a = rewriter.create(loc, a); + Value lgamma = rewriter.create( + loc, + rewriter.create( + loc, + getConstantLike(rewriter, loc, (std::log(2) + std::log(M_PI)) / 2, x), + r), + log_a); + + // Compute the reflected value for x < 0.5 as + // lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))). + // + // The abs is needed because lgamma is the log of the absolute value of the + // gamma function. + // + // We have to be careful when computing the final term above. gamma(x) goes + // to +/-inf at every integer x < 0, and this is controlled by the sin(pi * x) + // term. The slope is large, so precision is particularly important. + // + // Because abs(sin(pi * x)) has period of 1 we can equivalently use + // abs(sin(pi * frac(x))) where frac(x) is the fractional part of x. This is + // more numerically accurate: It doesn't overflow to inf like pi * x would and + // if x is an integer it evaluates to exactly 0 which is important because we + // then take the log of this value, and log(0) is inf. + // + // We don't have a frac(x) primitive in HLO and computing it is tricky, but + // because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for our + // purposes to use abs(frac(x)) = abs(x) - floor(abs(x)). + // + // Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close + // to 1. To remedy this, we can use the fact that sin(pi * x) in the domain + // [0, 1] is symmetric across the line Y=0.5. + // + + // Convert values of abs_frac > 0.5 to (1 - abs_frac) to improve precision of + // pi * abs_frac for values of abs_frac close to 1. + Value abs = rewriter.create(loc, x); + Value abs_frac = rewriter.create( + loc, abs, rewriter.create(loc, abs)); + Value reduce_abs_frac = + rewriter.create(loc, half, abs_frac, kLT); + abs_frac = rewriter.create( + loc, reduce_abs_frac, rewriter.create(loc, one, abs_frac), + abs_frac); + + // Materialize reflection. + Value reflection_denom = rewriter.create( + loc, + rewriter.create( + loc, rewriter.create( + loc, getConstantLike(rewriter, loc, M_PI, x), abs_frac))); + Value lgamma_reflection = rewriter.create( + loc, + rewriter.create( + loc, getConstantLike(rewriter, loc, std::log(M_PI), x), + reflection_denom), + lgamma); + + // Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf, + // then it "wins" and the result is +/-inf. + Value finite_reflection_denom = + rewriter.create(loc, reflection_denom); + Value neg_reflection_denom = + rewriter.create(loc, reflection_denom); + lgamma_reflection = rewriter.create( + loc, finite_reflection_denom, lgamma_reflection, neg_reflection_denom); + + // Select whether or not to rely on the reflection. + lgamma = rewriter.create(loc, need_to_reflect, + lgamma_reflection, lgamma); + + // Materialize +/-inf behavior as + // lgamma(+/-inf) = +inf. + Value x_is_inf = rewriter.create(loc, x); + return rewriter.create( + loc, x_is_inf, + chlo::getConstantLikeInfValue(rewriter, loc, x, /*negative=*/false), + lgamma); +} + +struct ConvertLgammaOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + LgammaOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + LgammaOp::Adaptor transformed(operands); + Value x = transformed.operand(); + Type ty = getElementTypeOrSelf(op.getType()); + + if (ty.isF32() || ty.isF64()) { + rewriter.replaceOp(op, MaterializeLgamma(rewriter, loc, x)); + return success(); + } + + // Materialize lgamma with upcast to f32. + x = rewriter.create(loc, x, rewriter.getF32Type()); + Value result = MaterializeLgamma(rewriter, loc, x); + result = rewriter.create(loc, result, ty); + + rewriter.replaceOp(op, result); + return success(); + } +}; + // Converts binary ops that statically are determined to not broadcast directly // to the corresponding mhlo non-broadcasting op. template @@ -622,7 +807,8 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context, context, patterns, 5); // Other patterns. - patterns->insert(context); + patterns->insert(context); } } // namespace chlo diff --git a/tests/chlo_legalize_to_mhlo.mlir b/tests/chlo_legalize_to_mhlo.mlir index 026cea0..aad137f 100644 --- a/tests/chlo_legalize_to_mhlo.mlir +++ b/tests/chlo_legalize_to_mhlo.mlir @@ -683,3 +683,195 @@ func @is_neg_inf_f32(%arg : tensor) -> tensor { %1 = chlo.is_neg_inf %arg : tensor -> tensor return %1 : tensor } + +// CHECK-LABEL: @lgamma_f64 +// CHECK-SAME: (%[[ARG:.*]]: tensor) +func @lgamma_f64(%arg : tensor) -> tensor { + // CHECK: %[[TMP_1:.*]] = mhlo.constant dense<5.000000e-01> + // CHECK: %[[TMP_9:.*]] = "mhlo.compare"(%[[ARG]], %[[TMP_1]]) {comparison_direction = "LT"} + // CHECK: %[[TMP_10:.*]] = "mhlo.negate"(%[[ARG]]) + // CHECK: %[[TMP_2:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[TMP_11:.*]] = mhlo.subtract %[[ARG]], %[[TMP_2]] + // CHECK: %[[TMP_12:.*]] = "mhlo.select"(%[[TMP_9]], %[[TMP_10]], %[[TMP_11]]) + // CHECK: %[[TMP_8:.*]] = mhlo.constant dense<0.99999999999980993> + // CHECK: %[[TMP_13:.*]] = mhlo.constant dense<676.5203681218851> + // CHECK: %[[TMP_14:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[TMP_15:.*]] = mhlo.add %[[TMP_12]], %[[TMP_14]] + // CHECK: %[[TMP_16:.*]] = mhlo.divide %[[TMP_13]], %[[TMP_15]] + // CHECK: %[[TMP_17:.*]] = mhlo.add %[[TMP_8]], %[[TMP_16]] + // CHECK: %[[TMP_18:.*]] = mhlo.constant dense<-1259.1392167224028> + // CHECK: %[[TMP_19:.*]] = mhlo.constant dense<2.000000e+00> + // CHECK: %[[TMP_20:.*]] = mhlo.add %[[TMP_12]], %[[TMP_19]] + // CHECK: %[[TMP_21:.*]] = mhlo.divide %[[TMP_18]], %[[TMP_20]] + // CHECK: %[[TMP_22:.*]] = mhlo.add %[[TMP_17]], %[[TMP_21]] + // CHECK: %[[TMP_23:.*]] = mhlo.constant dense<771.32342877765313> + // CHECK: %[[TMP_24:.*]] = mhlo.constant dense<3.000000e+00> + // CHECK: %[[TMP_25:.*]] = mhlo.add %[[TMP_12]], %[[TMP_24]] + // CHECK: %[[TMP_26:.*]] = mhlo.divide %[[TMP_23]], %[[TMP_25]] + // CHECK: %[[TMP_27:.*]] = mhlo.add %[[TMP_22]], %[[TMP_26]] + // CHECK: %[[TMP_28:.*]] = mhlo.constant dense<-176.61502916214059> + // CHECK: %[[TMP_29:.*]] = mhlo.constant dense<4.000000e+00> + // CHECK: %[[TMP_30:.*]] = mhlo.add %[[TMP_12]], %[[TMP_29]] + // CHECK: %[[TMP_31:.*]] = mhlo.divide %[[TMP_28]], %[[TMP_30]] + // CHECK: %[[TMP_32:.*]] = mhlo.add %[[TMP_27]], %[[TMP_31]] + // CHECK: %[[TMP_33:.*]] = mhlo.constant dense<12.507343278686905> + // CHECK: %[[TMP_34:.*]] = mhlo.constant dense<5.000000e+00> + // CHECK: %[[TMP_35:.*]] = mhlo.add %[[TMP_12]], %[[TMP_34]] + // CHECK: %[[TMP_36:.*]] = mhlo.divide %[[TMP_33]], %[[TMP_35]] + // CHECK: %[[TMP_37:.*]] = mhlo.add %[[TMP_32]], %[[TMP_36]] + // CHECK: %[[TMP_38:.*]] = mhlo.constant dense<-0.13857109526572012> + // CHECK: %[[TMP_39:.*]] = mhlo.constant dense<6.000000e+00> + // CHECK: %[[TMP_40:.*]] = mhlo.add %[[TMP_12]], %[[TMP_39]] + // CHECK: %[[TMP_41:.*]] = mhlo.divide %[[TMP_38]], %[[TMP_40]] + // CHECK: %[[TMP_42:.*]] = mhlo.add %[[TMP_37]], %[[TMP_41]] + // CHECK: %[[TMP_43:.*]] = mhlo.constant dense<9.9843695780195716E-6> + // CHECK: %[[TMP_44:.*]] = mhlo.constant dense<7.000000e+00> + // CHECK: %[[TMP_45:.*]] = mhlo.add %[[TMP_12]], %[[TMP_44]] + // CHECK: %[[TMP_46:.*]] = mhlo.divide %[[TMP_43]], %[[TMP_45]] + // CHECK: %[[TMP_47:.*]] = mhlo.add %[[TMP_42]], %[[TMP_46]] + // CHECK: %[[TMP_48:.*]] = mhlo.constant dense<1.5056327351493116E-7> + // CHECK: %[[TMP_49:.*]] = mhlo.constant dense<8.000000e+00> + // CHECK: %[[TMP_50:.*]] = mhlo.add %[[TMP_12]], %[[TMP_49]] + // CHECK: %[[TMP_51:.*]] = mhlo.divide %[[TMP_48]], %[[TMP_50]] + // CHECK: %[[TMP_52:.*]] = mhlo.add %[[TMP_47]], %[[TMP_51]] + // CHECK: %[[TMP_6:.*]] = mhlo.constant dense<7.500000e+00> + // CHECK: %[[TMP_53:.*]] = mhlo.add %[[TMP_6]], %[[TMP_12]] + // CHECK: %[[TMP_7:.*]] = mhlo.constant dense<2.0149030205422647> + // CHECK: %[[TMP_54:.*]] = mhlo.divide %[[TMP_12]], %[[TMP_6]] + // CHECK: %[[TMP_55:.*]] = "mhlo.log_plus_one"(%[[TMP_54]]) + // CHECK: %[[TMP_56:.*]] = mhlo.add %[[TMP_7]], %[[TMP_55]] + // CHECK: %[[TMP_57:.*]] = mhlo.divide %[[TMP_53]], %[[TMP_56]] + // CHECK: %[[TMP_58:.*]] = mhlo.add %[[TMP_12]], %[[TMP_1]] + // CHECK: %[[TMP_59:.*]] = mhlo.subtract %[[TMP_58]], %[[TMP_57]] + // CHECK: %[[TMP_60:.*]] = mhlo.multiply %[[TMP_59]], %[[TMP_56]] + // CHECK: %[[TMP_61:.*]] = "mhlo.log"(%[[TMP_52]]) + // CHECK: %[[TMP_5:.*]] = mhlo.constant dense<0.91893853320467266> + // CHECK: %[[TMP_62:.*]] = mhlo.add %[[TMP_5]], %[[TMP_60]] + // CHECK: %[[TMP_63:.*]] = mhlo.add %[[TMP_62]], %[[TMP_61]] + // CHECK: %[[TMP_64:.*]] = "mhlo.abs"(%[[ARG]]) + // CHECK: %[[TMP_65:.*]] = "mhlo.floor"(%[[TMP_64]]) + // CHECK: %[[TMP_66:.*]] = mhlo.subtract %[[TMP_64]], %[[TMP_65]] + // CHECK: %[[TMP_67:.*]] = "mhlo.compare"(%[[TMP_1]], %[[TMP_66]]) {comparison_direction = "LT"} + // CHECK: %[[TMP_68:.*]] = mhlo.subtract %[[TMP_2]], %[[TMP_66]] + // CHECK: %[[TMP_69:.*]] = "mhlo.select"(%[[TMP_67]], %[[TMP_68]], %[[TMP_66]]) + // CHECK: %[[TMP_3:.*]] = mhlo.constant dense<3.1415926535897931> + // CHECK: %[[TMP_70:.*]] = mhlo.multiply %[[TMP_3]], %[[TMP_69]] + // CHECK: %[[TMP_71:.*]] = "mhlo.sine"(%[[TMP_70]]) + // CHECK: %[[TMP_72:.*]] = "mhlo.log"(%[[TMP_71]]) + // CHECK: %[[TMP_4:.*]] = mhlo.constant dense<1.1447298858494002> + // CHECK: %[[TMP_75:.*]] = mhlo.subtract %[[TMP_4]], %[[TMP_72]] + // CHECK: %[[TMP_76:.*]] = mhlo.subtract %[[TMP_75]], %[[TMP_63]] + // CHECK: %[[TMP_73:.*]] = "mhlo.is_finite"(%[[TMP_72]]) + // CHECK: %[[TMP_74:.*]] = "mhlo.negate"(%[[TMP_72]]) + // CHECK: %[[TMP_77:.*]] = "mhlo.select"(%[[TMP_73]], %[[TMP_76]], %[[TMP_74]]) + // CHECK: %[[TMP_78:.*]] = "mhlo.select"(%[[TMP_9]], %[[TMP_77]], %[[TMP_63]]) + // CHECK: %[[TMP_79:.*]] = "mhlo.abs"(%[[ARG]]) + // CHECK: %[[TMP_80:.*]] = mhlo.constant dense<0x7FF0000000000000> + // CHECK: %[[TMP_81:.*]] = "mhlo.compare"(%[[TMP_79]], %[[TMP_80]]) {comparison_direction = "EQ"} + // CHECK: %[[TMP_0:.*]] = mhlo.constant dense<0x7FF0000000000000> + // CHECK: %[[TMP_82:.*]] = "mhlo.select"(%[[TMP_81]], %[[TMP_0]], %[[TMP_78]]) + // CHECK: return %[[TMP_82]] + %1 = chlo.lgamma %arg : tensor -> tensor + return %1 : tensor +} + +// CHECK-LABEL: @lgamma_f32 +// CHECK-SAME: (%[[ARG:.*]]: tensor) +func @lgamma_f32(%arg : tensor) -> tensor { + // CHECK: %[[TMP_1:.*]] = mhlo.constant dense<5.000000e-01> + // CHECK: %[[TMP_9:.*]] = "mhlo.compare"(%[[ARG]], %[[TMP_1]]) {comparison_direction = "LT"} + // CHECK: %[[TMP_10:.*]] = "mhlo.negate"(%[[ARG]]) + // CHECK: %[[TMP_2:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[TMP_11:.*]] = mhlo.subtract %[[ARG]], %[[TMP_2]] + // CHECK: %[[TMP_12:.*]] = "mhlo.select"(%[[TMP_9]], %[[TMP_10]], %[[TMP_11]]) + // CHECK: %[[TMP_8:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[TMP_13:.*]] = mhlo.constant dense<676.520386> + // CHECK: %[[TMP_14:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[TMP_15:.*]] = mhlo.add %[[TMP_12]], %[[TMP_14]] + // CHECK: %[[TMP_16:.*]] = mhlo.divide %[[TMP_13]], %[[TMP_15]] + // CHECK: %[[TMP_17:.*]] = mhlo.add %[[TMP_8]], %[[TMP_16]] + // CHECK: %[[TMP_18:.*]] = mhlo.constant dense<-1259.13916> + // CHECK: %[[TMP_19:.*]] = mhlo.constant dense<2.000000e+00> + // CHECK: %[[TMP_20:.*]] = mhlo.add %[[TMP_12]], %[[TMP_19]] + // CHECK: %[[TMP_21:.*]] = mhlo.divide %[[TMP_18]], %[[TMP_20]] + // CHECK: %[[TMP_22:.*]] = mhlo.add %[[TMP_17]], %[[TMP_21]] + // CHECK: %[[TMP_23:.*]] = mhlo.constant dense<771.323425> + // CHECK: %[[TMP_24:.*]] = mhlo.constant dense<3.000000e+00> + // CHECK: %[[TMP_25:.*]] = mhlo.add %[[TMP_12]], %[[TMP_24]] + // CHECK: %[[TMP_26:.*]] = mhlo.divide %[[TMP_23]], %[[TMP_25]] + // CHECK: %[[TMP_27:.*]] = mhlo.add %[[TMP_22]], %[[TMP_26]] + // CHECK: %[[TMP_28:.*]] = mhlo.constant dense<-176.615036> + // CHECK: %[[TMP_29:.*]] = mhlo.constant dense<4.000000e+00> + // CHECK: %[[TMP_30:.*]] = mhlo.add %[[TMP_12]], %[[TMP_29]] + // CHECK: %[[TMP_31:.*]] = mhlo.divide %[[TMP_28]], %[[TMP_30]] + // CHECK: %[[TMP_32:.*]] = mhlo.add %[[TMP_27]], %[[TMP_31]] + // CHECK: %[[TMP_33:.*]] = mhlo.constant dense<12.5073433> + // CHECK: %[[TMP_34:.*]] = mhlo.constant dense<5.000000e+00> + // CHECK: %[[TMP_35:.*]] = mhlo.add %[[TMP_12]], %[[TMP_34]] + // CHECK: %[[TMP_36:.*]] = mhlo.divide %[[TMP_33]], %[[TMP_35]] + // CHECK: %[[TMP_37:.*]] = mhlo.add %[[TMP_32]], %[[TMP_36]] + // CHECK: %[[TMP_38:.*]] = mhlo.constant dense<-0.138571098> + // CHECK: %[[TMP_39:.*]] = mhlo.constant dense<6.000000e+00> + // CHECK: %[[TMP_40:.*]] = mhlo.add %[[TMP_12]], %[[TMP_39]] + // CHECK: %[[TMP_41:.*]] = mhlo.divide %[[TMP_38]], %[[TMP_40]] + // CHECK: %[[TMP_42:.*]] = mhlo.add %[[TMP_37]], %[[TMP_41]] + // CHECK: %[[TMP_43:.*]] = mhlo.constant dense<9.98436917E-6> + // CHECK: %[[TMP_44:.*]] = mhlo.constant dense<7.000000e+00> + // CHECK: %[[TMP_45:.*]] = mhlo.add %[[TMP_12]], %[[TMP_44]] + // CHECK: %[[TMP_46:.*]] = mhlo.divide %[[TMP_43]], %[[TMP_45]] + // CHECK: %[[TMP_47:.*]] = mhlo.add %[[TMP_42]], %[[TMP_46]] + // CHECK: %[[TMP_48:.*]] = mhlo.constant dense<1.50563267E-7> + // CHECK: %[[TMP_49:.*]] = mhlo.constant dense<8.000000e+00> + // CHECK: %[[TMP_50:.*]] = mhlo.add %[[TMP_12]], %[[TMP_49]] + // CHECK: %[[TMP_51:.*]] = mhlo.divide %[[TMP_48]], %[[TMP_50]] + // CHECK: %[[TMP_52:.*]] = mhlo.add %[[TMP_47]], %[[TMP_51]] + // CHECK: %[[TMP_6:.*]] = mhlo.constant dense<7.500000e+00> + // CHECK: %[[TMP_53:.*]] = mhlo.add %[[TMP_6]], %[[TMP_12]] + // CHECK: %[[TMP_7:.*]] = mhlo.constant dense<2.01490307> + // CHECK: %[[TMP_54:.*]] = mhlo.divide %[[TMP_12]], %[[TMP_6]] + // CHECK: %[[TMP_55:.*]] = "mhlo.log_plus_one"(%[[TMP_54]]) + // CHECK: %[[TMP_56:.*]] = mhlo.add %[[TMP_7]], %[[TMP_55]] + // CHECK: %[[TMP_57:.*]] = mhlo.divide %[[TMP_53]], %[[TMP_56]] + // CHECK: %[[TMP_58:.*]] = mhlo.add %[[TMP_12]], %[[TMP_1]] + // CHECK: %[[TMP_59:.*]] = mhlo.subtract %[[TMP_58]], %[[TMP_57]] + // CHECK: %[[TMP_60:.*]] = mhlo.multiply %[[TMP_59]], %[[TMP_56]] + // CHECK: %[[TMP_61:.*]] = "mhlo.log"(%[[TMP_52]]) + // CHECK: %[[TMP_5:.*]] = mhlo.constant dense<0.918938517> + // CHECK: %[[TMP_62:.*]] = mhlo.add %[[TMP_5]], %[[TMP_60]] + // CHECK: %[[TMP_63:.*]] = mhlo.add %[[TMP_62]], %[[TMP_61]] + // CHECK: %[[TMP_64:.*]] = "mhlo.abs"(%[[ARG]]) + // CHECK: %[[TMP_65:.*]] = "mhlo.floor"(%[[TMP_64]]) + // CHECK: %[[TMP_66:.*]] = mhlo.subtract %[[TMP_64]], %[[TMP_65]] + // CHECK: %[[TMP_67:.*]] = "mhlo.compare"(%[[TMP_1]], %[[TMP_66]]) {comparison_direction = "LT"} + // CHECK: %[[TMP_68:.*]] = mhlo.subtract %[[TMP_2]], %[[TMP_66]] + // CHECK: %[[TMP_69:.*]] = "mhlo.select"(%[[TMP_67]], %[[TMP_68]], %[[TMP_66]]) + // CHECK: %[[TMP_3:.*]] = mhlo.constant dense<3.14159274> + // CHECK: %[[TMP_70:.*]] = mhlo.multiply %[[TMP_3]], %[[TMP_69]] + // CHECK: %[[TMP_71:.*]] = "mhlo.sine"(%[[TMP_70]]) + // CHECK: %[[TMP_72:.*]] = "mhlo.log"(%[[TMP_71]]) + // CHECK: %[[TMP_4:.*]] = mhlo.constant dense<1.14472985> + // CHECK: %[[TMP_75:.*]] = mhlo.subtract %[[TMP_4]], %[[TMP_72]] + // CHECK: %[[TMP_76:.*]] = mhlo.subtract %[[TMP_75]], %[[TMP_63]] + // CHECK: %[[TMP_73:.*]] = "mhlo.is_finite"(%[[TMP_72]]) + // CHECK: %[[TMP_74:.*]] = "mhlo.negate"(%[[TMP_72]]) + // CHECK: %[[TMP_77:.*]] = "mhlo.select"(%[[TMP_73]], %[[TMP_76]], %[[TMP_74]]) + // CHECK: %[[TMP_78:.*]] = "mhlo.select"(%[[TMP_9]], %[[TMP_77]], %[[TMP_63]]) + // CHECK: %[[TMP_79:.*]] = "mhlo.abs"(%[[ARG]]) + // CHECK: %[[TMP_80:.*]] = mhlo.constant dense<0x7F800000> + // CHECK: %[[TMP_81:.*]] = "mhlo.compare"(%[[TMP_79]], %[[TMP_80]]) {comparison_direction = "EQ"} + // CHECK: %[[TMP_0:.*]] = mhlo.constant dense<0x7F800000> + // CHECK: %[[TMP_82:.*]] = "mhlo.select"(%[[TMP_81]], %[[TMP_0]], %[[TMP_78]]) + // CHECK: return %[[TMP_82]] + %1 = chlo.lgamma %arg : tensor -> tensor + return %1 : tensor +} + +// CHECK-LABEL: @lgamma_f16 +// CHECK-SAME: (%[[ARG:.*]]: tensor) +func @lgamma_f16(%arg : tensor) -> tensor { + // CHECK: "mhlo.convert"(%[[ARG]]) : (tensor) -> tensor + // CHECK: %[[RES:.*]] = "mhlo.convert"(%{{.*}}) : (tensor) -> tensor + // CHECK: return %[[RES]] + %1 = chlo.lgamma %arg : tensor -> tensor + return %1 : tensor +}