[MLIR][CHLO] Add `chlo.digamma` and lowering to MHLO
PiperOrigin-RevId: 355122765
This commit is contained in:
parent
c2115f56c7
commit
f40ccc5b4b
|
@ -555,6 +555,15 @@ def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like",
|
|||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def HLOClient_DigammaOp : HLOClient_UnaryElementwiseOp<"digamma",
|
||||
[SameOperandsAndResultType], HLO_FpTensor, HLO_FpTensor> {
|
||||
let summary = "Digamma function";
|
||||
|
||||
let description = [{
|
||||
Returns `Digamma(operand)` element-wise.
|
||||
}];
|
||||
}
|
||||
|
||||
def HLOClient_ErfOp : HLOClient_UnaryElementwiseOp<"erf",
|
||||
[SameOperandsAndResultType], HLO_FpTensor, HLO_FpTensor> {
|
||||
let summary = "Erfc operator";
|
||||
|
|
|
@ -625,6 +625,119 @@ Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc,
|
|||
lgamma);
|
||||
}
|
||||
|
||||
// Compute the Digamma function using Lanczos' approximation from "A Precision
|
||||
// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
|
||||
// series B. Vol. 1:
|
||||
// digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z)
|
||||
// with t(z) = z + kLanczosGamma + 1/2
|
||||
// a(z) = kBaseLanczosCoeff
|
||||
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
|
||||
// a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
|
||||
Value MaterializeDigamma(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value x) {
|
||||
// If the input is less than 0.5 use Euler's reflection formula.
|
||||
// digamma(x) = digamma(1 - x) - pi * cot(pi * x)
|
||||
// Let z be
|
||||
// z = -x if x < 1/2
|
||||
// z = x - 1 otheriwse
|
||||
const StringAttr kLT = rewriter.getStringAttr(
|
||||
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
|
||||
Value half = getConstantLike(rewriter, loc, 0.5, x);
|
||||
Value need_to_reflect = rewriter.create<mhlo::CompareOp>(loc, x, half, kLT);
|
||||
Value neg_x = rewriter.create<mhlo::NegOp>(loc, x);
|
||||
Value one = getConstantLike(rewriter, loc, 1, x);
|
||||
Value x_sub_one = rewriter.create<mhlo::SubOp>(loc, x, one);
|
||||
Value z =
|
||||
rewriter.create<mhlo::SelectOp>(loc, need_to_reflect, neg_x, x_sub_one);
|
||||
|
||||
// Materialize
|
||||
// a(z) = kBaseLanczosCoeff
|
||||
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
|
||||
// a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
|
||||
Value zero = getConstantLike(rewriter, loc, 0.0, x);
|
||||
Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
|
||||
Value a_prime = zero;
|
||||
for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
|
||||
Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x);
|
||||
Value one_based_index = getConstantLike(rewriter, loc, i + 1, x);
|
||||
Value z_term = rewriter.create<mhlo::AddOp>(loc, z, one_based_index);
|
||||
a_prime = rewriter.create<mhlo::SubOp>(
|
||||
loc, a_prime,
|
||||
rewriter.create<mhlo::DivOp>(
|
||||
loc, coeff, rewriter.create<mhlo::MulOp>(loc, z_term, z_term)));
|
||||
a = rewriter.create<mhlo::AddOp>(
|
||||
loc, a, rewriter.create<mhlo::DivOp>(loc, coeff, z_term));
|
||||
}
|
||||
|
||||
// To improve accuracy on platforms with less-precise log implementations,
|
||||
// compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
|
||||
// device.
|
||||
// Materialize as
|
||||
// log(t) = log(kLanczosGamma + 1/2 + z)
|
||||
// = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
|
||||
Value lanczos_plus_half =
|
||||
getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
|
||||
Value t = rewriter.create<mhlo::AddOp>(loc, lanczos_plus_half, z);
|
||||
Value log_term =
|
||||
getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x);
|
||||
Value log1p_term = rewriter.create<mhlo::Log1pOp>(
|
||||
loc, rewriter.create<mhlo::DivOp>(loc, z, lanczos_plus_half));
|
||||
Value log_t = rewriter.create<mhlo::AddOp>(loc, log_term, log1p_term);
|
||||
|
||||
// Materialize the final result (modulo reflection) as
|
||||
// digamma(z + 1) = log(t(z)) + a'(z) / a(z) - kLanczosGamma / t(z).
|
||||
Value a_prime_div_a = rewriter.create<mhlo::DivOp>(loc, a_prime, a);
|
||||
Value lanczos_gamma_div_t = rewriter.create<mhlo::DivOp>(
|
||||
loc, getConstantLike(rewriter, loc, kLanczosGamma, x), t);
|
||||
Value digamma = rewriter.create<mhlo::SubOp>(
|
||||
loc, rewriter.create<mhlo::AddOp>(loc, log_t, a_prime_div_a),
|
||||
lanczos_gamma_div_t);
|
||||
|
||||
// We need to be careful how we compute cot(pi * input) below: For
|
||||
// near-integral arguments, pi * input can lose precision.
|
||||
//
|
||||
// Input is already known to be less than 0.5 (otherwise we don't have to
|
||||
// reflect). We shift values smaller than -0.5 into the range [-0.5, 0.5] to
|
||||
// increase precision of pi * x and the resulting cotangent.
|
||||
Value reduced_x = rewriter.create<mhlo::AddOp>(
|
||||
loc, x,
|
||||
rewriter.create<mhlo::AbsOp>(
|
||||
loc, rewriter.create<mhlo::FloorOp>(
|
||||
loc, rewriter.create<mhlo::AddOp>(
|
||||
loc, x, getConstantLike(rewriter, loc, 0.5, x)))));
|
||||
|
||||
// Materialize reflection for inputs less than 0.5 as
|
||||
// digamma(x) = digamma(1 - x) - pi * cot(pi * x)
|
||||
// = digamma(1 - x) - pi * cos(pi * x) / sin(pi * x)
|
||||
Value pi = getConstantLike(rewriter, loc, M_PI, x);
|
||||
Value pi_mul_reduced_x = rewriter.create<mhlo::MulOp>(loc, pi, reduced_x);
|
||||
Value cos = rewriter.create<mhlo::CosOp>(loc, pi_mul_reduced_x);
|
||||
Value sin = rewriter.create<mhlo::SinOp>(loc, pi_mul_reduced_x);
|
||||
Value reflection = rewriter.create<mhlo::SubOp>(
|
||||
loc, digamma,
|
||||
rewriter.create<mhlo::DivOp>(
|
||||
loc, rewriter.create<mhlo::MulOp>(loc, pi, cos), sin));
|
||||
|
||||
// Select whether or not to rely on the reflection.
|
||||
digamma = rewriter.create<mhlo::SelectOp>(loc, need_to_reflect, reflection,
|
||||
digamma);
|
||||
|
||||
// Digamma has poles at negative integers and zero; return nan for those.
|
||||
const StringAttr kLE = rewriter.getStringAttr(
|
||||
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LE));
|
||||
Value is_le_zero = rewriter.create<mhlo::CompareOp>(loc, x, zero, kLE);
|
||||
const StringAttr kEQ = rewriter.getStringAttr(
|
||||
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
|
||||
Value is_int = rewriter.create<mhlo::CompareOp>(
|
||||
loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kEQ);
|
||||
Value is_pole = rewriter.create<mhlo::AndOp>(loc, is_le_zero, is_int);
|
||||
return rewriter.create<mhlo::SelectOp>(
|
||||
loc, is_pole,
|
||||
getConstantLike(rewriter, loc, std::numeric_limits<double>::quiet_NaN(),
|
||||
x),
|
||||
digamma);
|
||||
}
|
||||
|
||||
struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> {
|
||||
using OpConversionPattern<LgammaOp>::OpConversionPattern;
|
||||
LogicalResult matchAndRewrite(
|
||||
|
@ -639,6 +752,20 @@ struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> {
|
|||
}
|
||||
};
|
||||
|
||||
struct ConvertDigammaOp : public OpConversionPattern<DigammaOp> {
|
||||
using OpConversionPattern<DigammaOp>::OpConversionPattern;
|
||||
LogicalResult matchAndRewrite(
|
||||
DigammaOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
DigammaOp::Adaptor transformed(operands);
|
||||
FloatType min_precision_ty = rewriter.getF32Type();
|
||||
rewriter.replaceOp(
|
||||
op, MaterializeWithUpcast(rewriter, op.getLoc(), transformed.operand(),
|
||||
min_precision_ty, &MaterializeDigamma));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Converts binary ops that statically are determined to not broadcast directly
|
||||
// to the corresponding mhlo non-broadcasting op.
|
||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||
|
@ -790,8 +917,13 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
|||
context, patterns, 5);
|
||||
|
||||
// Other patterns.
|
||||
patterns->insert<ConvertConstantLikeOp, ConvertErfOp, ConvertErfcOp,
|
||||
// clang-format off
|
||||
patterns->insert<ConvertConstantLikeOp,
|
||||
ConvertDigammaOp,
|
||||
ConvertErfOp,
|
||||
ConvertErfcOp,
|
||||
ConvertLgammaOp>(context);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
} // namespace chlo
|
||||
|
|
|
@ -875,3 +875,233 @@ func @lgamma_f16(%arg : tensor<f16>) -> tensor<f16> {
|
|||
%1 = chlo.lgamma %arg : tensor<f16> -> tensor<f16>
|
||||
return %1 : tensor<f16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @digamma_f64
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<f64>)
|
||||
func @digamma_f64(%arg : tensor<f64>) -> tensor<f64> {
|
||||
// CHECK: %[[TMP_0:.*]] = mhlo.constant dense<5.000000e-01>
|
||||
// CHECK: %[[TMP_1:.*]] = "mhlo.compare"(%arg0, %[[TMP_0]]) {comparison_direction = "LT"}
|
||||
// CHECK: %[[TMP_2:.*]] = "mhlo.negate"(%arg0)
|
||||
// CHECK: %[[TMP_3:.*]] = mhlo.constant dense<1.000000e+00>
|
||||
// CHECK: %[[TMP_4:.*]] = mhlo.subtract %arg0, %[[TMP_3]]
|
||||
// CHECK: %[[TMP_5:.*]] = "mhlo.select"(%[[TMP_1]], %[[TMP_2]], %[[TMP_4]])
|
||||
// CHECK: %[[TMP_6:.*]] = mhlo.constant dense<0.000000e+00>
|
||||
// CHECK: %[[TMP_7:.*]] = mhlo.constant dense<0.99999999999980993>
|
||||
// CHECK: %[[TMP_8:.*]] = mhlo.constant dense<676.5203681218851>
|
||||
// CHECK: %[[TMP_9:.*]] = mhlo.constant dense<1.000000e+00>
|
||||
// CHECK: %[[TMP_10:.*]] = mhlo.add %[[TMP_5]], %[[TMP_9]]
|
||||
// CHECK: %[[TMP_11:.*]] = mhlo.multiply %[[TMP_10]], %[[TMP_10]]
|
||||
// CHECK: %[[TMP_12:.*]] = mhlo.divide %[[TMP_8]], %[[TMP_11]]
|
||||
// CHECK: %[[TMP_13:.*]] = mhlo.subtract %[[TMP_6]], %[[TMP_12]]
|
||||
// CHECK: %[[TMP_14:.*]] = mhlo.divide %[[TMP_8]], %[[TMP_10]]
|
||||
// CHECK: %[[TMP_15:.*]] = mhlo.add %[[TMP_7]], %[[TMP_14]]
|
||||
// CHECK: %[[TMP_16:.*]] = mhlo.constant dense<-1259.1392167224028>
|
||||
// CHECK: %[[TMP_17:.*]] = mhlo.constant dense<2.000000e+00>
|
||||
// CHECK: %[[TMP_18:.*]] = mhlo.add %[[TMP_5]], %[[TMP_17]]
|
||||
// CHECK: %[[TMP_19:.*]] = mhlo.multiply %[[TMP_18]], %[[TMP_18]]
|
||||
// CHECK: %[[TMP_20:.*]] = mhlo.divide %[[TMP_16]], %[[TMP_19]]
|
||||
// CHECK: %[[TMP_21:.*]] = mhlo.subtract %[[TMP_13]], %[[TMP_20]]
|
||||
// CHECK: %[[TMP_22:.*]] = mhlo.divide %[[TMP_16]], %[[TMP_18]]
|
||||
// CHECK: %[[TMP_23:.*]] = mhlo.add %[[TMP_15]], %[[TMP_22]]
|
||||
// CHECK: %[[TMP_24:.*]] = mhlo.constant dense<771.32342877765313>
|
||||
// CHECK: %[[TMP_25:.*]] = mhlo.constant dense<3.000000e+00>
|
||||
// CHECK: %[[TMP_26:.*]] = mhlo.add %[[TMP_5]], %[[TMP_25]]
|
||||
// CHECK: %[[TMP_27:.*]] = mhlo.multiply %[[TMP_26]], %[[TMP_26]]
|
||||
// CHECK: %[[TMP_28:.*]] = mhlo.divide %[[TMP_24]], %[[TMP_27]]
|
||||
// CHECK: %[[TMP_29:.*]] = mhlo.subtract %[[TMP_21]], %[[TMP_28]]
|
||||
// CHECK: %[[TMP_30:.*]] = mhlo.divide %[[TMP_24]], %[[TMP_26]]
|
||||
// CHECK: %[[TMP_31:.*]] = mhlo.add %[[TMP_23]], %[[TMP_30]]
|
||||
// CHECK: %[[TMP_32:.*]] = mhlo.constant dense<-176.61502916214059>
|
||||
// CHECK: %[[TMP_33:.*]] = mhlo.constant dense<4.000000e+00>
|
||||
// CHECK: %[[TMP_34:.*]] = mhlo.add %[[TMP_5]], %[[TMP_33]]
|
||||
// CHECK: %[[TMP_35:.*]] = mhlo.multiply %[[TMP_34]], %[[TMP_34]]
|
||||
// CHECK: %[[TMP_36:.*]] = mhlo.divide %[[TMP_32]], %[[TMP_35]]
|
||||
// CHECK: %[[TMP_37:.*]] = mhlo.subtract %[[TMP_29]], %[[TMP_36]]
|
||||
// CHECK: %[[TMP_38:.*]] = mhlo.divide %[[TMP_32]], %[[TMP_34]]
|
||||
// CHECK: %[[TMP_39:.*]] = mhlo.add %[[TMP_31]], %[[TMP_38]]
|
||||
// CHECK: %[[TMP_40:.*]] = mhlo.constant dense<12.507343278686905>
|
||||
// CHECK: %[[TMP_41:.*]] = mhlo.constant dense<5.000000e+00>
|
||||
// CHECK: %[[TMP_42:.*]] = mhlo.add %[[TMP_5]], %[[TMP_41]]
|
||||
// CHECK: %[[TMP_43:.*]] = mhlo.multiply %[[TMP_42]], %[[TMP_42]]
|
||||
// CHECK: %[[TMP_44:.*]] = mhlo.divide %[[TMP_40]], %[[TMP_43]]
|
||||
// CHECK: %[[TMP_45:.*]] = mhlo.subtract %[[TMP_37]], %[[TMP_44]]
|
||||
// CHECK: %[[TMP_46:.*]] = mhlo.divide %[[TMP_40]], %[[TMP_42]]
|
||||
// CHECK: %[[TMP_47:.*]] = mhlo.add %[[TMP_39]], %[[TMP_46]]
|
||||
// CHECK: %[[TMP_48:.*]] = mhlo.constant dense<-0.13857109526572012>
|
||||
// CHECK: %[[TMP_49:.*]] = mhlo.constant dense<6.000000e+00>
|
||||
// CHECK: %[[TMP_50:.*]] = mhlo.add %[[TMP_5]], %[[TMP_49]]
|
||||
// CHECK: %[[TMP_51:.*]] = mhlo.multiply %[[TMP_50]], %[[TMP_50]]
|
||||
// CHECK: %[[TMP_52:.*]] = mhlo.divide %[[TMP_48]], %[[TMP_51]]
|
||||
// CHECK: %[[TMP_53:.*]] = mhlo.subtract %[[TMP_45]], %[[TMP_52]]
|
||||
// CHECK: %[[TMP_54:.*]] = mhlo.divide %[[TMP_48]], %[[TMP_50]]
|
||||
// CHECK: %[[TMP_55:.*]] = mhlo.add %[[TMP_47]], %[[TMP_54]]
|
||||
// CHECK: %[[TMP_56:.*]] = mhlo.constant dense<9.9843695780195716E-6>
|
||||
// CHECK: %[[TMP_57:.*]] = mhlo.constant dense<7.000000e+00>
|
||||
// CHECK: %[[TMP_58:.*]] = mhlo.add %[[TMP_5]], %[[TMP_57]]
|
||||
// CHECK: %[[TMP_59:.*]] = mhlo.multiply %[[TMP_58]], %[[TMP_58]]
|
||||
// CHECK: %[[TMP_60:.*]] = mhlo.divide %[[TMP_56]], %[[TMP_59]]
|
||||
// CHECK: %[[TMP_61:.*]] = mhlo.subtract %[[TMP_53]], %[[TMP_60]]
|
||||
// CHECK: %[[TMP_62:.*]] = mhlo.divide %[[TMP_56]], %[[TMP_58]]
|
||||
// CHECK: %[[TMP_63:.*]] = mhlo.add %[[TMP_55]], %[[TMP_62]]
|
||||
// CHECK: %[[TMP_64:.*]] = mhlo.constant dense<1.5056327351493116E-7>
|
||||
// CHECK: %[[TMP_65:.*]] = mhlo.constant dense<8.000000e+00>
|
||||
// CHECK: %[[TMP_66:.*]] = mhlo.add %[[TMP_5]], %[[TMP_65]]
|
||||
// CHECK: %[[TMP_67:.*]] = mhlo.multiply %[[TMP_66]], %[[TMP_66]]
|
||||
// CHECK: %[[TMP_68:.*]] = mhlo.divide %[[TMP_64]], %[[TMP_67]]
|
||||
// CHECK: %[[TMP_69:.*]] = mhlo.subtract %[[TMP_61]], %[[TMP_68]]
|
||||
// CHECK: %[[TMP_70:.*]] = mhlo.divide %[[TMP_64]], %[[TMP_66]]
|
||||
// CHECK: %[[TMP_71:.*]] = mhlo.add %[[TMP_63]], %[[TMP_70]]
|
||||
// CHECK: %[[TMP_72:.*]] = mhlo.constant dense<7.500000e+00>
|
||||
// CHECK: %[[TMP_73:.*]] = mhlo.add %[[TMP_72]], %[[TMP_5]]
|
||||
// CHECK: %[[TMP_74:.*]] = mhlo.constant dense<2.0149030205422647>
|
||||
// CHECK: %[[TMP_75:.*]] = mhlo.divide %[[TMP_5]], %[[TMP_72]]
|
||||
// CHECK: %[[TMP_76:.*]] = "mhlo.log_plus_one"(%[[TMP_75]])
|
||||
// CHECK: %[[TMP_77:.*]] = mhlo.add %[[TMP_74]], %[[TMP_76]]
|
||||
// CHECK: %[[TMP_78:.*]] = mhlo.divide %[[TMP_69]], %[[TMP_71]]
|
||||
// CHECK: %[[TMP_79:.*]] = mhlo.constant dense<7.000000e+00>
|
||||
// CHECK: %[[TMP_80:.*]] = mhlo.divide %[[TMP_79]], %[[TMP_73]]
|
||||
// CHECK: %[[TMP_81:.*]] = mhlo.add %[[TMP_77]], %[[TMP_78]]
|
||||
// CHECK: %[[TMP_82:.*]] = mhlo.subtract %[[TMP_81]], %[[TMP_80]]
|
||||
// CHECK: %[[TMP_83:.*]] = mhlo.constant dense<5.000000e-01>
|
||||
// CHECK: %[[TMP_84:.*]] = mhlo.add %arg0, %[[TMP_83]]
|
||||
// CHECK: %[[TMP_85:.*]] = "mhlo.floor"(%[[TMP_84]])
|
||||
// CHECK: %[[TMP_86:.*]] = "mhlo.abs"(%[[TMP_85]])
|
||||
// CHECK: %[[TMP_87:.*]] = mhlo.add %arg0, %[[TMP_86]]
|
||||
// CHECK: %[[TMP_88:.*]] = mhlo.constant dense<3.1415926535897931>
|
||||
// CHECK: %[[TMP_89:.*]] = mhlo.multiply %[[TMP_88]], %[[TMP_87]]
|
||||
// CHECK: %[[TMP_90:.*]] = "mhlo.cosine"(%[[TMP_89]])
|
||||
// CHECK: %[[TMP_92:.*]] = "mhlo.sine"(%[[TMP_89]])
|
||||
// CHECK: %[[TMP_91:.*]] = mhlo.multiply %[[TMP_88]], %[[TMP_90]]
|
||||
// CHECK: %[[TMP_93:.*]] = mhlo.divide %[[TMP_91]], %[[TMP_92]]
|
||||
// CHECK: %[[TMP_94:.*]] = mhlo.subtract %[[TMP_82]], %[[TMP_93]]
|
||||
// CHECK: %[[TMP_95:.*]] = "mhlo.select"(%[[TMP_1]], %[[TMP_94]], %[[TMP_82]])
|
||||
// CHECK: %[[TMP_96:.*]] = "mhlo.compare"(%arg0, %[[TMP_6]]) {comparison_direction = "LE"}
|
||||
// CHECK: %[[TMP_97:.*]] = "mhlo.floor"(%arg0)
|
||||
// CHECK: %[[TMP_98:.*]] = "mhlo.compare"(%arg0, %[[TMP_97]]) {comparison_direction = "EQ"}
|
||||
// CHECK: %[[TMP_99:.*]] = mhlo.and %[[TMP_96]], %[[TMP_98]]
|
||||
// CHECK: %[[TMP_100:.*]] = mhlo.constant dense<0x7FF8000000000000>
|
||||
// CHECK: %[[RES:.*]] = "mhlo.select"(%[[TMP_99]], %[[TMP_100]], %[[TMP_95]])
|
||||
// CHECK: return %[[RES]]
|
||||
%1 = chlo.digamma %arg : tensor<f64> -> tensor<f64>
|
||||
return %1 : tensor<f64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @digamma_f32
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<f32>)
|
||||
func @digamma_f32(%arg : tensor<f32>) -> tensor<f32> {
|
||||
// CHECK: %[[TMP_0:.*]] = mhlo.constant dense<5.000000e-01>
|
||||
// CHECK: %[[TMP_1:.*]] = "mhlo.compare"(%arg0, %[[TMP_0]]) {comparison_direction = "LT"}
|
||||
// CHECK: %[[TMP_2:.*]] = "mhlo.negate"(%arg0)
|
||||
// CHECK: %[[TMP_3:.*]] = mhlo.constant dense<1.000000e+00>
|
||||
// CHECK: %[[TMP_4:.*]] = mhlo.subtract %arg0, %[[TMP_3]]
|
||||
// CHECK: %[[TMP_5:.*]] = "mhlo.select"(%[[TMP_1]], %[[TMP_2]], %[[TMP_4]])
|
||||
// CHECK: %[[TMP_6:.*]] = mhlo.constant dense<0.000000e+00>
|
||||
// CHECK: %[[TMP_7:.*]] = mhlo.constant dense<1.000000e+00>
|
||||
// CHECK: %[[TMP_8:.*]] = mhlo.constant dense<676.520386>
|
||||
// CHECK: %[[TMP_9:.*]] = mhlo.constant dense<1.000000e+00>
|
||||
// CHECK: %[[TMP_10:.*]] = mhlo.add %[[TMP_5]], %[[TMP_9]]
|
||||
// CHECK: %[[TMP_11:.*]] = mhlo.multiply %[[TMP_10]], %[[TMP_10]]
|
||||
// CHECK: %[[TMP_12:.*]] = mhlo.divide %[[TMP_8]], %[[TMP_11]]
|
||||
// CHECK: %[[TMP_13:.*]] = mhlo.subtract %[[TMP_6]], %[[TMP_12]]
|
||||
// CHECK: %[[TMP_14:.*]] = mhlo.divide %[[TMP_8]], %[[TMP_10]]
|
||||
// CHECK: %[[TMP_15:.*]] = mhlo.add %[[TMP_7]], %[[TMP_14]]
|
||||
// CHECK: %[[TMP_16:.*]] = mhlo.constant dense<-1259.13916>
|
||||
// CHECK: %[[TMP_17:.*]] = mhlo.constant dense<2.000000e+00>
|
||||
// CHECK: %[[TMP_18:.*]] = mhlo.add %[[TMP_5]], %[[TMP_17]]
|
||||
// CHECK: %[[TMP_19:.*]] = mhlo.multiply %[[TMP_18]], %[[TMP_18]]
|
||||
// CHECK: %[[TMP_20:.*]] = mhlo.divide %[[TMP_16]], %[[TMP_19]]
|
||||
// CHECK: %[[TMP_21:.*]] = mhlo.subtract %[[TMP_13]], %[[TMP_20]]
|
||||
// CHECK: %[[TMP_22:.*]] = mhlo.divide %[[TMP_16]], %[[TMP_18]]
|
||||
// CHECK: %[[TMP_23:.*]] = mhlo.add %[[TMP_15]], %[[TMP_22]]
|
||||
// CHECK: %[[TMP_24:.*]] = mhlo.constant dense<771.323425>
|
||||
// CHECK: %[[TMP_25:.*]] = mhlo.constant dense<3.000000e+00>
|
||||
// CHECK: %[[TMP_26:.*]] = mhlo.add %[[TMP_5]], %[[TMP_25]]
|
||||
// CHECK: %[[TMP_27:.*]] = mhlo.multiply %[[TMP_26]], %[[TMP_26]]
|
||||
// CHECK: %[[TMP_28:.*]] = mhlo.divide %[[TMP_24]], %[[TMP_27]]
|
||||
// CHECK: %[[TMP_29:.*]] = mhlo.subtract %[[TMP_21]], %[[TMP_28]]
|
||||
// CHECK: %[[TMP_30:.*]] = mhlo.divide %[[TMP_24]], %[[TMP_26]]
|
||||
// CHECK: %[[TMP_31:.*]] = mhlo.add %[[TMP_23]], %[[TMP_30]]
|
||||
// CHECK: %[[TMP_32:.*]] = mhlo.constant dense<-176.615036>
|
||||
// CHECK: %[[TMP_33:.*]] = mhlo.constant dense<4.000000e+00>
|
||||
// CHECK: %[[TMP_34:.*]] = mhlo.add %[[TMP_5]], %[[TMP_33]]
|
||||
// CHECK: %[[TMP_35:.*]] = mhlo.multiply %[[TMP_34]], %[[TMP_34]]
|
||||
// CHECK: %[[TMP_36:.*]] = mhlo.divide %[[TMP_32]], %[[TMP_35]]
|
||||
// CHECK: %[[TMP_37:.*]] = mhlo.subtract %[[TMP_29]], %[[TMP_36]]
|
||||
// CHECK: %[[TMP_38:.*]] = mhlo.divide %[[TMP_32]], %[[TMP_34]]
|
||||
// CHECK: %[[TMP_39:.*]] = mhlo.add %[[TMP_31]], %[[TMP_38]]
|
||||
// CHECK: %[[TMP_40:.*]] = mhlo.constant dense<12.5073433>
|
||||
// CHECK: %[[TMP_41:.*]] = mhlo.constant dense<5.000000e+00>
|
||||
// CHECK: %[[TMP_42:.*]] = mhlo.add %[[TMP_5]], %[[TMP_41]]
|
||||
// CHECK: %[[TMP_43:.*]] = mhlo.multiply %[[TMP_42]], %[[TMP_42]]
|
||||
// CHECK: %[[TMP_44:.*]] = mhlo.divide %[[TMP_40]], %[[TMP_43]]
|
||||
// CHECK: %[[TMP_45:.*]] = mhlo.subtract %[[TMP_37]], %[[TMP_44]]
|
||||
// CHECK: %[[TMP_46:.*]] = mhlo.divide %[[TMP_40]], %[[TMP_42]]
|
||||
// CHECK: %[[TMP_47:.*]] = mhlo.add %[[TMP_39]], %[[TMP_46]]
|
||||
// CHECK: %[[TMP_48:.*]] = mhlo.constant dense<-0.138571098>
|
||||
// CHECK: %[[TMP_49:.*]] = mhlo.constant dense<6.000000e+00>
|
||||
// CHECK: %[[TMP_50:.*]] = mhlo.add %[[TMP_5]], %[[TMP_49]]
|
||||
// CHECK: %[[TMP_51:.*]] = mhlo.multiply %[[TMP_50]], %[[TMP_50]]
|
||||
// CHECK: %[[TMP_52:.*]] = mhlo.divide %[[TMP_48]], %[[TMP_51]]
|
||||
// CHECK: %[[TMP_53:.*]] = mhlo.subtract %[[TMP_45]], %[[TMP_52]]
|
||||
// CHECK: %[[TMP_54:.*]] = mhlo.divide %[[TMP_48]], %[[TMP_50]]
|
||||
// CHECK: %[[TMP_55:.*]] = mhlo.add %[[TMP_47]], %[[TMP_54]]
|
||||
// CHECK: %[[TMP_56:.*]] = mhlo.constant dense<9.98436917E-6>
|
||||
// CHECK: %[[TMP_57:.*]] = mhlo.constant dense<7.000000e+00>
|
||||
// CHECK: %[[TMP_58:.*]] = mhlo.add %[[TMP_5]], %[[TMP_57]]
|
||||
// CHECK: %[[TMP_59:.*]] = mhlo.multiply %[[TMP_58]], %[[TMP_58]]
|
||||
// CHECK: %[[TMP_60:.*]] = mhlo.divide %[[TMP_56]], %[[TMP_59]]
|
||||
// CHECK: %[[TMP_61:.*]] = mhlo.subtract %[[TMP_53]], %[[TMP_60]]
|
||||
// CHECK: %[[TMP_62:.*]] = mhlo.divide %[[TMP_56]], %[[TMP_58]]
|
||||
// CHECK: %[[TMP_63:.*]] = mhlo.add %[[TMP_55]], %[[TMP_62]]
|
||||
// CHECK: %[[TMP_64:.*]] = mhlo.constant dense<1.50563267E-7>
|
||||
// CHECK: %[[TMP_65:.*]] = mhlo.constant dense<8.000000e+00>
|
||||
// CHECK: %[[TMP_66:.*]] = mhlo.add %[[TMP_5]], %[[TMP_65]]
|
||||
// CHECK: %[[TMP_67:.*]] = mhlo.multiply %[[TMP_66]], %[[TMP_66]]
|
||||
// CHECK: %[[TMP_68:.*]] = mhlo.divide %[[TMP_64]], %[[TMP_67]]
|
||||
// CHECK: %[[TMP_69:.*]] = mhlo.subtract %[[TMP_61]], %[[TMP_68]]
|
||||
// CHECK: %[[TMP_70:.*]] = mhlo.divide %[[TMP_64]], %[[TMP_66]]
|
||||
// CHECK: %[[TMP_71:.*]] = mhlo.add %[[TMP_63]], %[[TMP_70]]
|
||||
// CHECK: %[[TMP_72:.*]] = mhlo.constant dense<7.500000e+00>
|
||||
// CHECK: %[[TMP_73:.*]] = mhlo.add %[[TMP_72]], %[[TMP_5]]
|
||||
// CHECK: %[[TMP_74:.*]] = mhlo.constant dense<2.01490307>
|
||||
// CHECK: %[[TMP_75:.*]] = mhlo.divide %[[TMP_5]], %[[TMP_72]]
|
||||
// CHECK: %[[TMP_76:.*]] = "mhlo.log_plus_one"(%[[TMP_75]])
|
||||
// CHECK: %[[TMP_77:.*]] = mhlo.add %[[TMP_74]], %[[TMP_76]]
|
||||
// CHECK: %[[TMP_78:.*]] = mhlo.divide %[[TMP_69]], %[[TMP_71]]
|
||||
// CHECK: %[[TMP_79:.*]] = mhlo.constant dense<7.000000e+00>
|
||||
// CHECK: %[[TMP_80:.*]] = mhlo.divide %[[TMP_79]], %[[TMP_73]]
|
||||
// CHECK: %[[TMP_81:.*]] = mhlo.add %[[TMP_77]], %[[TMP_78]]
|
||||
// CHECK: %[[TMP_82:.*]] = mhlo.subtract %[[TMP_81]], %[[TMP_80]]
|
||||
// CHECK: %[[TMP_83:.*]] = mhlo.constant dense<5.000000e-01>
|
||||
// CHECK: %[[TMP_84:.*]] = mhlo.add %arg0, %[[TMP_83]]
|
||||
// CHECK: %[[TMP_85:.*]] = "mhlo.floor"(%[[TMP_84]])
|
||||
// CHECK: %[[TMP_86:.*]] = "mhlo.abs"(%[[TMP_85]])
|
||||
// CHECK: %[[TMP_87:.*]] = mhlo.add %arg0, %[[TMP_86]]
|
||||
// CHECK: %[[TMP_88:.*]] = mhlo.constant dense<3.14159274>
|
||||
// CHECK: %[[TMP_89:.*]] = mhlo.multiply %[[TMP_88]], %[[TMP_87]]
|
||||
// CHECK: %[[TMP_90:.*]] = "mhlo.cosine"(%[[TMP_89]])
|
||||
// CHECK: %[[TMP_92:.*]] = "mhlo.sine"(%[[TMP_89]])
|
||||
// CHECK: %[[TMP_91:.*]] = mhlo.multiply %[[TMP_88]], %[[TMP_90]]
|
||||
// CHECK: %[[TMP_93:.*]] = mhlo.divide %[[TMP_91]], %[[TMP_92]]
|
||||
// CHECK: %[[TMP_94:.*]] = mhlo.subtract %[[TMP_82]], %[[TMP_93]]
|
||||
// CHECK: %[[TMP_95:.*]] = "mhlo.select"(%[[TMP_1]], %[[TMP_94]], %[[TMP_82]])
|
||||
// CHECK: %[[TMP_96:.*]] = "mhlo.compare"(%arg0, %[[TMP_6]]) {comparison_direction = "LE"}
|
||||
// CHECK: %[[TMP_97:.*]] = "mhlo.floor"(%arg0)
|
||||
// CHECK: %[[TMP_98:.*]] = "mhlo.compare"(%arg0, %[[TMP_97]]) {comparison_direction = "EQ"}
|
||||
// CHECK: %[[TMP_99:.*]] = mhlo.and %[[TMP_96]], %[[TMP_98]]
|
||||
// CHECK: %[[TMP_100:.*]] = mhlo.constant dense<0x7FC00000>
|
||||
// CHECK: %[[RES:.*]] = "mhlo.select"(%[[TMP_99]], %[[TMP_100]], %[[TMP_95]])
|
||||
// CHECK: return %[[RES]]
|
||||
%1 = chlo.digamma %arg : tensor<f32> -> tensor<f32>
|
||||
return %1 : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @digamma_f16
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<f16>)
|
||||
func @digamma_f16(%arg : tensor<f16>) -> tensor<f16> {
|
||||
// CHECK: "mhlo.convert"(%[[ARG]]) : (tensor<f16>) -> tensor<f32>
|
||||
// CHECK: %[[RES:.*]] = "mhlo.convert"(%{{.*}}) : (tensor<f32>) -> tensor<f16>
|
||||
// CHECK: return %[[RES]]
|
||||
%1 = chlo.digamma %arg : tensor<f16> -> tensor<f16>
|
||||
return %1 : tensor<f16>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue