[MLIR][CHLO] Add `chlo.lgamma` and lowering to `hlo`
PiperOrigin-RevId: 354287316
This commit is contained in:
parent
c3ddcd6c7f
commit
e0a7be7fb1
|
@ -554,6 +554,15 @@ def HLOClient_IsPosInfOp : HLOClient_UnaryElementwiseOp<"is_pos_inf",
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def HLOClient_LgammaOp : HLOClient_UnaryElementwiseOp<"lgamma",
|
||||||
|
[SameOperandsAndResultType], HLO_FpTensor, HLO_FpTensor> {
|
||||||
|
let summary = "Lgamma function";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Returns `Lgamma(operand)` element-wise.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Broadcasting compare op
|
// Broadcasting compare op
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -471,6 +471,191 @@ struct ConvertErfcOp : public OpConversionPattern<ErfcOp> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Coefficients for the Lanczos approximation of the gamma function. The
|
||||||
|
// coefficients are uniquely determined by the choice of g and n (kLanczosGamma
|
||||||
|
// and kLanczosCoefficients.size() + 1). The coefficients below correspond to
|
||||||
|
// [7, 9]. [5, 7], [7, 9], [9, 10], and [607/128.0, 15] were evaluated and
|
||||||
|
// [7, 9] seemed to be the least sensitive to the quality of the log function.
|
||||||
|
// In particular, [5, 7] is the only choice where -1.5e-5 <= lgamma(2) <= 1.5e-5
|
||||||
|
// for a particularly inaccurate log function.
|
||||||
|
constexpr double kLanczosGamma = 7; // aka g
|
||||||
|
constexpr double kBaseLanczosCoeff = 0.99999999999980993227684700473478;
|
||||||
|
constexpr std::array<double, 8> kLanczosCoefficients = {
|
||||||
|
676.520368121885098567009190444019, -1259.13921672240287047156078755283,
|
||||||
|
771.3234287776530788486528258894, -176.61502916214059906584551354,
|
||||||
|
12.507343278686904814458936853, -0.13857109526572011689554707,
|
||||||
|
9.984369578019570859563e-6, 1.50563273514931155834e-7};
|
||||||
|
|
||||||
|
// Compute the Lgamma function using Lanczos' approximation from "A Precision
|
||||||
|
// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
|
||||||
|
// series B. Vol. 1:
|
||||||
|
// lgamma(z + 1) = (log(2) + log(pi)) / 2
|
||||||
|
// + (z + 1/2) * log(t(z))
|
||||||
|
// - t(z) + log(a(z))
|
||||||
|
// with t(z) = z + kLanczosGamma + 1/2
|
||||||
|
// a(z) = kBaseLanczosCoeff
|
||||||
|
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
|
||||||
|
Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
Value x) {
|
||||||
|
// If the input is less than 0.5 use Euler's reflection formula.
|
||||||
|
// gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
|
||||||
|
// Let z be
|
||||||
|
// z = -x if x < 1/2
|
||||||
|
// z = x - 1 otheriwse
|
||||||
|
const StringAttr kLT = rewriter.getStringAttr(
|
||||||
|
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
|
||||||
|
Value half = getConstantLike(rewriter, loc, 0.5, x);
|
||||||
|
Value need_to_reflect = rewriter.create<mhlo::CompareOp>(loc, x, half, kLT);
|
||||||
|
Value neg_x = rewriter.create<mhlo::NegOp>(loc, x);
|
||||||
|
Value one = getConstantLike(rewriter, loc, 1, x);
|
||||||
|
Value x_sub_one = rewriter.create<mhlo::SubOp>(loc, x, one);
|
||||||
|
Value z =
|
||||||
|
rewriter.create<mhlo::SelectOp>(loc, need_to_reflect, neg_x, x_sub_one);
|
||||||
|
|
||||||
|
// Materialize
|
||||||
|
// a(z) = kBaseLanczosCoeff
|
||||||
|
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
|
||||||
|
Value a = getConstantLike(rewriter, loc, kBaseLanczosCoeff, x);
|
||||||
|
for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) {
|
||||||
|
Value coeff = getConstantLike(rewriter, loc, kLanczosCoefficients[i], x);
|
||||||
|
Value one_based_index = getConstantLike(rewriter, loc, i + 1, x);
|
||||||
|
Value quotient = rewriter.create<mhlo::DivOp>(
|
||||||
|
loc, coeff, rewriter.create<mhlo::AddOp>(loc, z, one_based_index));
|
||||||
|
a = rewriter.create<mhlo::AddOp>(loc, a, quotient);
|
||||||
|
}
|
||||||
|
|
||||||
|
// To improve accuracy on platforms with less-precise log implementations,
|
||||||
|
// compute log(kLanczosGamma + 1/2) at compile time and use log1p on the
|
||||||
|
// device.
|
||||||
|
// Materialize as
|
||||||
|
// log(t) = log(kLanczosGamma + 1/2 + z)
|
||||||
|
// = log(kLanczosGamma + 1/2) + log1p(z / (kLanczosGamma + 1/2)).
|
||||||
|
Value lanczos_plus_half =
|
||||||
|
getConstantLike(rewriter, loc, kLanczosGamma + 0.5, x);
|
||||||
|
Value t = rewriter.create<mhlo::AddOp>(loc, lanczos_plus_half, z);
|
||||||
|
Value log_term =
|
||||||
|
getConstantLike(rewriter, loc, std::log(kLanczosGamma + 0.5), x);
|
||||||
|
Value log1p_term = rewriter.create<mhlo::Log1pOp>(
|
||||||
|
loc, rewriter.create<mhlo::DivOp>(loc, z, lanczos_plus_half));
|
||||||
|
Value log_t = rewriter.create<mhlo::AddOp>(loc, log_term, log1p_term);
|
||||||
|
|
||||||
|
// Note that t(z) may be large and we need to be careful not to overflow to
|
||||||
|
// infinity in the relevant term
|
||||||
|
// r = (z + 1/2) * log(t(z)) - t(z).
|
||||||
|
// Therefore, we compute this as
|
||||||
|
// r = (z + 1/2 - t(z) / log(t(z))) * log(t(z)).
|
||||||
|
Value t_div_log_t = rewriter.create<mhlo::DivOp>(loc, t, log_t);
|
||||||
|
Value sum = rewriter.create<mhlo::SubOp>(
|
||||||
|
loc, rewriter.create<mhlo::AddOp>(loc, z, half), t_div_log_t);
|
||||||
|
Value r = rewriter.create<mhlo::MulOp>(loc, sum, log_t);
|
||||||
|
|
||||||
|
// Compute the final result (modulo reflection) as
|
||||||
|
// lgamma(z + 1) = (log(2) + log(pi)) / 2 + r + log(a(z)).
|
||||||
|
Value log_a = rewriter.create<mhlo::LogOp>(loc, a);
|
||||||
|
Value lgamma = rewriter.create<mhlo::AddOp>(
|
||||||
|
loc,
|
||||||
|
rewriter.create<mhlo::AddOp>(
|
||||||
|
loc,
|
||||||
|
getConstantLike(rewriter, loc, (std::log(2) + std::log(M_PI)) / 2, x),
|
||||||
|
r),
|
||||||
|
log_a);
|
||||||
|
|
||||||
|
// Compute the reflected value for x < 0.5 as
|
||||||
|
// lgamma(x) = log(pi) - lgamma(1-x) - log(abs(sin(pi * x))).
|
||||||
|
//
|
||||||
|
// The abs is needed because lgamma is the log of the absolute value of the
|
||||||
|
// gamma function.
|
||||||
|
//
|
||||||
|
// We have to be careful when computing the final term above. gamma(x) goes
|
||||||
|
// to +/-inf at every integer x < 0, and this is controlled by the sin(pi * x)
|
||||||
|
// term. The slope is large, so precision is particularly important.
|
||||||
|
//
|
||||||
|
// Because abs(sin(pi * x)) has period of 1 we can equivalently use
|
||||||
|
// abs(sin(pi * frac(x))) where frac(x) is the fractional part of x. This is
|
||||||
|
// more numerically accurate: It doesn't overflow to inf like pi * x would and
|
||||||
|
// if x is an integer it evaluates to exactly 0 which is important because we
|
||||||
|
// then take the log of this value, and log(0) is inf.
|
||||||
|
//
|
||||||
|
// We don't have a frac(x) primitive in HLO and computing it is tricky, but
|
||||||
|
// because abs(sin(pi * x)) = abs(sin(pi * abs(x))), it's good enough for our
|
||||||
|
// purposes to use abs(frac(x)) = abs(x) - floor(abs(x)).
|
||||||
|
//
|
||||||
|
// Furthermore, pi * abs(frac(x)) loses precision when abs(frac(x)) is close
|
||||||
|
// to 1. To remedy this, we can use the fact that sin(pi * x) in the domain
|
||||||
|
// [0, 1] is symmetric across the line Y=0.5.
|
||||||
|
//
|
||||||
|
|
||||||
|
// Convert values of abs_frac > 0.5 to (1 - abs_frac) to improve precision of
|
||||||
|
// pi * abs_frac for values of abs_frac close to 1.
|
||||||
|
Value abs = rewriter.create<mhlo::AbsOp>(loc, x);
|
||||||
|
Value abs_frac = rewriter.create<mhlo::SubOp>(
|
||||||
|
loc, abs, rewriter.create<mhlo::FloorOp>(loc, abs));
|
||||||
|
Value reduce_abs_frac =
|
||||||
|
rewriter.create<mhlo::CompareOp>(loc, half, abs_frac, kLT);
|
||||||
|
abs_frac = rewriter.create<mhlo::SelectOp>(
|
||||||
|
loc, reduce_abs_frac, rewriter.create<mhlo::SubOp>(loc, one, abs_frac),
|
||||||
|
abs_frac);
|
||||||
|
|
||||||
|
// Materialize reflection.
|
||||||
|
Value reflection_denom = rewriter.create<mhlo::LogOp>(
|
||||||
|
loc,
|
||||||
|
rewriter.create<mhlo::SinOp>(
|
||||||
|
loc, rewriter.create<mhlo::MulOp>(
|
||||||
|
loc, getConstantLike(rewriter, loc, M_PI, x), abs_frac)));
|
||||||
|
Value lgamma_reflection = rewriter.create<mhlo::SubOp>(
|
||||||
|
loc,
|
||||||
|
rewriter.create<mhlo::SubOp>(
|
||||||
|
loc, getConstantLike(rewriter, loc, std::log(M_PI), x),
|
||||||
|
reflection_denom),
|
||||||
|
lgamma);
|
||||||
|
|
||||||
|
// Avoid computing -inf - inf, which is nan. If reflection_denom is +/-inf,
|
||||||
|
// then it "wins" and the result is +/-inf.
|
||||||
|
Value finite_reflection_denom =
|
||||||
|
rewriter.create<mhlo::IsFiniteOp>(loc, reflection_denom);
|
||||||
|
Value neg_reflection_denom =
|
||||||
|
rewriter.create<mhlo::NegOp>(loc, reflection_denom);
|
||||||
|
lgamma_reflection = rewriter.create<mhlo::SelectOp>(
|
||||||
|
loc, finite_reflection_denom, lgamma_reflection, neg_reflection_denom);
|
||||||
|
|
||||||
|
// Select whether or not to rely on the reflection.
|
||||||
|
lgamma = rewriter.create<mhlo::SelectOp>(loc, need_to_reflect,
|
||||||
|
lgamma_reflection, lgamma);
|
||||||
|
|
||||||
|
// Materialize +/-inf behavior as
|
||||||
|
// lgamma(+/-inf) = +inf.
|
||||||
|
Value x_is_inf = rewriter.create<chlo::IsInfOp>(loc, x);
|
||||||
|
return rewriter.create<mhlo::SelectOp>(
|
||||||
|
loc, x_is_inf,
|
||||||
|
chlo::getConstantLikeInfValue(rewriter, loc, x, /*negative=*/false),
|
||||||
|
lgamma);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> {
|
||||||
|
using OpConversionPattern<LgammaOp>::OpConversionPattern;
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
LgammaOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
LgammaOp::Adaptor transformed(operands);
|
||||||
|
Value x = transformed.operand();
|
||||||
|
Type ty = getElementTypeOrSelf(op.getType());
|
||||||
|
|
||||||
|
if (ty.isF32() || ty.isF64()) {
|
||||||
|
rewriter.replaceOp(op, MaterializeLgamma(rewriter, loc, x));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Materialize lgamma with upcast to f32.
|
||||||
|
x = rewriter.create<mhlo::ConvertOp>(loc, x, rewriter.getF32Type());
|
||||||
|
Value result = MaterializeLgamma(rewriter, loc, x);
|
||||||
|
result = rewriter.create<mhlo::ConvertOp>(loc, result, ty);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Converts binary ops that statically are determined to not broadcast directly
|
// Converts binary ops that statically are determined to not broadcast directly
|
||||||
// to the corresponding mhlo non-broadcasting op.
|
// to the corresponding mhlo non-broadcasting op.
|
||||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||||
|
@ -622,7 +807,8 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||||
context, patterns, 5);
|
context, patterns, 5);
|
||||||
|
|
||||||
// Other patterns.
|
// Other patterns.
|
||||||
patterns->insert<ConvertConstantLikeOp, ConvertErfOp, ConvertErfcOp>(context);
|
patterns->insert<ConvertConstantLikeOp, ConvertErfOp, ConvertErfcOp,
|
||||||
|
ConvertLgammaOp>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace chlo
|
} // namespace chlo
|
||||||
|
|
|
@ -683,3 +683,195 @@ func @is_neg_inf_f32(%arg : tensor<f32>) -> tensor<i1> {
|
||||||
%1 = chlo.is_neg_inf %arg : tensor<f32> -> tensor<i1>
|
%1 = chlo.is_neg_inf %arg : tensor<f32> -> tensor<i1>
|
||||||
return %1 : tensor<i1>
|
return %1 : tensor<i1>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @lgamma_f64
|
||||||
|
// CHECK-SAME: (%[[ARG:.*]]: tensor<f64>)
|
||||||
|
func @lgamma_f64(%arg : tensor<f64>) -> tensor<f64> {
|
||||||
|
// CHECK: %[[TMP_1:.*]] = mhlo.constant dense<5.000000e-01>
|
||||||
|
// CHECK: %[[TMP_9:.*]] = "mhlo.compare"(%[[ARG]], %[[TMP_1]]) {comparison_direction = "LT"}
|
||||||
|
// CHECK: %[[TMP_10:.*]] = "mhlo.negate"(%[[ARG]])
|
||||||
|
// CHECK: %[[TMP_2:.*]] = mhlo.constant dense<1.000000e+00>
|
||||||
|
// CHECK: %[[TMP_11:.*]] = mhlo.subtract %[[ARG]], %[[TMP_2]]
|
||||||
|
// CHECK: %[[TMP_12:.*]] = "mhlo.select"(%[[TMP_9]], %[[TMP_10]], %[[TMP_11]])
|
||||||
|
// CHECK: %[[TMP_8:.*]] = mhlo.constant dense<0.99999999999980993>
|
||||||
|
// CHECK: %[[TMP_13:.*]] = mhlo.constant dense<676.5203681218851>
|
||||||
|
// CHECK: %[[TMP_14:.*]] = mhlo.constant dense<1.000000e+00>
|
||||||
|
// CHECK: %[[TMP_15:.*]] = mhlo.add %[[TMP_12]], %[[TMP_14]]
|
||||||
|
// CHECK: %[[TMP_16:.*]] = mhlo.divide %[[TMP_13]], %[[TMP_15]]
|
||||||
|
// CHECK: %[[TMP_17:.*]] = mhlo.add %[[TMP_8]], %[[TMP_16]]
|
||||||
|
// CHECK: %[[TMP_18:.*]] = mhlo.constant dense<-1259.1392167224028>
|
||||||
|
// CHECK: %[[TMP_19:.*]] = mhlo.constant dense<2.000000e+00>
|
||||||
|
// CHECK: %[[TMP_20:.*]] = mhlo.add %[[TMP_12]], %[[TMP_19]]
|
||||||
|
// CHECK: %[[TMP_21:.*]] = mhlo.divide %[[TMP_18]], %[[TMP_20]]
|
||||||
|
// CHECK: %[[TMP_22:.*]] = mhlo.add %[[TMP_17]], %[[TMP_21]]
|
||||||
|
// CHECK: %[[TMP_23:.*]] = mhlo.constant dense<771.32342877765313>
|
||||||
|
// CHECK: %[[TMP_24:.*]] = mhlo.constant dense<3.000000e+00>
|
||||||
|
// CHECK: %[[TMP_25:.*]] = mhlo.add %[[TMP_12]], %[[TMP_24]]
|
||||||
|
// CHECK: %[[TMP_26:.*]] = mhlo.divide %[[TMP_23]], %[[TMP_25]]
|
||||||
|
// CHECK: %[[TMP_27:.*]] = mhlo.add %[[TMP_22]], %[[TMP_26]]
|
||||||
|
// CHECK: %[[TMP_28:.*]] = mhlo.constant dense<-176.61502916214059>
|
||||||
|
// CHECK: %[[TMP_29:.*]] = mhlo.constant dense<4.000000e+00>
|
||||||
|
// CHECK: %[[TMP_30:.*]] = mhlo.add %[[TMP_12]], %[[TMP_29]]
|
||||||
|
// CHECK: %[[TMP_31:.*]] = mhlo.divide %[[TMP_28]], %[[TMP_30]]
|
||||||
|
// CHECK: %[[TMP_32:.*]] = mhlo.add %[[TMP_27]], %[[TMP_31]]
|
||||||
|
// CHECK: %[[TMP_33:.*]] = mhlo.constant dense<12.507343278686905>
|
||||||
|
// CHECK: %[[TMP_34:.*]] = mhlo.constant dense<5.000000e+00>
|
||||||
|
// CHECK: %[[TMP_35:.*]] = mhlo.add %[[TMP_12]], %[[TMP_34]]
|
||||||
|
// CHECK: %[[TMP_36:.*]] = mhlo.divide %[[TMP_33]], %[[TMP_35]]
|
||||||
|
// CHECK: %[[TMP_37:.*]] = mhlo.add %[[TMP_32]], %[[TMP_36]]
|
||||||
|
// CHECK: %[[TMP_38:.*]] = mhlo.constant dense<-0.13857109526572012>
|
||||||
|
// CHECK: %[[TMP_39:.*]] = mhlo.constant dense<6.000000e+00>
|
||||||
|
// CHECK: %[[TMP_40:.*]] = mhlo.add %[[TMP_12]], %[[TMP_39]]
|
||||||
|
// CHECK: %[[TMP_41:.*]] = mhlo.divide %[[TMP_38]], %[[TMP_40]]
|
||||||
|
// CHECK: %[[TMP_42:.*]] = mhlo.add %[[TMP_37]], %[[TMP_41]]
|
||||||
|
// CHECK: %[[TMP_43:.*]] = mhlo.constant dense<9.9843695780195716E-6>
|
||||||
|
// CHECK: %[[TMP_44:.*]] = mhlo.constant dense<7.000000e+00>
|
||||||
|
// CHECK: %[[TMP_45:.*]] = mhlo.add %[[TMP_12]], %[[TMP_44]]
|
||||||
|
// CHECK: %[[TMP_46:.*]] = mhlo.divide %[[TMP_43]], %[[TMP_45]]
|
||||||
|
// CHECK: %[[TMP_47:.*]] = mhlo.add %[[TMP_42]], %[[TMP_46]]
|
||||||
|
// CHECK: %[[TMP_48:.*]] = mhlo.constant dense<1.5056327351493116E-7>
|
||||||
|
// CHECK: %[[TMP_49:.*]] = mhlo.constant dense<8.000000e+00>
|
||||||
|
// CHECK: %[[TMP_50:.*]] = mhlo.add %[[TMP_12]], %[[TMP_49]]
|
||||||
|
// CHECK: %[[TMP_51:.*]] = mhlo.divide %[[TMP_48]], %[[TMP_50]]
|
||||||
|
// CHECK: %[[TMP_52:.*]] = mhlo.add %[[TMP_47]], %[[TMP_51]]
|
||||||
|
// CHECK: %[[TMP_6:.*]] = mhlo.constant dense<7.500000e+00>
|
||||||
|
// CHECK: %[[TMP_53:.*]] = mhlo.add %[[TMP_6]], %[[TMP_12]]
|
||||||
|
// CHECK: %[[TMP_7:.*]] = mhlo.constant dense<2.0149030205422647>
|
||||||
|
// CHECK: %[[TMP_54:.*]] = mhlo.divide %[[TMP_12]], %[[TMP_6]]
|
||||||
|
// CHECK: %[[TMP_55:.*]] = "mhlo.log_plus_one"(%[[TMP_54]])
|
||||||
|
// CHECK: %[[TMP_56:.*]] = mhlo.add %[[TMP_7]], %[[TMP_55]]
|
||||||
|
// CHECK: %[[TMP_57:.*]] = mhlo.divide %[[TMP_53]], %[[TMP_56]]
|
||||||
|
// CHECK: %[[TMP_58:.*]] = mhlo.add %[[TMP_12]], %[[TMP_1]]
|
||||||
|
// CHECK: %[[TMP_59:.*]] = mhlo.subtract %[[TMP_58]], %[[TMP_57]]
|
||||||
|
// CHECK: %[[TMP_60:.*]] = mhlo.multiply %[[TMP_59]], %[[TMP_56]]
|
||||||
|
// CHECK: %[[TMP_61:.*]] = "mhlo.log"(%[[TMP_52]])
|
||||||
|
// CHECK: %[[TMP_5:.*]] = mhlo.constant dense<0.91893853320467266>
|
||||||
|
// CHECK: %[[TMP_62:.*]] = mhlo.add %[[TMP_5]], %[[TMP_60]]
|
||||||
|
// CHECK: %[[TMP_63:.*]] = mhlo.add %[[TMP_62]], %[[TMP_61]]
|
||||||
|
// CHECK: %[[TMP_64:.*]] = "mhlo.abs"(%[[ARG]])
|
||||||
|
// CHECK: %[[TMP_65:.*]] = "mhlo.floor"(%[[TMP_64]])
|
||||||
|
// CHECK: %[[TMP_66:.*]] = mhlo.subtract %[[TMP_64]], %[[TMP_65]]
|
||||||
|
// CHECK: %[[TMP_67:.*]] = "mhlo.compare"(%[[TMP_1]], %[[TMP_66]]) {comparison_direction = "LT"}
|
||||||
|
// CHECK: %[[TMP_68:.*]] = mhlo.subtract %[[TMP_2]], %[[TMP_66]]
|
||||||
|
// CHECK: %[[TMP_69:.*]] = "mhlo.select"(%[[TMP_67]], %[[TMP_68]], %[[TMP_66]])
|
||||||
|
// CHECK: %[[TMP_3:.*]] = mhlo.constant dense<3.1415926535897931>
|
||||||
|
// CHECK: %[[TMP_70:.*]] = mhlo.multiply %[[TMP_3]], %[[TMP_69]]
|
||||||
|
// CHECK: %[[TMP_71:.*]] = "mhlo.sine"(%[[TMP_70]])
|
||||||
|
// CHECK: %[[TMP_72:.*]] = "mhlo.log"(%[[TMP_71]])
|
||||||
|
// CHECK: %[[TMP_4:.*]] = mhlo.constant dense<1.1447298858494002>
|
||||||
|
// CHECK: %[[TMP_75:.*]] = mhlo.subtract %[[TMP_4]], %[[TMP_72]]
|
||||||
|
// CHECK: %[[TMP_76:.*]] = mhlo.subtract %[[TMP_75]], %[[TMP_63]]
|
||||||
|
// CHECK: %[[TMP_73:.*]] = "mhlo.is_finite"(%[[TMP_72]])
|
||||||
|
// CHECK: %[[TMP_74:.*]] = "mhlo.negate"(%[[TMP_72]])
|
||||||
|
// CHECK: %[[TMP_77:.*]] = "mhlo.select"(%[[TMP_73]], %[[TMP_76]], %[[TMP_74]])
|
||||||
|
// CHECK: %[[TMP_78:.*]] = "mhlo.select"(%[[TMP_9]], %[[TMP_77]], %[[TMP_63]])
|
||||||
|
// CHECK: %[[TMP_79:.*]] = "mhlo.abs"(%[[ARG]])
|
||||||
|
// CHECK: %[[TMP_80:.*]] = mhlo.constant dense<0x7FF0000000000000>
|
||||||
|
// CHECK: %[[TMP_81:.*]] = "mhlo.compare"(%[[TMP_79]], %[[TMP_80]]) {comparison_direction = "EQ"}
|
||||||
|
// CHECK: %[[TMP_0:.*]] = mhlo.constant dense<0x7FF0000000000000>
|
||||||
|
// CHECK: %[[TMP_82:.*]] = "mhlo.select"(%[[TMP_81]], %[[TMP_0]], %[[TMP_78]])
|
||||||
|
// CHECK: return %[[TMP_82]]
|
||||||
|
%1 = chlo.lgamma %arg : tensor<f64> -> tensor<f64>
|
||||||
|
return %1 : tensor<f64>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @lgamma_f32
|
||||||
|
// CHECK-SAME: (%[[ARG:.*]]: tensor<f32>)
|
||||||
|
func @lgamma_f32(%arg : tensor<f32>) -> tensor<f32> {
|
||||||
|
// CHECK: %[[TMP_1:.*]] = mhlo.constant dense<5.000000e-01>
|
||||||
|
// CHECK: %[[TMP_9:.*]] = "mhlo.compare"(%[[ARG]], %[[TMP_1]]) {comparison_direction = "LT"}
|
||||||
|
// CHECK: %[[TMP_10:.*]] = "mhlo.negate"(%[[ARG]])
|
||||||
|
// CHECK: %[[TMP_2:.*]] = mhlo.constant dense<1.000000e+00>
|
||||||
|
// CHECK: %[[TMP_11:.*]] = mhlo.subtract %[[ARG]], %[[TMP_2]]
|
||||||
|
// CHECK: %[[TMP_12:.*]] = "mhlo.select"(%[[TMP_9]], %[[TMP_10]], %[[TMP_11]])
|
||||||
|
// CHECK: %[[TMP_8:.*]] = mhlo.constant dense<1.000000e+00>
|
||||||
|
// CHECK: %[[TMP_13:.*]] = mhlo.constant dense<676.520386>
|
||||||
|
// CHECK: %[[TMP_14:.*]] = mhlo.constant dense<1.000000e+00>
|
||||||
|
// CHECK: %[[TMP_15:.*]] = mhlo.add %[[TMP_12]], %[[TMP_14]]
|
||||||
|
// CHECK: %[[TMP_16:.*]] = mhlo.divide %[[TMP_13]], %[[TMP_15]]
|
||||||
|
// CHECK: %[[TMP_17:.*]] = mhlo.add %[[TMP_8]], %[[TMP_16]]
|
||||||
|
// CHECK: %[[TMP_18:.*]] = mhlo.constant dense<-1259.13916>
|
||||||
|
// CHECK: %[[TMP_19:.*]] = mhlo.constant dense<2.000000e+00>
|
||||||
|
// CHECK: %[[TMP_20:.*]] = mhlo.add %[[TMP_12]], %[[TMP_19]]
|
||||||
|
// CHECK: %[[TMP_21:.*]] = mhlo.divide %[[TMP_18]], %[[TMP_20]]
|
||||||
|
// CHECK: %[[TMP_22:.*]] = mhlo.add %[[TMP_17]], %[[TMP_21]]
|
||||||
|
// CHECK: %[[TMP_23:.*]] = mhlo.constant dense<771.323425>
|
||||||
|
// CHECK: %[[TMP_24:.*]] = mhlo.constant dense<3.000000e+00>
|
||||||
|
// CHECK: %[[TMP_25:.*]] = mhlo.add %[[TMP_12]], %[[TMP_24]]
|
||||||
|
// CHECK: %[[TMP_26:.*]] = mhlo.divide %[[TMP_23]], %[[TMP_25]]
|
||||||
|
// CHECK: %[[TMP_27:.*]] = mhlo.add %[[TMP_22]], %[[TMP_26]]
|
||||||
|
// CHECK: %[[TMP_28:.*]] = mhlo.constant dense<-176.615036>
|
||||||
|
// CHECK: %[[TMP_29:.*]] = mhlo.constant dense<4.000000e+00>
|
||||||
|
// CHECK: %[[TMP_30:.*]] = mhlo.add %[[TMP_12]], %[[TMP_29]]
|
||||||
|
// CHECK: %[[TMP_31:.*]] = mhlo.divide %[[TMP_28]], %[[TMP_30]]
|
||||||
|
// CHECK: %[[TMP_32:.*]] = mhlo.add %[[TMP_27]], %[[TMP_31]]
|
||||||
|
// CHECK: %[[TMP_33:.*]] = mhlo.constant dense<12.5073433>
|
||||||
|
// CHECK: %[[TMP_34:.*]] = mhlo.constant dense<5.000000e+00>
|
||||||
|
// CHECK: %[[TMP_35:.*]] = mhlo.add %[[TMP_12]], %[[TMP_34]]
|
||||||
|
// CHECK: %[[TMP_36:.*]] = mhlo.divide %[[TMP_33]], %[[TMP_35]]
|
||||||
|
// CHECK: %[[TMP_37:.*]] = mhlo.add %[[TMP_32]], %[[TMP_36]]
|
||||||
|
// CHECK: %[[TMP_38:.*]] = mhlo.constant dense<-0.138571098>
|
||||||
|
// CHECK: %[[TMP_39:.*]] = mhlo.constant dense<6.000000e+00>
|
||||||
|
// CHECK: %[[TMP_40:.*]] = mhlo.add %[[TMP_12]], %[[TMP_39]]
|
||||||
|
// CHECK: %[[TMP_41:.*]] = mhlo.divide %[[TMP_38]], %[[TMP_40]]
|
||||||
|
// CHECK: %[[TMP_42:.*]] = mhlo.add %[[TMP_37]], %[[TMP_41]]
|
||||||
|
// CHECK: %[[TMP_43:.*]] = mhlo.constant dense<9.98436917E-6>
|
||||||
|
// CHECK: %[[TMP_44:.*]] = mhlo.constant dense<7.000000e+00>
|
||||||
|
// CHECK: %[[TMP_45:.*]] = mhlo.add %[[TMP_12]], %[[TMP_44]]
|
||||||
|
// CHECK: %[[TMP_46:.*]] = mhlo.divide %[[TMP_43]], %[[TMP_45]]
|
||||||
|
// CHECK: %[[TMP_47:.*]] = mhlo.add %[[TMP_42]], %[[TMP_46]]
|
||||||
|
// CHECK: %[[TMP_48:.*]] = mhlo.constant dense<1.50563267E-7>
|
||||||
|
// CHECK: %[[TMP_49:.*]] = mhlo.constant dense<8.000000e+00>
|
||||||
|
// CHECK: %[[TMP_50:.*]] = mhlo.add %[[TMP_12]], %[[TMP_49]]
|
||||||
|
// CHECK: %[[TMP_51:.*]] = mhlo.divide %[[TMP_48]], %[[TMP_50]]
|
||||||
|
// CHECK: %[[TMP_52:.*]] = mhlo.add %[[TMP_47]], %[[TMP_51]]
|
||||||
|
// CHECK: %[[TMP_6:.*]] = mhlo.constant dense<7.500000e+00>
|
||||||
|
// CHECK: %[[TMP_53:.*]] = mhlo.add %[[TMP_6]], %[[TMP_12]]
|
||||||
|
// CHECK: %[[TMP_7:.*]] = mhlo.constant dense<2.01490307>
|
||||||
|
// CHECK: %[[TMP_54:.*]] = mhlo.divide %[[TMP_12]], %[[TMP_6]]
|
||||||
|
// CHECK: %[[TMP_55:.*]] = "mhlo.log_plus_one"(%[[TMP_54]])
|
||||||
|
// CHECK: %[[TMP_56:.*]] = mhlo.add %[[TMP_7]], %[[TMP_55]]
|
||||||
|
// CHECK: %[[TMP_57:.*]] = mhlo.divide %[[TMP_53]], %[[TMP_56]]
|
||||||
|
// CHECK: %[[TMP_58:.*]] = mhlo.add %[[TMP_12]], %[[TMP_1]]
|
||||||
|
// CHECK: %[[TMP_59:.*]] = mhlo.subtract %[[TMP_58]], %[[TMP_57]]
|
||||||
|
// CHECK: %[[TMP_60:.*]] = mhlo.multiply %[[TMP_59]], %[[TMP_56]]
|
||||||
|
// CHECK: %[[TMP_61:.*]] = "mhlo.log"(%[[TMP_52]])
|
||||||
|
// CHECK: %[[TMP_5:.*]] = mhlo.constant dense<0.918938517>
|
||||||
|
// CHECK: %[[TMP_62:.*]] = mhlo.add %[[TMP_5]], %[[TMP_60]]
|
||||||
|
// CHECK: %[[TMP_63:.*]] = mhlo.add %[[TMP_62]], %[[TMP_61]]
|
||||||
|
// CHECK: %[[TMP_64:.*]] = "mhlo.abs"(%[[ARG]])
|
||||||
|
// CHECK: %[[TMP_65:.*]] = "mhlo.floor"(%[[TMP_64]])
|
||||||
|
// CHECK: %[[TMP_66:.*]] = mhlo.subtract %[[TMP_64]], %[[TMP_65]]
|
||||||
|
// CHECK: %[[TMP_67:.*]] = "mhlo.compare"(%[[TMP_1]], %[[TMP_66]]) {comparison_direction = "LT"}
|
||||||
|
// CHECK: %[[TMP_68:.*]] = mhlo.subtract %[[TMP_2]], %[[TMP_66]]
|
||||||
|
// CHECK: %[[TMP_69:.*]] = "mhlo.select"(%[[TMP_67]], %[[TMP_68]], %[[TMP_66]])
|
||||||
|
// CHECK: %[[TMP_3:.*]] = mhlo.constant dense<3.14159274>
|
||||||
|
// CHECK: %[[TMP_70:.*]] = mhlo.multiply %[[TMP_3]], %[[TMP_69]]
|
||||||
|
// CHECK: %[[TMP_71:.*]] = "mhlo.sine"(%[[TMP_70]])
|
||||||
|
// CHECK: %[[TMP_72:.*]] = "mhlo.log"(%[[TMP_71]])
|
||||||
|
// CHECK: %[[TMP_4:.*]] = mhlo.constant dense<1.14472985>
|
||||||
|
// CHECK: %[[TMP_75:.*]] = mhlo.subtract %[[TMP_4]], %[[TMP_72]]
|
||||||
|
// CHECK: %[[TMP_76:.*]] = mhlo.subtract %[[TMP_75]], %[[TMP_63]]
|
||||||
|
// CHECK: %[[TMP_73:.*]] = "mhlo.is_finite"(%[[TMP_72]])
|
||||||
|
// CHECK: %[[TMP_74:.*]] = "mhlo.negate"(%[[TMP_72]])
|
||||||
|
// CHECK: %[[TMP_77:.*]] = "mhlo.select"(%[[TMP_73]], %[[TMP_76]], %[[TMP_74]])
|
||||||
|
// CHECK: %[[TMP_78:.*]] = "mhlo.select"(%[[TMP_9]], %[[TMP_77]], %[[TMP_63]])
|
||||||
|
// CHECK: %[[TMP_79:.*]] = "mhlo.abs"(%[[ARG]])
|
||||||
|
// CHECK: %[[TMP_80:.*]] = mhlo.constant dense<0x7F800000>
|
||||||
|
// CHECK: %[[TMP_81:.*]] = "mhlo.compare"(%[[TMP_79]], %[[TMP_80]]) {comparison_direction = "EQ"}
|
||||||
|
// CHECK: %[[TMP_0:.*]] = mhlo.constant dense<0x7F800000>
|
||||||
|
// CHECK: %[[TMP_82:.*]] = "mhlo.select"(%[[TMP_81]], %[[TMP_0]], %[[TMP_78]])
|
||||||
|
// CHECK: return %[[TMP_82]]
|
||||||
|
%1 = chlo.lgamma %arg : tensor<f32> -> tensor<f32>
|
||||||
|
return %1 : tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @lgamma_f16
|
||||||
|
// CHECK-SAME: (%[[ARG:.*]]: tensor<f16>)
|
||||||
|
func @lgamma_f16(%arg : tensor<f16>) -> tensor<f16> {
|
||||||
|
// CHECK: "mhlo.convert"(%[[ARG]]) : (tensor<f16>) -> tensor<f32>
|
||||||
|
// CHECK: %[[RES:.*]] = "mhlo.convert"(%{{.*}}) : (tensor<f32>) -> tensor<f16>
|
||||||
|
// CHECK: return %[[RES]]
|
||||||
|
%1 = chlo.lgamma %arg : tensor<f16> -> tensor<f16>
|
||||||
|
return %1 : tensor<f16>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue