diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index a8a214e..e34a82e 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -555,6 +555,15 @@ def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like", let hasCanonicalizer = 1; } +def HLOClient_DigammaOp : HLOClient_UnaryElementwiseOp<"digamma", + [SameOperandsAndResultType], HLO_FpTensor, HLO_FpTensor> { + let summary = "Digamma function"; + + let description = [{ + Returns `Digamma(operand)` element-wise. + }]; +} + def HLOClient_ErfOp : HLOClient_UnaryElementwiseOp<"erf", [SameOperandsAndResultType], HLO_FpTensor, HLO_FpTensor> { let summary = "Erfc operator"; diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index 9705b5c..794c2a3 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -625,6 +625,119 @@ Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc, lgamma); } +// Compute the Digamma function using Lanczos' approximation from "A Precision +// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis +// series B. Vol. 1: +// digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z) +// with t(z) = z + kLanczosGamma + 1/2 +// a(z) = kBaseLanczosCoeff +// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k)) +// a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k)) +Value MaterializeDigamma(ConversionPatternRewriter &rewriter, Location loc, + Value x) { + // If the input is less than 0.5 use Euler's reflection formula. + // digamma(x) = digamma(1 - x) - pi * cot(pi * x) + // Let z be + // z = -x if x < 1/2 + // z = x - 1 otheriwse + const StringAttr kLT = rewriter.getStringAttr( + mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT)); + Value half = getConstantLike(rewriter, loc, 0.5, x); + Value need_to_reflect = rewriter.create(loc, x, half, kLT); + Value neg_x = rewriter.create(loc, x); + Value one = getConstantLike(rewriter, loc, 1, x); + Value x_sub_one = rewriter.create(loc, x, one); + Value z = + rewriter.create(loc, need_to_reflect, neg_x, x_sub_one); + + // Materialize + // a(z) = kBaseLanczosCoeff + // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k)) + // a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k)) + Value zero = getConstantLike(rewriter, loc, 0.0, x); + Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x); + Value a_prime = zero; + for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) { + Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x); + Value one_based_index = getConstantLike(rewriter, loc, i + 1, x); + Value z_term = rewriter.create(loc, z, one_based_index); + a_prime = rewriter.create( + loc, a_prime, + rewriter.create( + loc, coeff, rewriter.create(loc, z_term, z_term))); + a = rewriter.create( + loc, a, rewriter.create(loc, coeff, z_term)); + } + + // To improve accuracy on platforms with less-precise log implementations, + // compute log(kLanczosGamma + 1/2) at compile time and use log1p on the + // device. + // Materialize as + // log(t) = log(kLanczosGamma + 1/2 + z) + // = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)). + Value lanczos_plus_half = + getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x); + Value t = rewriter.create(loc, lanczos_plus_half, z); + Value log_term = + getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x); + Value log1p_term = rewriter.create( + loc, rewriter.create(loc, z, lanczos_plus_half)); + Value log_t = rewriter.create(loc, log_term, log1p_term); + + // Materialize the final result (modulo reflection) as + // digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z). + Value a_prime_div_a = rewriter.create(loc, a_prime, a); + Value lanczos_gamma_div_t = rewriter.create( + loc, getConstantLike(rewriter, loc, kLanczosGamma, x), t); + Value digamma = rewriter.create( + loc, rewriter.create(loc, log_t, a_prime_div_a), + lanczos_gamma_div_t); + + // We need to be careful how we compute cot(pi * input) below: For + // near-integral arguments, pi * input can lose precision. + // + // Input is already known to be less than 0.5 (otherwise we don't have to + // reflect). We shift values smaller than -0.5 into the range [-0.5, 0.5] to + // increase precision of pi * x and the resulting cotangent. + Value reduced_x = rewriter.create( + loc, x, + rewriter.create( + loc, rewriter.create( + loc, rewriter.create( + loc, x, getConstantLike(rewriter, loc, 0.5, x))))); + + // Materialize reflection for inputs less than 0.5 as + // digamma(x) = digamma(1 - x) - pi * cot(pi * x) + // = digamma(1 - x) - pi * cos(pi * x) / sin(pi * x) + Value pi = getConstantLike(rewriter, loc, M_PI, x); + Value pi_mul_reduced_x = rewriter.create(loc, pi, reduced_x); + Value cos = rewriter.create(loc, pi_mul_reduced_x); + Value sin = rewriter.create(loc, pi_mul_reduced_x); + Value reflection = rewriter.create( + loc, digamma, + rewriter.create( + loc, rewriter.create(loc, pi, cos), sin)); + + // Select whether or not to rely on the reflection. + digamma = rewriter.create(loc, need_to_reflect, reflection, + digamma); + + // Digamma has poles at negative integers and zero; return nan for those. + const StringAttr kLE = rewriter.getStringAttr( + mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LE)); + Value is_le_zero = rewriter.create(loc, x, zero, kLE); + const StringAttr kEQ = rewriter.getStringAttr( + mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ)); + Value is_int = rewriter.create( + loc, x, rewriter.create(loc, x), kEQ); + Value is_pole = rewriter.create(loc, is_le_zero, is_int); + return rewriter.create( + loc, is_pole, + getConstantLike(rewriter, loc, std::numeric_limits::quiet_NaN(), + x), + digamma); +} + struct ConvertLgammaOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( @@ -639,6 +752,20 @@ struct ConvertLgammaOp : public OpConversionPattern { } }; +struct ConvertDigammaOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + DigammaOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + DigammaOp::Adaptor transformed(operands); + FloatType min_precision_ty = rewriter.getF32Type(); + rewriter.replaceOp( + op, MaterializeWithUpcast(rewriter, op.getLoc(), transformed.operand(), + min_precision_ty, &MaterializeDigamma)); + return success(); + } +}; + // Converts binary ops that statically are determined to not broadcast directly // to the corresponding mhlo non-broadcasting op. template @@ -790,8 +917,13 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context, context, patterns, 5); // Other patterns. - patterns->insertinsert(context); + // clang-format on } } // namespace chlo diff --git a/tests/chlo_legalize_to_mhlo.mlir b/tests/chlo_legalize_to_mhlo.mlir index aad137f..3e7d931 100644 --- a/tests/chlo_legalize_to_mhlo.mlir +++ b/tests/chlo_legalize_to_mhlo.mlir @@ -875,3 +875,233 @@ func @lgamma_f16(%arg : tensor) -> tensor { %1 = chlo.lgamma %arg : tensor -> tensor return %1 : tensor } + +// CHECK-LABEL: @digamma_f64 +// CHECK-SAME: (%[[ARG:.*]]: tensor) +func @digamma_f64(%arg : tensor) -> tensor { + // CHECK: %[[TMP_0:.*]] = mhlo.constant dense<5.000000e-01> + // CHECK: %[[TMP_1:.*]] = "mhlo.compare"(%arg0, %[[TMP_0]]) {comparison_direction = "LT"} + // CHECK: %[[TMP_2:.*]] = "mhlo.negate"(%arg0) + // CHECK: %[[TMP_3:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[TMP_4:.*]] = mhlo.subtract %arg0, %[[TMP_3]] + // CHECK: %[[TMP_5:.*]] = "mhlo.select"(%[[TMP_1]], %[[TMP_2]], %[[TMP_4]]) + // CHECK: %[[TMP_6:.*]] = mhlo.constant dense<0.000000e+00> + // CHECK: %[[TMP_7:.*]] = mhlo.constant dense<0.99999999999980993> + // CHECK: %[[TMP_8:.*]] = mhlo.constant dense<676.5203681218851> + // CHECK: %[[TMP_9:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[TMP_10:.*]] = mhlo.add %[[TMP_5]], %[[TMP_9]] + // CHECK: %[[TMP_11:.*]] = mhlo.multiply %[[TMP_10]], %[[TMP_10]] + // CHECK: %[[TMP_12:.*]] = mhlo.divide %[[TMP_8]], %[[TMP_11]] + // CHECK: %[[TMP_13:.*]] = mhlo.subtract %[[TMP_6]], %[[TMP_12]] + // CHECK: %[[TMP_14:.*]] = mhlo.divide %[[TMP_8]], %[[TMP_10]] + // CHECK: %[[TMP_15:.*]] = mhlo.add %[[TMP_7]], %[[TMP_14]] + // CHECK: %[[TMP_16:.*]] = mhlo.constant dense<-1259.1392167224028> + // CHECK: %[[TMP_17:.*]] = mhlo.constant dense<2.000000e+00> + // CHECK: %[[TMP_18:.*]] = mhlo.add %[[TMP_5]], %[[TMP_17]] + // CHECK: %[[TMP_19:.*]] = mhlo.multiply %[[TMP_18]], %[[TMP_18]] + // CHECK: %[[TMP_20:.*]] = mhlo.divide %[[TMP_16]], %[[TMP_19]] + // CHECK: %[[TMP_21:.*]] = mhlo.subtract %[[TMP_13]], %[[TMP_20]] + // CHECK: %[[TMP_22:.*]] = mhlo.divide %[[TMP_16]], %[[TMP_18]] + // CHECK: %[[TMP_23:.*]] = mhlo.add %[[TMP_15]], %[[TMP_22]] + // CHECK: %[[TMP_24:.*]] = mhlo.constant dense<771.32342877765313> + // CHECK: %[[TMP_25:.*]] = mhlo.constant dense<3.000000e+00> + // CHECK: %[[TMP_26:.*]] = mhlo.add %[[TMP_5]], %[[TMP_25]] + // CHECK: %[[TMP_27:.*]] = mhlo.multiply %[[TMP_26]], %[[TMP_26]] + // CHECK: %[[TMP_28:.*]] = mhlo.divide %[[TMP_24]], %[[TMP_27]] + // CHECK: %[[TMP_29:.*]] = mhlo.subtract %[[TMP_21]], %[[TMP_28]] + // CHECK: %[[TMP_30:.*]] = mhlo.divide %[[TMP_24]], %[[TMP_26]] + // CHECK: %[[TMP_31:.*]] = mhlo.add %[[TMP_23]], %[[TMP_30]] + // CHECK: %[[TMP_32:.*]] = mhlo.constant dense<-176.61502916214059> + // CHECK: %[[TMP_33:.*]] = mhlo.constant dense<4.000000e+00> + // CHECK: %[[TMP_34:.*]] = mhlo.add %[[TMP_5]], %[[TMP_33]] + // CHECK: %[[TMP_35:.*]] = mhlo.multiply %[[TMP_34]], %[[TMP_34]] + // CHECK: %[[TMP_36:.*]] = mhlo.divide %[[TMP_32]], %[[TMP_35]] + // CHECK: %[[TMP_37:.*]] = mhlo.subtract %[[TMP_29]], %[[TMP_36]] + // CHECK: %[[TMP_38:.*]] = mhlo.divide %[[TMP_32]], %[[TMP_34]] + // CHECK: %[[TMP_39:.*]] = mhlo.add %[[TMP_31]], %[[TMP_38]] + // CHECK: %[[TMP_40:.*]] = mhlo.constant dense<12.507343278686905> + // CHECK: %[[TMP_41:.*]] = mhlo.constant dense<5.000000e+00> + // CHECK: %[[TMP_42:.*]] = mhlo.add %[[TMP_5]], %[[TMP_41]] + // CHECK: %[[TMP_43:.*]] = mhlo.multiply %[[TMP_42]], %[[TMP_42]] + // CHECK: %[[TMP_44:.*]] = mhlo.divide %[[TMP_40]], %[[TMP_43]] + // CHECK: %[[TMP_45:.*]] = mhlo.subtract %[[TMP_37]], %[[TMP_44]] + // CHECK: %[[TMP_46:.*]] = mhlo.divide %[[TMP_40]], %[[TMP_42]] + // CHECK: %[[TMP_47:.*]] = mhlo.add %[[TMP_39]], %[[TMP_46]] + // CHECK: %[[TMP_48:.*]] = mhlo.constant dense<-0.13857109526572012> + // CHECK: %[[TMP_49:.*]] = mhlo.constant dense<6.000000e+00> + // CHECK: %[[TMP_50:.*]] = mhlo.add %[[TMP_5]], %[[TMP_49]] + // CHECK: %[[TMP_51:.*]] = mhlo.multiply %[[TMP_50]], %[[TMP_50]] + // CHECK: %[[TMP_52:.*]] = mhlo.divide %[[TMP_48]], %[[TMP_51]] + // CHECK: %[[TMP_53:.*]] = mhlo.subtract %[[TMP_45]], %[[TMP_52]] + // CHECK: %[[TMP_54:.*]] = mhlo.divide %[[TMP_48]], %[[TMP_50]] + // CHECK: %[[TMP_55:.*]] = mhlo.add %[[TMP_47]], %[[TMP_54]] + // CHECK: %[[TMP_56:.*]] = mhlo.constant dense<9.9843695780195716E-6> + // CHECK: %[[TMP_57:.*]] = mhlo.constant dense<7.000000e+00> + // CHECK: %[[TMP_58:.*]] = mhlo.add %[[TMP_5]], %[[TMP_57]] + // CHECK: %[[TMP_59:.*]] = mhlo.multiply %[[TMP_58]], %[[TMP_58]] + // CHECK: %[[TMP_60:.*]] = mhlo.divide %[[TMP_56]], %[[TMP_59]] + // CHECK: %[[TMP_61:.*]] = mhlo.subtract %[[TMP_53]], %[[TMP_60]] + // CHECK: %[[TMP_62:.*]] = mhlo.divide %[[TMP_56]], %[[TMP_58]] + // CHECK: %[[TMP_63:.*]] = mhlo.add %[[TMP_55]], %[[TMP_62]] + // CHECK: %[[TMP_64:.*]] = mhlo.constant dense<1.5056327351493116E-7> + // CHECK: %[[TMP_65:.*]] = mhlo.constant dense<8.000000e+00> + // CHECK: %[[TMP_66:.*]] = mhlo.add %[[TMP_5]], %[[TMP_65]] + // CHECK: %[[TMP_67:.*]] = mhlo.multiply %[[TMP_66]], %[[TMP_66]] + // CHECK: %[[TMP_68:.*]] = mhlo.divide %[[TMP_64]], %[[TMP_67]] + // CHECK: %[[TMP_69:.*]] = mhlo.subtract %[[TMP_61]], %[[TMP_68]] + // CHECK: %[[TMP_70:.*]] = mhlo.divide %[[TMP_64]], %[[TMP_66]] + // CHECK: %[[TMP_71:.*]] = mhlo.add %[[TMP_63]], %[[TMP_70]] + // CHECK: %[[TMP_72:.*]] = mhlo.constant dense<7.500000e+00> + // CHECK: %[[TMP_73:.*]] = mhlo.add %[[TMP_72]], %[[TMP_5]] + // CHECK: %[[TMP_74:.*]] = mhlo.constant dense<2.0149030205422647> + // CHECK: %[[TMP_75:.*]] = mhlo.divide %[[TMP_5]], %[[TMP_72]] + // CHECK: %[[TMP_76:.*]] = "mhlo.log_plus_one"(%[[TMP_75]]) + // CHECK: %[[TMP_77:.*]] = mhlo.add %[[TMP_74]], %[[TMP_76]] + // CHECK: %[[TMP_78:.*]] = mhlo.divide %[[TMP_69]], %[[TMP_71]] + // CHECK: %[[TMP_79:.*]] = mhlo.constant dense<7.000000e+00> + // CHECK: %[[TMP_80:.*]] = mhlo.divide %[[TMP_79]], %[[TMP_73]] + // CHECK: %[[TMP_81:.*]] = mhlo.add %[[TMP_77]], %[[TMP_78]] + // CHECK: %[[TMP_82:.*]] = mhlo.subtract %[[TMP_81]], %[[TMP_80]] + // CHECK: %[[TMP_83:.*]] = mhlo.constant dense<5.000000e-01> + // CHECK: %[[TMP_84:.*]] = mhlo.add %arg0, %[[TMP_83]] + // CHECK: %[[TMP_85:.*]] = "mhlo.floor"(%[[TMP_84]]) + // CHECK: %[[TMP_86:.*]] = "mhlo.abs"(%[[TMP_85]]) + // CHECK: %[[TMP_87:.*]] = mhlo.add %arg0, %[[TMP_86]] + // CHECK: %[[TMP_88:.*]] = mhlo.constant dense<3.1415926535897931> + // CHECK: %[[TMP_89:.*]] = mhlo.multiply %[[TMP_88]], %[[TMP_87]] + // CHECK: %[[TMP_90:.*]] = "mhlo.cosine"(%[[TMP_89]]) + // CHECK: %[[TMP_92:.*]] = "mhlo.sine"(%[[TMP_89]]) + // CHECK: %[[TMP_91:.*]] = mhlo.multiply %[[TMP_88]], %[[TMP_90]] + // CHECK: %[[TMP_93:.*]] = mhlo.divide %[[TMP_91]], %[[TMP_92]] + // CHECK: %[[TMP_94:.*]] = mhlo.subtract %[[TMP_82]], %[[TMP_93]] + // CHECK: %[[TMP_95:.*]] = "mhlo.select"(%[[TMP_1]], %[[TMP_94]], %[[TMP_82]]) + // CHECK: %[[TMP_96:.*]] = "mhlo.compare"(%arg0, %[[TMP_6]]) {comparison_direction = "LE"} + // CHECK: %[[TMP_97:.*]] = "mhlo.floor"(%arg0) + // CHECK: %[[TMP_98:.*]] = "mhlo.compare"(%arg0, %[[TMP_97]]) {comparison_direction = "EQ"} + // CHECK: %[[TMP_99:.*]] = mhlo.and %[[TMP_96]], %[[TMP_98]] + // CHECK: %[[TMP_100:.*]] = mhlo.constant dense<0x7FF8000000000000> + // CHECK: %[[RES:.*]] = "mhlo.select"(%[[TMP_99]], %[[TMP_100]], %[[TMP_95]]) + // CHECK: return %[[RES]] + %1 = chlo.digamma %arg : tensor -> tensor + return %1 : tensor +} + +// CHECK-LABEL: @digamma_f32 +// CHECK-SAME: (%[[ARG:.*]]: tensor) +func @digamma_f32(%arg : tensor) -> tensor { + // CHECK: %[[TMP_0:.*]] = mhlo.constant dense<5.000000e-01> + // CHECK: %[[TMP_1:.*]] = "mhlo.compare"(%arg0, %[[TMP_0]]) {comparison_direction = "LT"} + // CHECK: %[[TMP_2:.*]] = "mhlo.negate"(%arg0) + // CHECK: %[[TMP_3:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[TMP_4:.*]] = mhlo.subtract %arg0, %[[TMP_3]] + // CHECK: %[[TMP_5:.*]] = "mhlo.select"(%[[TMP_1]], %[[TMP_2]], %[[TMP_4]]) + // CHECK: %[[TMP_6:.*]] = mhlo.constant dense<0.000000e+00> + // CHECK: %[[TMP_7:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[TMP_8:.*]] = mhlo.constant dense<676.520386> + // CHECK: %[[TMP_9:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[TMP_10:.*]] = mhlo.add %[[TMP_5]], %[[TMP_9]] + // CHECK: %[[TMP_11:.*]] = mhlo.multiply %[[TMP_10]], %[[TMP_10]] + // CHECK: %[[TMP_12:.*]] = mhlo.divide %[[TMP_8]], %[[TMP_11]] + // CHECK: %[[TMP_13:.*]] = mhlo.subtract %[[TMP_6]], %[[TMP_12]] + // CHECK: %[[TMP_14:.*]] = mhlo.divide %[[TMP_8]], %[[TMP_10]] + // CHECK: %[[TMP_15:.*]] = mhlo.add %[[TMP_7]], %[[TMP_14]] + // CHECK: %[[TMP_16:.*]] = mhlo.constant dense<-1259.13916> + // CHECK: %[[TMP_17:.*]] = mhlo.constant dense<2.000000e+00> + // CHECK: %[[TMP_18:.*]] = mhlo.add %[[TMP_5]], %[[TMP_17]] + // CHECK: %[[TMP_19:.*]] = mhlo.multiply %[[TMP_18]], %[[TMP_18]] + // CHECK: %[[TMP_20:.*]] = mhlo.divide %[[TMP_16]], %[[TMP_19]] + // CHECK: %[[TMP_21:.*]] = mhlo.subtract %[[TMP_13]], %[[TMP_20]] + // CHECK: %[[TMP_22:.*]] = mhlo.divide %[[TMP_16]], %[[TMP_18]] + // CHECK: %[[TMP_23:.*]] = mhlo.add %[[TMP_15]], %[[TMP_22]] + // CHECK: %[[TMP_24:.*]] = mhlo.constant dense<771.323425> + // CHECK: %[[TMP_25:.*]] = mhlo.constant dense<3.000000e+00> + // CHECK: %[[TMP_26:.*]] = mhlo.add %[[TMP_5]], %[[TMP_25]] + // CHECK: %[[TMP_27:.*]] = mhlo.multiply %[[TMP_26]], %[[TMP_26]] + // CHECK: %[[TMP_28:.*]] = mhlo.divide %[[TMP_24]], %[[TMP_27]] + // CHECK: %[[TMP_29:.*]] = mhlo.subtract %[[TMP_21]], %[[TMP_28]] + // CHECK: %[[TMP_30:.*]] = mhlo.divide %[[TMP_24]], %[[TMP_26]] + // CHECK: %[[TMP_31:.*]] = mhlo.add %[[TMP_23]], %[[TMP_30]] + // CHECK: %[[TMP_32:.*]] = mhlo.constant dense<-176.615036> + // CHECK: %[[TMP_33:.*]] = mhlo.constant dense<4.000000e+00> + // CHECK: %[[TMP_34:.*]] = mhlo.add %[[TMP_5]], %[[TMP_33]] + // CHECK: %[[TMP_35:.*]] = mhlo.multiply %[[TMP_34]], %[[TMP_34]] + // CHECK: %[[TMP_36:.*]] = mhlo.divide %[[TMP_32]], %[[TMP_35]] + // CHECK: %[[TMP_37:.*]] = mhlo.subtract %[[TMP_29]], %[[TMP_36]] + // CHECK: %[[TMP_38:.*]] = mhlo.divide %[[TMP_32]], %[[TMP_34]] + // CHECK: %[[TMP_39:.*]] = mhlo.add %[[TMP_31]], %[[TMP_38]] + // CHECK: %[[TMP_40:.*]] = mhlo.constant dense<12.5073433> + // CHECK: %[[TMP_41:.*]] = mhlo.constant dense<5.000000e+00> + // CHECK: %[[TMP_42:.*]] = mhlo.add %[[TMP_5]], %[[TMP_41]] + // CHECK: %[[TMP_43:.*]] = mhlo.multiply %[[TMP_42]], %[[TMP_42]] + // CHECK: %[[TMP_44:.*]] = mhlo.divide %[[TMP_40]], %[[TMP_43]] + // CHECK: %[[TMP_45:.*]] = mhlo.subtract %[[TMP_37]], %[[TMP_44]] + // CHECK: %[[TMP_46:.*]] = mhlo.divide %[[TMP_40]], %[[TMP_42]] + // CHECK: %[[TMP_47:.*]] = mhlo.add %[[TMP_39]], %[[TMP_46]] + // CHECK: %[[TMP_48:.*]] = mhlo.constant dense<-0.138571098> + // CHECK: %[[TMP_49:.*]] = mhlo.constant dense<6.000000e+00> + // CHECK: %[[TMP_50:.*]] = mhlo.add %[[TMP_5]], %[[TMP_49]] + // CHECK: %[[TMP_51:.*]] = mhlo.multiply %[[TMP_50]], %[[TMP_50]] + // CHECK: %[[TMP_52:.*]] = mhlo.divide %[[TMP_48]], %[[TMP_51]] + // CHECK: %[[TMP_53:.*]] = mhlo.subtract %[[TMP_45]], %[[TMP_52]] + // CHECK: %[[TMP_54:.*]] = mhlo.divide %[[TMP_48]], %[[TMP_50]] + // CHECK: %[[TMP_55:.*]] = mhlo.add %[[TMP_47]], %[[TMP_54]] + // CHECK: %[[TMP_56:.*]] = mhlo.constant dense<9.98436917E-6> + // CHECK: %[[TMP_57:.*]] = mhlo.constant dense<7.000000e+00> + // CHECK: %[[TMP_58:.*]] = mhlo.add %[[TMP_5]], %[[TMP_57]] + // CHECK: %[[TMP_59:.*]] = mhlo.multiply %[[TMP_58]], %[[TMP_58]] + // CHECK: %[[TMP_60:.*]] = mhlo.divide %[[TMP_56]], %[[TMP_59]] + // CHECK: %[[TMP_61:.*]] = mhlo.subtract %[[TMP_53]], %[[TMP_60]] + // CHECK: %[[TMP_62:.*]] = mhlo.divide %[[TMP_56]], %[[TMP_58]] + // CHECK: %[[TMP_63:.*]] = mhlo.add %[[TMP_55]], %[[TMP_62]] + // CHECK: %[[TMP_64:.*]] = mhlo.constant dense<1.50563267E-7> + // CHECK: %[[TMP_65:.*]] = mhlo.constant dense<8.000000e+00> + // CHECK: %[[TMP_66:.*]] = mhlo.add %[[TMP_5]], %[[TMP_65]] + // CHECK: %[[TMP_67:.*]] = mhlo.multiply %[[TMP_66]], %[[TMP_66]] + // CHECK: %[[TMP_68:.*]] = mhlo.divide %[[TMP_64]], %[[TMP_67]] + // CHECK: %[[TMP_69:.*]] = mhlo.subtract %[[TMP_61]], %[[TMP_68]] + // CHECK: %[[TMP_70:.*]] = mhlo.divide %[[TMP_64]], %[[TMP_66]] + // CHECK: %[[TMP_71:.*]] = mhlo.add %[[TMP_63]], %[[TMP_70]] + // CHECK: %[[TMP_72:.*]] = mhlo.constant dense<7.500000e+00> + // CHECK: %[[TMP_73:.*]] = mhlo.add %[[TMP_72]], %[[TMP_5]] + // CHECK: %[[TMP_74:.*]] = mhlo.constant dense<2.01490307> + // CHECK: %[[TMP_75:.*]] = mhlo.divide %[[TMP_5]], %[[TMP_72]] + // CHECK: %[[TMP_76:.*]] = "mhlo.log_plus_one"(%[[TMP_75]]) + // CHECK: %[[TMP_77:.*]] = mhlo.add %[[TMP_74]], %[[TMP_76]] + // CHECK: %[[TMP_78:.*]] = mhlo.divide %[[TMP_69]], %[[TMP_71]] + // CHECK: %[[TMP_79:.*]] = mhlo.constant dense<7.000000e+00> + // CHECK: %[[TMP_80:.*]] = mhlo.divide %[[TMP_79]], %[[TMP_73]] + // CHECK: %[[TMP_81:.*]] = mhlo.add %[[TMP_77]], %[[TMP_78]] + // CHECK: %[[TMP_82:.*]] = mhlo.subtract %[[TMP_81]], %[[TMP_80]] + // CHECK: %[[TMP_83:.*]] = mhlo.constant dense<5.000000e-01> + // CHECK: %[[TMP_84:.*]] = mhlo.add %arg0, %[[TMP_83]] + // CHECK: %[[TMP_85:.*]] = "mhlo.floor"(%[[TMP_84]]) + // CHECK: %[[TMP_86:.*]] = "mhlo.abs"(%[[TMP_85]]) + // CHECK: %[[TMP_87:.*]] = mhlo.add %arg0, %[[TMP_86]] + // CHECK: %[[TMP_88:.*]] = mhlo.constant dense<3.14159274> + // CHECK: %[[TMP_89:.*]] = mhlo.multiply %[[TMP_88]], %[[TMP_87]] + // CHECK: %[[TMP_90:.*]] = "mhlo.cosine"(%[[TMP_89]]) + // CHECK: %[[TMP_92:.*]] = "mhlo.sine"(%[[TMP_89]]) + // CHECK: %[[TMP_91:.*]] = mhlo.multiply %[[TMP_88]], %[[TMP_90]] + // CHECK: %[[TMP_93:.*]] = mhlo.divide %[[TMP_91]], %[[TMP_92]] + // CHECK: %[[TMP_94:.*]] = mhlo.subtract %[[TMP_82]], %[[TMP_93]] + // CHECK: %[[TMP_95:.*]] = "mhlo.select"(%[[TMP_1]], %[[TMP_94]], %[[TMP_82]]) + // CHECK: %[[TMP_96:.*]] = "mhlo.compare"(%arg0, %[[TMP_6]]) {comparison_direction = "LE"} + // CHECK: %[[TMP_97:.*]] = "mhlo.floor"(%arg0) + // CHECK: %[[TMP_98:.*]] = "mhlo.compare"(%arg0, %[[TMP_97]]) {comparison_direction = "EQ"} + // CHECK: %[[TMP_99:.*]] = mhlo.and %[[TMP_96]], %[[TMP_98]] + // CHECK: %[[TMP_100:.*]] = mhlo.constant dense<0x7FC00000> + // CHECK: %[[RES:.*]] = "mhlo.select"(%[[TMP_99]], %[[TMP_100]], %[[TMP_95]]) + // CHECK: return %[[RES]] + %1 = chlo.digamma %arg : tensor -> tensor + return %1 : tensor +} + +// CHECK-LABEL: @digamma_f16 +// CHECK-SAME: (%[[ARG:.*]]: tensor) +func @digamma_f16(%arg : tensor) -> tensor { + // CHECK: "mhlo.convert"(%[[ARG]]) : (tensor) -> tensor + // CHECK: %[[RES:.*]] = "mhlo.convert"(%{{.*}}) : (tensor) -> tensor + // CHECK: return %[[RES]] + %1 = chlo.digamma %arg : tensor -> tensor + return %1 : tensor +}