[MLIR][KernelGen] Fix zeta lowering at poles
Return nan at zeta poles or inf where the limit is defined. Also test the kernel based on the series representation of zeta. PiperOrigin-RevId: 361993482
This commit is contained in:
parent
7629dfdd81
commit
218476128e
|
@ -784,17 +784,17 @@ Value MaterializeZeta(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
// For speed we'll always use 9 iterations for the initial series estimate,
|
// For speed we'll always use 9 iterations for the initial series estimate,
|
||||||
// and a 12 term expansion for the Euler-Maclaurin formula.
|
// and a 12 term expansion for the Euler-Maclaurin formula.
|
||||||
Value a = q;
|
Value a = q;
|
||||||
Value zero_like_a = chlo::getConstantLike(rewriter, loc, 0.0, a);
|
Value zero = chlo::getConstantLike(rewriter, loc, 0.0, a);
|
||||||
Value neg_power = zero_like_a;
|
Value neg_power = zero;
|
||||||
Value neg_x = rewriter.create<mhlo::NegOp>(loc, x);
|
Value neg_x = rewriter.create<mhlo::NegOp>(loc, x);
|
||||||
Value initial_sum = rewriter.create<mhlo::PowOp>(loc, q, neg_x);
|
Value initial_sum = rewriter.create<mhlo::PowOp>(loc, q, neg_x);
|
||||||
Value one_like_a = chlo::getConstantLike(rewriter, loc, 1.0, a);
|
Value one = chlo::getConstantLike(rewriter, loc, 1.0, a);
|
||||||
for (int i = 0; i < 9; ++i) {
|
for (int i = 0; i < 9; ++i) {
|
||||||
a = rewriter.create<mhlo::AddOp>(loc, a, one_like_a);
|
a = rewriter.create<mhlo::AddOp>(loc, a, one);
|
||||||
neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x);
|
neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x);
|
||||||
initial_sum = rewriter.create<mhlo::AddOp>(loc, initial_sum, neg_power);
|
initial_sum = rewriter.create<mhlo::AddOp>(loc, initial_sum, neg_power);
|
||||||
}
|
}
|
||||||
a = rewriter.create<mhlo::AddOp>(loc, a, one_like_a);
|
a = rewriter.create<mhlo::AddOp>(loc, a, one);
|
||||||
neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x);
|
neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x);
|
||||||
Value one_like_x = chlo::getConstantLike(rewriter, loc, 1.0, x);
|
Value one_like_x = chlo::getConstantLike(rewriter, loc, 1.0, x);
|
||||||
Value x_minus_one = rewriter.create<mhlo::SubOp>(loc, x, one_like_x);
|
Value x_minus_one = rewriter.create<mhlo::SubOp>(loc, x, one_like_x);
|
||||||
|
@ -804,10 +804,10 @@ Value MaterializeZeta(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
Value s = rewriter.create<mhlo::AddOp>(loc, initial_sum,
|
Value s = rewriter.create<mhlo::AddOp>(loc, initial_sum,
|
||||||
neg_power_mul_a_div_x_minus_one);
|
neg_power_mul_a_div_x_minus_one);
|
||||||
Value a_inverse_square = rewriter.create<mhlo::DivOp>(
|
Value a_inverse_square = rewriter.create<mhlo::DivOp>(
|
||||||
loc, one_like_a, rewriter.create<mhlo::MulOp>(loc, a, a));
|
loc, one, rewriter.create<mhlo::MulOp>(loc, a, a));
|
||||||
|
|
||||||
Value horner_sum = zero_like_a;
|
Value horner_sum = zero;
|
||||||
Value factor = one_like_a;
|
Value factor = one;
|
||||||
// Use Horner's rule for this.
|
// Use Horner's rule for this.
|
||||||
// Note this differs from Cephes which does a 'naive' polynomial evaluation.
|
// Note this differs from Cephes which does a 'naive' polynomial evaluation.
|
||||||
// Using Horner's rule allows to avoid some NaN's and Infs from happening,
|
// Using Horner's rule allows to avoid some NaN's and Infs from happening,
|
||||||
|
@ -842,8 +842,7 @@ Value MaterializeZeta(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[11],
|
chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[11],
|
||||||
a),
|
a),
|
||||||
horner_sum)))));
|
horner_sum)))));
|
||||||
const double nan = std::numeric_limits<double>::quiet_NaN();
|
|
||||||
const double inf = std::numeric_limits<double>::infinity();
|
|
||||||
// Use the initial zeta sum without the correction term coming
|
// Use the initial zeta sum without the correction term coming
|
||||||
// from Euler-Maclaurin if it is accurate enough.
|
// from Euler-Maclaurin if it is accurate enough.
|
||||||
const StringAttr kLT = rewriter.getStringAttr(
|
const StringAttr kLT = rewriter.getStringAttr(
|
||||||
|
@ -859,38 +858,49 @@ Value MaterializeZeta(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
chlo::getConstantLikeSmallestFiniteValue(rewriter, loc, a)),
|
chlo::getConstantLikeSmallestFiniteValue(rewriter, loc, a)),
|
||||||
kLT),
|
kLT),
|
||||||
initial_sum, s);
|
initial_sum, s);
|
||||||
// This is the harmonic series.
|
|
||||||
const StringAttr kEQ = rewriter.getStringAttr(
|
|
||||||
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
|
|
||||||
Value inf_like_x = chlo::getConstantLike(rewriter, loc, inf, x);
|
|
||||||
output = rewriter.create<mhlo::SelectOp>(
|
|
||||||
loc, rewriter.create<mhlo::CompareOp>(loc, x, one_like_x, kEQ),
|
|
||||||
inf_like_x, output);
|
|
||||||
// Function is not defined for x < 1.
|
// Function is not defined for x < 1.
|
||||||
Value nan_like_x = chlo::getConstantLike(rewriter, loc, nan, x);
|
Value nan = chlo::getConstantLike(
|
||||||
|
rewriter, loc, std::numeric_limits<double>::quiet_NaN(), x);
|
||||||
output = rewriter.create<mhlo::SelectOp>(
|
output = rewriter.create<mhlo::SelectOp>(
|
||||||
loc, rewriter.create<mhlo::CompareOp>(loc, x, one_like_x, kLT),
|
loc, rewriter.create<mhlo::CompareOp>(loc, x, one_like_x, kLT), nan,
|
||||||
nan_like_x, output);
|
output);
|
||||||
// If q <= 0, then when q is an integer or x is not an integer, this is
|
|
||||||
// NaN.
|
// For q <= 0, x must be an integer.
|
||||||
const StringAttr kLE = rewriter.getStringAttr(
|
const StringAttr kLE = rewriter.getStringAttr(
|
||||||
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LE));
|
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LE));
|
||||||
const StringAttr kNE = rewriter.getStringAttr(
|
const StringAttr kNE = rewriter.getStringAttr(
|
||||||
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE));
|
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE));
|
||||||
Value zero_like_q = chlo::getConstantLike(rewriter, loc, 0.0, q);
|
Value q_le_zero = rewriter.create<mhlo::CompareOp>(loc, q, zero, kLE);
|
||||||
Value q_le_zero = rewriter.create<mhlo::CompareOp>(loc, q, zero_like_q, kLE);
|
Value x_not_int = rewriter.create<mhlo::CompareOp>(
|
||||||
Value domain_error = rewriter.create<mhlo::AndOp>(
|
loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kNE);
|
||||||
loc, q_le_zero,
|
Value x_domain_error =
|
||||||
rewriter.create<mhlo::CompareOp>(
|
rewriter.create<mhlo::AndOp>(loc, q_le_zero, x_not_int);
|
||||||
loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kNE));
|
output = rewriter.create<mhlo::SelectOp>(loc, x_domain_error, nan, output);
|
||||||
Value negative_integer_q = rewriter.create<mhlo::AndOp>(
|
|
||||||
loc, q_le_zero,
|
// For all integer q <= 0, zeta has a pole. The limit is only defined as
|
||||||
rewriter.create<mhlo::CompareOp>(
|
// +inf if x is and even integer.
|
||||||
loc, q, rewriter.create<mhlo::FloorOp>(loc, q), kEQ));
|
const StringAttr kEQ = rewriter.getStringAttr(
|
||||||
output = rewriter.create<mhlo::SelectOp>(loc, negative_integer_q, inf_like_x,
|
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
|
||||||
output);
|
Value inf = chlo::getConstantLike(rewriter, loc,
|
||||||
output =
|
std::numeric_limits<double>::infinity(), x);
|
||||||
rewriter.create<mhlo::SelectOp>(loc, domain_error, nan_like_x, output);
|
Value q_is_int = rewriter.create<mhlo::CompareOp>(
|
||||||
|
loc, q, rewriter.create<mhlo::FloorOp>(loc, q), kEQ);
|
||||||
|
Value at_pole = rewriter.create<mhlo::AndOp>(loc, q_le_zero, q_is_int);
|
||||||
|
Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
|
||||||
|
Value x_is_int = rewriter.create<mhlo::CompareOp>(
|
||||||
|
loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kEQ);
|
||||||
|
Value x_is_even = rewriter.create<mhlo::CompareOp>(
|
||||||
|
loc, rewriter.create<mhlo::RemOp>(loc, x, two), zero, kEQ);
|
||||||
|
Value x_is_even_int = rewriter.create<mhlo::AndOp>(loc, x_is_int, x_is_even);
|
||||||
|
output = rewriter.create<mhlo::SelectOp>(
|
||||||
|
loc, at_pole,
|
||||||
|
rewriter.create<mhlo::SelectOp>(loc, x_is_even_int, inf, nan), output);
|
||||||
|
|
||||||
|
// For x = 1, this is the harmonic series and diverges.
|
||||||
|
output = rewriter.create<mhlo::SelectOp>(
|
||||||
|
loc, rewriter.create<mhlo::CompareOp>(loc, x, one, kEQ), inf, output);
|
||||||
|
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -923,7 +933,7 @@ Value MaterializePolygamma(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
result = rewriter.create<mhlo::SelectOp>(
|
result = rewriter.create<mhlo::SelectOp>(
|
||||||
loc, n_eq_zero, rewriter.create<chlo::DigammaOp>(loc, x), result);
|
loc, n_eq_zero, rewriter.create<chlo::DigammaOp>(loc, x), result);
|
||||||
|
|
||||||
// Check that n is a natural number.
|
// Check that n is a natural number. Return nan, otherwise.
|
||||||
const StringAttr kNE = rewriter.getStringAttr(
|
const StringAttr kNE = rewriter.getStringAttr(
|
||||||
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE));
|
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE));
|
||||||
Value non_int = rewriter.create<mhlo::CompareOp>(
|
Value non_int = rewriter.create<mhlo::CompareOp>(
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue