[MLIR][KernelGen] Fix zeta lowering at poles

Return nan at zeta poles or inf where the limit is defined. Also test the kernel
based on the series representation of zeta.

PiperOrigin-RevId: 361993482
This commit is contained in:
A. Unique TensorFlower 2021-03-10 01:08:02 -08:00 committed by TensorFlow MLIR Team
parent 7629dfdd81
commit 218476128e
2 changed files with 548 additions and 472 deletions

View File

@ -784,17 +784,17 @@ Value MaterializeZeta(ConversionPatternRewriter &rewriter, Location loc,
// For speed we'll always use 9 iterations for the initial series estimate, // For speed we'll always use 9 iterations for the initial series estimate,
// and a 12 term expansion for the Euler-Maclaurin formula. // and a 12 term expansion for the Euler-Maclaurin formula.
Value a = q; Value a = q;
Value zero_like_a = chlo::getConstantLike(rewriter, loc, 0.0, a); Value zero = chlo::getConstantLike(rewriter, loc, 0.0, a);
Value neg_power = zero_like_a; Value neg_power = zero;
Value neg_x = rewriter.create<mhlo::NegOp>(loc, x); Value neg_x = rewriter.create<mhlo::NegOp>(loc, x);
Value initial_sum = rewriter.create<mhlo::PowOp>(loc, q, neg_x); Value initial_sum = rewriter.create<mhlo::PowOp>(loc, q, neg_x);
Value one_like_a = chlo::getConstantLike(rewriter, loc, 1.0, a); Value one = chlo::getConstantLike(rewriter, loc, 1.0, a);
for (int i = 0; i < 9; ++i) { for (int i = 0; i < 9; ++i) {
a = rewriter.create<mhlo::AddOp>(loc, a, one_like_a); a = rewriter.create<mhlo::AddOp>(loc, a, one);
neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x); neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x);
initial_sum = rewriter.create<mhlo::AddOp>(loc, initial_sum, neg_power); initial_sum = rewriter.create<mhlo::AddOp>(loc, initial_sum, neg_power);
} }
a = rewriter.create<mhlo::AddOp>(loc, a, one_like_a); a = rewriter.create<mhlo::AddOp>(loc, a, one);
neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x); neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x);
Value one_like_x = chlo::getConstantLike(rewriter, loc, 1.0, x); Value one_like_x = chlo::getConstantLike(rewriter, loc, 1.0, x);
Value x_minus_one = rewriter.create<mhlo::SubOp>(loc, x, one_like_x); Value x_minus_one = rewriter.create<mhlo::SubOp>(loc, x, one_like_x);
@ -804,10 +804,10 @@ Value MaterializeZeta(ConversionPatternRewriter &rewriter, Location loc,
Value s = rewriter.create<mhlo::AddOp>(loc, initial_sum, Value s = rewriter.create<mhlo::AddOp>(loc, initial_sum,
neg_power_mul_a_div_x_minus_one); neg_power_mul_a_div_x_minus_one);
Value a_inverse_square = rewriter.create<mhlo::DivOp>( Value a_inverse_square = rewriter.create<mhlo::DivOp>(
loc, one_like_a, rewriter.create<mhlo::MulOp>(loc, a, a)); loc, one, rewriter.create<mhlo::MulOp>(loc, a, a));
Value horner_sum = zero_like_a; Value horner_sum = zero;
Value factor = one_like_a; Value factor = one;
// Use Horner's rule for this. // Use Horner's rule for this.
// Note this differs from Cephes which does a 'naive' polynomial evaluation. // Note this differs from Cephes which does a 'naive' polynomial evaluation.
// Using Horner's rule allows to avoid some NaN's and Infs from happening, // Using Horner's rule allows to avoid some NaN's and Infs from happening,
@ -842,8 +842,7 @@ Value MaterializeZeta(ConversionPatternRewriter &rewriter, Location loc,
chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[11], chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[11],
a), a),
horner_sum))))); horner_sum)))));
const double nan = std::numeric_limits<double>::quiet_NaN();
const double inf = std::numeric_limits<double>::infinity();
// Use the initial zeta sum without the correction term coming // Use the initial zeta sum without the correction term coming
// from Euler-Maclaurin if it is accurate enough. // from Euler-Maclaurin if it is accurate enough.
const StringAttr kLT = rewriter.getStringAttr( const StringAttr kLT = rewriter.getStringAttr(
@ -859,38 +858,49 @@ Value MaterializeZeta(ConversionPatternRewriter &rewriter, Location loc,
chlo::getConstantLikeSmallestFiniteValue(rewriter, loc, a)), chlo::getConstantLikeSmallestFiniteValue(rewriter, loc, a)),
kLT), kLT),
initial_sum, s); initial_sum, s);
// This is the harmonic series.
const StringAttr kEQ = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
Value inf_like_x = chlo::getConstantLike(rewriter, loc, inf, x);
output = rewriter.create<mhlo::SelectOp>(
loc, rewriter.create<mhlo::CompareOp>(loc, x, one_like_x, kEQ),
inf_like_x, output);
// Function is not defined for x < 1. // Function is not defined for x < 1.
Value nan_like_x = chlo::getConstantLike(rewriter, loc, nan, x); Value nan = chlo::getConstantLike(
rewriter, loc, std::numeric_limits<double>::quiet_NaN(), x);
output = rewriter.create<mhlo::SelectOp>( output = rewriter.create<mhlo::SelectOp>(
loc, rewriter.create<mhlo::CompareOp>(loc, x, one_like_x, kLT), loc, rewriter.create<mhlo::CompareOp>(loc, x, one_like_x, kLT), nan,
nan_like_x, output); output);
// If q <= 0, then when q is an integer or x is not an integer, this is
// NaN. // For q <= 0, x must be an integer.
const StringAttr kLE = rewriter.getStringAttr( const StringAttr kLE = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LE)); mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LE));
const StringAttr kNE = rewriter.getStringAttr( const StringAttr kNE = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE)); mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE));
Value zero_like_q = chlo::getConstantLike(rewriter, loc, 0.0, q); Value q_le_zero = rewriter.create<mhlo::CompareOp>(loc, q, zero, kLE);
Value q_le_zero = rewriter.create<mhlo::CompareOp>(loc, q, zero_like_q, kLE); Value x_not_int = rewriter.create<mhlo::CompareOp>(
Value domain_error = rewriter.create<mhlo::AndOp>( loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kNE);
loc, q_le_zero, Value x_domain_error =
rewriter.create<mhlo::CompareOp>( rewriter.create<mhlo::AndOp>(loc, q_le_zero, x_not_int);
loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kNE)); output = rewriter.create<mhlo::SelectOp>(loc, x_domain_error, nan, output);
Value negative_integer_q = rewriter.create<mhlo::AndOp>(
loc, q_le_zero, // For all integer q <= 0, zeta has a pole. The limit is only defined as
rewriter.create<mhlo::CompareOp>( // +inf if x is and even integer.
loc, q, rewriter.create<mhlo::FloorOp>(loc, q), kEQ)); const StringAttr kEQ = rewriter.getStringAttr(
output = rewriter.create<mhlo::SelectOp>(loc, negative_integer_q, inf_like_x, mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
output); Value inf = chlo::getConstantLike(rewriter, loc,
output = std::numeric_limits<double>::infinity(), x);
rewriter.create<mhlo::SelectOp>(loc, domain_error, nan_like_x, output); Value q_is_int = rewriter.create<mhlo::CompareOp>(
loc, q, rewriter.create<mhlo::FloorOp>(loc, q), kEQ);
Value at_pole = rewriter.create<mhlo::AndOp>(loc, q_le_zero, q_is_int);
Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
Value x_is_int = rewriter.create<mhlo::CompareOp>(
loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kEQ);
Value x_is_even = rewriter.create<mhlo::CompareOp>(
loc, rewriter.create<mhlo::RemOp>(loc, x, two), zero, kEQ);
Value x_is_even_int = rewriter.create<mhlo::AndOp>(loc, x_is_int, x_is_even);
output = rewriter.create<mhlo::SelectOp>(
loc, at_pole,
rewriter.create<mhlo::SelectOp>(loc, x_is_even_int, inf, nan), output);
// For x = 1, this is the harmonic series and diverges.
output = rewriter.create<mhlo::SelectOp>(
loc, rewriter.create<mhlo::CompareOp>(loc, x, one, kEQ), inf, output);
return output; return output;
} }
@ -923,7 +933,7 @@ Value MaterializePolygamma(ConversionPatternRewriter &rewriter, Location loc,
result = rewriter.create<mhlo::SelectOp>( result = rewriter.create<mhlo::SelectOp>(
loc, n_eq_zero, rewriter.create<chlo::DigammaOp>(loc, x), result); loc, n_eq_zero, rewriter.create<chlo::DigammaOp>(loc, x), result);
// Check that n is a natural number. // Check that n is a natural number. Return nan, otherwise.
const StringAttr kNE = rewriter.getStringAttr( const StringAttr kNE = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE)); mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE));
Value non_int = rewriter.create<mhlo::CompareOp>( Value non_int = rewriter.create<mhlo::CompareOp>(

File diff suppressed because it is too large Load Diff