[MLIR][CHLO] Generalize lowering with upcast to n-ary operation

Allows reuse for zeta lowering now and for the polygamma lowering soon.

PiperOrigin-RevId: 357739910
This commit is contained in:
A. Unique TensorFlower 2021-02-16 09:46:33 -08:00 committed by TensorFlow MLIR Team
parent 81abaf364d
commit c06de24f6c
2 changed files with 84 additions and 78 deletions

1
BUILD
View File

@ -977,6 +977,7 @@ cc_library(
":chlo_legalize_to_hlo_inc_gen", ":chlo_legalize_to_hlo_inc_gen",
":hlo", ":hlo",
":map_chlo_to_hlo_op", ":map_chlo_to_hlo_op",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Shape", "@llvm-project//mlir:Shape",

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <numeric> #include <numeric>
#include <vector> #include <vector>
#include "llvm/ADT/SmallVector.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
@ -96,7 +97,8 @@ Value MaterializePolynomialApproximation(ConversionPatternRewriter &rewriter,
// argument and derive the final approximation for all |x| >= 1. // argument and derive the final approximation for all |x| >= 1.
// This implementation is based on Cephes. // This implementation is based on Cephes.
Value MaterializeErfcApproximationF64ForMagnituteGEOne( Value MaterializeErfcApproximationF64ForMagnituteGEOne(
ConversionPatternRewriter &rewriter, Location loc, Value x) { ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF64() && assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
"expect f64 element type"); "expect f64 element type");
const double kMaxlog = 7.09782712893383996843E2; const double kMaxlog = 7.09782712893383996843E2;
@ -179,7 +181,8 @@ Value MaterializeErfcApproximationF64ForMagnituteGEOne(
// Precondition is |x| <= 1. Use erfc approximation, otherwise. // Precondition is |x| <= 1. Use erfc approximation, otherwise.
// This implementation is based on Cephes. // This implementation is based on Cephes.
Value MaterializeErfApproximationF64ForMagnituteLEOne( Value MaterializeErfApproximationF64ForMagnituteLEOne(
ConversionPatternRewriter &rewriter, Location loc, Value x) { ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF64() && assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
"expect f64 element type"); "expect f64 element type");
const std::vector<double> kErfTCoefficients{ const std::vector<double> kErfTCoefficients{
@ -204,7 +207,8 @@ Value MaterializeErfApproximationF64ForMagnituteLEOne(
// This implementation is based on Cephes. // This implementation is based on Cephes.
Value MaterializeErfApproximationF64(ConversionPatternRewriter &rewriter, Value MaterializeErfApproximationF64(ConversionPatternRewriter &rewriter,
Location loc, Value x) { Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF64() && assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
"expect f64 element type"); "expect f64 element type");
@ -230,7 +234,8 @@ Value MaterializeErfApproximationF64(ConversionPatternRewriter &rewriter,
} }
Value MaterializeErfcApproximationF64(ConversionPatternRewriter &rewriter, Value MaterializeErfcApproximationF64(ConversionPatternRewriter &rewriter,
Location loc, Value x) { Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF64() && assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
"expect f64 element type"); "expect f64 element type");
@ -261,7 +266,8 @@ Value MaterializeErfcApproximationF64(ConversionPatternRewriter &rewriter,
// argument and derive the final approximation for all |x| >= 1. // argument and derive the final approximation for all |x| >= 1.
// This implementation is based on Cephes. // This implementation is based on Cephes.
Value MaterializeErfcApproximationF32ForMagnitudeGEOne( Value MaterializeErfcApproximationF32ForMagnitudeGEOne(
ConversionPatternRewriter &rewriter, Location loc, Value x) { ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF32() && assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
"expect f32 element type"); "expect f32 element type");
const double kMaxlog = 88.72283905206835; const double kMaxlog = 88.72283905206835;
@ -325,7 +331,8 @@ Value MaterializeErfcApproximationF32ForMagnitudeGEOne(
// Precondition is |x| <= 1. Use erfc approximation, otherwise. // Precondition is |x| <= 1. Use erfc approximation, otherwise.
// This implementation is based on Cephes. // This implementation is based on Cephes.
Value MaterializeErfApproximationF32ForMagnitudeLEOne( Value MaterializeErfApproximationF32ForMagnitudeLEOne(
ConversionPatternRewriter &rewriter, Location loc, Value x) { ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF32() && assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
"expect f32 element type"); "expect f32 element type");
const std::vector<float> kErfTCoefficients{ const std::vector<float> kErfTCoefficients{
@ -344,8 +351,9 @@ Value MaterializeErfApproximationF32ForMagnitudeLEOne(
// This is the same approximation as used in Eigen. // This is the same approximation as used in Eigen.
Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter, Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
Location loc, Value operand) { Location loc, ValueRange args) {
assert(operand.getType().cast<ShapedType>().getElementType().isF32() && Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
"expect f32 element type"); "expect f32 element type");
const std::vector<float> kAlpha{ const std::vector<float> kAlpha{
-2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f, -2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
@ -358,10 +366,9 @@ Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
}; };
// Clamp argument between -4 and 4. // Clamp argument between -4 and 4.
Value lb = chlo::getConstantLike(rewriter, loc, -4.0, operand); Value lb = chlo::getConstantLike(rewriter, loc, -4.0, x);
Value ub = chlo::getConstantLike(rewriter, loc, 4.0, operand); Value ub = chlo::getConstantLike(rewriter, loc, 4.0, x);
Value x = x = rewriter.create<mhlo::ClampOp>(loc, x.getType(), lb, x, ub);
rewriter.create<mhlo::ClampOp>(loc, operand.getType(), lb, operand, ub);
Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x); Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
// Materialize polynomial approximation for x in [-4, 4] as // Materialize polynomial approximation for x in [-4, 4] as
@ -375,7 +382,8 @@ Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
} }
Value MaterializeErfcApproximationF32(ConversionPatternRewriter &rewriter, Value MaterializeErfcApproximationF32(ConversionPatternRewriter &rewriter,
Location loc, Value x) { Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF32() && assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
"expect f32 element type"); "expect f32 element type");
@ -401,18 +409,30 @@ Value MaterializeErfcApproximationF32(ConversionPatternRewriter &rewriter,
} }
Value MaterializeWithUpcast(ConversionPatternRewriter &rewriter, Location loc, Value MaterializeWithUpcast(ConversionPatternRewriter &rewriter, Location loc,
Value arg, FloatType min_precision_ty, ValueRange args, FloatType min_precision_ty,
Value callback(ConversionPatternRewriter &, Value callback(ConversionPatternRewriter &,
Location, Value)) { Location, ValueRange)) {
auto original_ty = getElementTypeOrSelf(arg.getType()).cast<FloatType>(); auto original_ty =
getElementTypeOrSelf(args.front().getType()).cast<FloatType>();
bool needs_upcast = original_ty.getWidth() < min_precision_ty.getWidth(); bool needs_upcast = original_ty.getWidth() < min_precision_ty.getWidth();
// Upcast arguments if necessary.
llvm::SmallVector<Value, 2> casted_args;
if (needs_upcast) { if (needs_upcast) {
arg = rewriter.create<mhlo::ConvertOp>(loc, arg, min_precision_ty); for (Value a : args) {
casted_args.push_back(
rewriter.create<mhlo::ConvertOp>(loc, a, min_precision_ty));
} }
Value result = callback(rewriter, loc, arg); args = casted_args;
}
Value result = callback(rewriter, loc, args);
// Cast back if necessary.
if (needs_upcast) { if (needs_upcast) {
result = rewriter.create<mhlo::ConvertOp>(loc, result, original_ty); result = rewriter.create<mhlo::ConvertOp>(loc, result, original_ty);
} }
return result; return result;
} }
@ -434,8 +454,8 @@ struct ConvertErfOp : public OpConversionPattern<ErfOp> {
return success(); return success();
} }
rewriter.replaceOp( rewriter.replaceOp(op, MaterializeWithUpcast(
op, MaterializeWithUpcast(rewriter, loc, x, rewriter.getF32Type(), rewriter, loc, operands, rewriter.getF32Type(),
&MaterializeErfApproximationF32)); &MaterializeErfApproximationF32));
return success(); return success();
} }
@ -459,8 +479,8 @@ struct ConvertErfcOp : public OpConversionPattern<ErfcOp> {
return success(); return success();
} }
rewriter.replaceOp( rewriter.replaceOp(op, MaterializeWithUpcast(
op, MaterializeWithUpcast(rewriter, loc, x, rewriter.getF32Type(), rewriter, loc, operands, rewriter.getF32Type(),
&MaterializeErfcApproximationF32)); &MaterializeErfcApproximationF32));
return success(); return success();
} }
@ -491,12 +511,13 @@ constexpr std::array<double, 8> kLanczosCoefficients = {
// a(z) = kBaseLanczosCoeff // a(z) = kBaseLanczosCoeff
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k)) // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc, Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc,
Value x) { ValueRange args) {
// If the input is less than 0.5 use Euler's reflection formula. // If the input is less than 0.5 use Euler's reflection formula.
// gamma(x) = pi / (sin(pi * x) * gamma(1 - x)) // gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
// Let z be // Let z be
// z = -x if x < 1/2 // z = -x if x < 1/2
// z = x - 1 otheriwse // z = x - 1 otheriwse
Value x = args.front();
const StringAttr kLT = rewriter.getStringAttr( const StringAttr kLT = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT)); mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
Value half = getConstantLike(rewriter, loc, 0.5, x); Value half = getConstantLike(rewriter, loc, 0.5, x);
@ -635,12 +656,13 @@ Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc,
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k)) // + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
// a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k)) // a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
Value MaterializeDigamma(ConversionPatternRewriter &rewriter, Location loc, Value MaterializeDigamma(ConversionPatternRewriter &rewriter, Location loc,
Value x) { ValueRange args) {
// If the input is less than 0.5 use Euler's reflection formula. // If the input is less than 0.5 use Euler's reflection formula.
// digamma(x) = digamma(1 - x) - pi * cot(pi * x) // digamma(x) = digamma(1 - x) - pi * cot(pi * x)
// Let z be // Let z be
// z = -x if x < 1/2 // z = -x if x < 1/2
// z = x - 1 otheriwse // z = x - 1 otheriwse
Value x = args.front();
const StringAttr kLT = rewriter.getStringAttr( const StringAttr kLT = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT)); mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
Value half = getConstantLike(rewriter, loc, 0.5, x); Value half = getConstantLike(rewriter, loc, 0.5, x);
@ -739,36 +761,11 @@ Value MaterializeDigamma(ConversionPatternRewriter &rewriter, Location loc,
digamma); digamma);
} }
struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> { Value MaterializeZeta(ConversionPatternRewriter &rewriter, Location loc,
using OpConversionPattern<LgammaOp>::OpConversionPattern; ValueRange args) {
LogicalResult matchAndRewrite( assert(args.size() == 2);
LgammaOp op, ArrayRef<Value> operands, Value x = args[0];
ConversionPatternRewriter &rewriter) const override { Value q = args[1];
LgammaOp::Adaptor transformed(operands);
FloatType min_precision_ty = rewriter.getF32Type();
rewriter.replaceOp(
op, MaterializeWithUpcast(rewriter, op.getLoc(), transformed.operand(),
min_precision_ty, &MaterializeLgamma));
return success();
}
};
struct ConvertDigammaOp : public OpConversionPattern<DigammaOp> {
using OpConversionPattern<DigammaOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
DigammaOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
DigammaOp::Adaptor transformed(operands);
FloatType min_precision_ty = rewriter.getF32Type();
rewriter.replaceOp(
op, MaterializeWithUpcast(rewriter, op.getLoc(), transformed.operand(),
min_precision_ty, &MaterializeDigamma));
return success();
}
};
Value MaterializeZetaComputation(ConversionPatternRewriter &rewriter,
Location loc, Value x, Value q) {
static const std::array<double, 12> kZetaCoeffs{ static const std::array<double, 12> kZetaCoeffs{
-7.1661652561756670113e18, -7.1661652561756670113e18,
1.8152105401943546773e17, 1.8152105401943546773e17,
@ -897,34 +894,42 @@ Value MaterializeZetaComputation(ConversionPatternRewriter &rewriter,
return output; return output;
} }
struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> {
using OpConversionPattern<LgammaOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
LgammaOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
FloatType min_precision_ty = rewriter.getF32Type();
rewriter.replaceOp(
op, MaterializeWithUpcast(rewriter, op.getLoc(), operands,
min_precision_ty, &MaterializeLgamma));
return success();
}
};
struct ConvertDigammaOp : public OpConversionPattern<DigammaOp> {
using OpConversionPattern<DigammaOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
DigammaOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
FloatType min_precision_ty = rewriter.getF32Type();
rewriter.replaceOp(
op, MaterializeWithUpcast(rewriter, op.getLoc(), operands,
min_precision_ty, &MaterializeDigamma));
return success();
}
};
struct ConvertZetaOp : public OpConversionPattern<ZetaOp> { struct ConvertZetaOp : public OpConversionPattern<ZetaOp> {
using OpConversionPattern<ZetaOp>::OpConversionPattern; using OpConversionPattern<ZetaOp>::OpConversionPattern;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
ZetaOp op, ArrayRef<Value> operands, ZetaOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
ZetaOpAdaptor adaptor(operands);
Location loc = op.getLoc(); Location loc = op.getLoc();
FloatType min_precision_ty = rewriter.getF32Type();
// Zeta is only defined on tensors of float elements and statically rewriter.replaceOp(
// verified that both have the same type. So it suffices to look at one op, MaterializeWithUpcast(rewriter, loc, operands, min_precision_ty,
// here. &MaterializeZeta));
auto elm_type = adaptor.x().getType().cast<ShapedType>().getElementType();
bool needs_upcast = elm_type.isF16() || elm_type.isBF16();
Value x = adaptor.x();
Value q = adaptor.q();
if (needs_upcast) {
x = rewriter.create<mhlo::ConvertOp>(loc, x, rewriter.getF32Type());
q = rewriter.create<mhlo::ConvertOp>(loc, q, rewriter.getF32Type());
}
Value result = MaterializeZetaComputation(rewriter, loc, x, q);
if (needs_upcast) {
result = rewriter.create<mhlo::ConvertOp>(loc, result, elm_type);
}
rewriter.replaceOp(op, {result});
return success(); return success();
} }
}; };