[MLIR][CHLO] Generalize lowering with upcast to n-ary operation

Allows reuse for zeta lowering now and for the polygamma lowering soon.

PiperOrigin-RevId: 357739910
This commit is contained in:
A. Unique TensorFlower 2021-02-16 09:46:33 -08:00 committed by TensorFlow MLIR Team
parent 81abaf364d
commit c06de24f6c
2 changed files with 84 additions and 78 deletions

1
BUILD
View File

@ -977,6 +977,7 @@ cc_library(
":chlo_legalize_to_hlo_inc_gen",
":hlo",
":map_chlo_to_hlo_op",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Shape",

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <numeric>
#include <vector>
#include "llvm/ADT/SmallVector.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
@ -96,7 +97,8 @@ Value MaterializePolynomialApproximation(ConversionPatternRewriter &rewriter,
// argument and derive the final approximation for all |x| >= 1.
// This implementation is based on Cephes.
Value MaterializeErfcApproximationF64ForMagnituteGEOne(
ConversionPatternRewriter &rewriter, Location loc, Value x) {
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
"expect f64 element type");
const double kMaxlog = 7.09782712893383996843E2;
@ -179,7 +181,8 @@ Value MaterializeErfcApproximationF64ForMagnituteGEOne(
// Precondition is |x| <= 1. Use erfc approximation, otherwise.
// This implementation is based on Cephes.
Value MaterializeErfApproximationF64ForMagnituteLEOne(
ConversionPatternRewriter &rewriter, Location loc, Value x) {
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
"expect f64 element type");
const std::vector<double> kErfTCoefficients{
@ -204,7 +207,8 @@ Value MaterializeErfApproximationF64ForMagnituteLEOne(
// This implementation is based on Cephes.
Value MaterializeErfApproximationF64(ConversionPatternRewriter &rewriter,
Location loc, Value x) {
Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
"expect f64 element type");
@ -230,7 +234,8 @@ Value MaterializeErfApproximationF64(ConversionPatternRewriter &rewriter,
}
Value MaterializeErfcApproximationF64(ConversionPatternRewriter &rewriter,
Location loc, Value x) {
Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
"expect f64 element type");
@ -261,7 +266,8 @@ Value MaterializeErfcApproximationF64(ConversionPatternRewriter &rewriter,
// argument and derive the final approximation for all |x| >= 1.
// This implementation is based on Cephes.
Value MaterializeErfcApproximationF32ForMagnitudeGEOne(
ConversionPatternRewriter &rewriter, Location loc, Value x) {
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
"expect f32 element type");
const double kMaxlog = 88.72283905206835;
@ -325,7 +331,8 @@ Value MaterializeErfcApproximationF32ForMagnitudeGEOne(
// Precondition is |x| <= 1. Use erfc approximation, otherwise.
// This implementation is based on Cephes.
Value MaterializeErfApproximationF32ForMagnitudeLEOne(
ConversionPatternRewriter &rewriter, Location loc, Value x) {
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
"expect f32 element type");
const std::vector<float> kErfTCoefficients{
@ -344,8 +351,9 @@ Value MaterializeErfApproximationF32ForMagnitudeLEOne(
// This is the same approximation as used in Eigen.
Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
Location loc, Value operand) {
assert(operand.getType().cast<ShapedType>().getElementType().isF32() &&
Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
"expect f32 element type");
const std::vector<float> kAlpha{
-2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
@ -358,10 +366,9 @@ Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
};
// Clamp argument between -4 and 4.
Value lb = chlo::getConstantLike(rewriter, loc, -4.0, operand);
Value ub = chlo::getConstantLike(rewriter, loc, 4.0, operand);
Value x =
rewriter.create<mhlo::ClampOp>(loc, operand.getType(), lb, operand, ub);
Value lb = chlo::getConstantLike(rewriter, loc, -4.0, x);
Value ub = chlo::getConstantLike(rewriter, loc, 4.0, x);
x = rewriter.create<mhlo::ClampOp>(loc, x.getType(), lb, x, ub);
Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
// Materialize polynomial approximation for x in [-4, 4] as
@ -375,7 +382,8 @@ Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
}
Value MaterializeErfcApproximationF32(ConversionPatternRewriter &rewriter,
Location loc, Value x) {
Location loc, ValueRange args) {
Value x = args.front();
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
"expect f32 element type");
@ -401,18 +409,30 @@ Value MaterializeErfcApproximationF32(ConversionPatternRewriter &rewriter,
}
Value MaterializeWithUpcast(ConversionPatternRewriter &rewriter, Location loc,
Value arg, FloatType min_precision_ty,
ValueRange args, FloatType min_precision_ty,
Value callback(ConversionPatternRewriter &,
Location, Value)) {
auto original_ty = getElementTypeOrSelf(arg.getType()).cast<FloatType>();
Location, ValueRange)) {
auto original_ty =
getElementTypeOrSelf(args.front().getType()).cast<FloatType>();
bool needs_upcast = original_ty.getWidth() < min_precision_ty.getWidth();
// Upcast arguments if necessary.
llvm::SmallVector<Value, 2> casted_args;
if (needs_upcast) {
arg = rewriter.create<mhlo::ConvertOp>(loc, arg, min_precision_ty);
for (Value a : args) {
casted_args.push_back(
rewriter.create<mhlo::ConvertOp>(loc, a, min_precision_ty));
}
args = casted_args;
}
Value result = callback(rewriter, loc, arg);
Value result = callback(rewriter, loc, args);
// Cast back if necessary.
if (needs_upcast) {
result = rewriter.create<mhlo::ConvertOp>(loc, result, original_ty);
}
return result;
}
@ -434,9 +454,9 @@ struct ConvertErfOp : public OpConversionPattern<ErfOp> {
return success();
}
rewriter.replaceOp(
op, MaterializeWithUpcast(rewriter, loc, x, rewriter.getF32Type(),
&MaterializeErfApproximationF32));
rewriter.replaceOp(op, MaterializeWithUpcast(
rewriter, loc, operands, rewriter.getF32Type(),
&MaterializeErfApproximationF32));
return success();
}
};
@ -459,9 +479,9 @@ struct ConvertErfcOp : public OpConversionPattern<ErfcOp> {
return success();
}
rewriter.replaceOp(
op, MaterializeWithUpcast(rewriter, loc, x, rewriter.getF32Type(),
&MaterializeErfcApproximationF32));
rewriter.replaceOp(op, MaterializeWithUpcast(
rewriter, loc, operands, rewriter.getF32Type(),
&MaterializeErfcApproximationF32));
return success();
}
};
@ -491,12 +511,13 @@ constexpr std::array<double, 8> kLanczosCoefficients = {
// a(z) = kBaseLanczosCoeff
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc,
Value x) {
ValueRange args) {
// If the input is less than 0.5 use Euler's reflection formula.
// gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
// Let z be
// z = -x if x < 1/2
// z = x - 1 otheriwse
Value x = args.front();
const StringAttr kLT = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
Value half = getConstantLike(rewriter, loc, 0.5, x);
@ -635,12 +656,13 @@ Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc,
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
// a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
Value MaterializeDigamma(ConversionPatternRewriter &rewriter, Location loc,
Value x) {
ValueRange args) {
// If the input is less than 0.5 use Euler's reflection formula.
// digamma(x) = digamma(1 - x) - pi * cot(pi * x)
// Let z be
// z = -x if x < 1/2
// z = x - 1 otheriwse
Value x = args.front();
const StringAttr kLT = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
Value half = getConstantLike(rewriter, loc, 0.5, x);
@ -739,36 +761,11 @@ Value MaterializeDigamma(ConversionPatternRewriter &rewriter, Location loc,
digamma);
}
struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> {
using OpConversionPattern<LgammaOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
LgammaOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
LgammaOp::Adaptor transformed(operands);
FloatType min_precision_ty = rewriter.getF32Type();
rewriter.replaceOp(
op, MaterializeWithUpcast(rewriter, op.getLoc(), transformed.operand(),
min_precision_ty, &MaterializeLgamma));
return success();
}
};
struct ConvertDigammaOp : public OpConversionPattern<DigammaOp> {
using OpConversionPattern<DigammaOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
DigammaOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
DigammaOp::Adaptor transformed(operands);
FloatType min_precision_ty = rewriter.getF32Type();
rewriter.replaceOp(
op, MaterializeWithUpcast(rewriter, op.getLoc(), transformed.operand(),
min_precision_ty, &MaterializeDigamma));
return success();
}
};
Value MaterializeZetaComputation(ConversionPatternRewriter &rewriter,
Location loc, Value x, Value q) {
Value MaterializeZeta(ConversionPatternRewriter &rewriter, Location loc,
ValueRange args) {
assert(args.size() == 2);
Value x = args[0];
Value q = args[1];
static const std::array<double, 12> kZetaCoeffs{
-7.1661652561756670113e18,
1.8152105401943546773e17,
@ -897,34 +894,42 @@ Value MaterializeZetaComputation(ConversionPatternRewriter &rewriter,
return output;
}
struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> {
using OpConversionPattern<LgammaOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
LgammaOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
FloatType min_precision_ty = rewriter.getF32Type();
rewriter.replaceOp(
op, MaterializeWithUpcast(rewriter, op.getLoc(), operands,
min_precision_ty, &MaterializeLgamma));
return success();
}
};
struct ConvertDigammaOp : public OpConversionPattern<DigammaOp> {
using OpConversionPattern<DigammaOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
DigammaOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
FloatType min_precision_ty = rewriter.getF32Type();
rewriter.replaceOp(
op, MaterializeWithUpcast(rewriter, op.getLoc(), operands,
min_precision_ty, &MaterializeDigamma));
return success();
}
};
struct ConvertZetaOp : public OpConversionPattern<ZetaOp> {
using OpConversionPattern<ZetaOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
ZetaOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
ZetaOpAdaptor adaptor(operands);
Location loc = op.getLoc();
// Zeta is only defined on tensors of float elements and statically
// verified that both have the same type. So it suffices to look at one
// here.
auto elm_type = adaptor.x().getType().cast<ShapedType>().getElementType();
bool needs_upcast = elm_type.isF16() || elm_type.isBF16();
Value x = adaptor.x();
Value q = adaptor.q();
if (needs_upcast) {
x = rewriter.create<mhlo::ConvertOp>(loc, x, rewriter.getF32Type());
q = rewriter.create<mhlo::ConvertOp>(loc, q, rewriter.getF32Type());
}
Value result = MaterializeZetaComputation(rewriter, loc, x, q);
if (needs_upcast) {
result = rewriter.create<mhlo::ConvertOp>(loc, result, elm_type);
}
rewriter.replaceOp(op, {result});
FloatType min_precision_ty = rewriter.getF32Type();
rewriter.replaceOp(
op, MaterializeWithUpcast(rewriter, loc, operands, min_precision_ty,
&MaterializeZeta));
return success();
}
};