[MLIR][CHLO] Generalize lowering with upcast to n-ary operation
Allows reuse for zeta lowering now and for the polygamma lowering soon. PiperOrigin-RevId: 357739910
This commit is contained in:
parent
81abaf364d
commit
c06de24f6c
1
BUILD
1
BUILD
|
@ -977,6 +977,7 @@ cc_library(
|
|||
":chlo_legalize_to_hlo_inc_gen",
|
||||
":hlo",
|
||||
":map_chlo_to_hlo_op",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:Shape",
|
||||
|
|
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
|
||||
|
@ -96,7 +97,8 @@ Value MaterializePolynomialApproximation(ConversionPatternRewriter &rewriter,
|
|||
// argument and derive the final approximation for all |x| >= 1.
|
||||
// This implementation is based on Cephes.
|
||||
Value MaterializeErfcApproximationF64ForMagnituteGEOne(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Value x) {
|
||||
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
|
||||
Value x = args.front();
|
||||
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
|
||||
"expect f64 element type");
|
||||
const double kMaxlog = 7.09782712893383996843E2;
|
||||
|
@ -179,7 +181,8 @@ Value MaterializeErfcApproximationF64ForMagnituteGEOne(
|
|||
// Precondition is |x| <= 1. Use erfc approximation, otherwise.
|
||||
// This implementation is based on Cephes.
|
||||
Value MaterializeErfApproximationF64ForMagnituteLEOne(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Value x) {
|
||||
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
|
||||
Value x = args.front();
|
||||
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
|
||||
"expect f64 element type");
|
||||
const std::vector<double> kErfTCoefficients{
|
||||
|
@ -204,7 +207,8 @@ Value MaterializeErfApproximationF64ForMagnituteLEOne(
|
|||
|
||||
// This implementation is based on Cephes.
|
||||
Value MaterializeErfApproximationF64(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value x) {
|
||||
Location loc, ValueRange args) {
|
||||
Value x = args.front();
|
||||
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
|
||||
"expect f64 element type");
|
||||
|
||||
|
@ -230,7 +234,8 @@ Value MaterializeErfApproximationF64(ConversionPatternRewriter &rewriter,
|
|||
}
|
||||
|
||||
Value MaterializeErfcApproximationF64(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value x) {
|
||||
Location loc, ValueRange args) {
|
||||
Value x = args.front();
|
||||
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
|
||||
"expect f64 element type");
|
||||
|
||||
|
@ -261,7 +266,8 @@ Value MaterializeErfcApproximationF64(ConversionPatternRewriter &rewriter,
|
|||
// argument and derive the final approximation for all |x| >= 1.
|
||||
// This implementation is based on Cephes.
|
||||
Value MaterializeErfcApproximationF32ForMagnitudeGEOne(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Value x) {
|
||||
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
|
||||
Value x = args.front();
|
||||
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
|
||||
"expect f32 element type");
|
||||
const double kMaxlog = 88.72283905206835;
|
||||
|
@ -325,7 +331,8 @@ Value MaterializeErfcApproximationF32ForMagnitudeGEOne(
|
|||
// Precondition is |x| <= 1. Use erfc approximation, otherwise.
|
||||
// This implementation is based on Cephes.
|
||||
Value MaterializeErfApproximationF32ForMagnitudeLEOne(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Value x) {
|
||||
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
|
||||
Value x = args.front();
|
||||
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
|
||||
"expect f32 element type");
|
||||
const std::vector<float> kErfTCoefficients{
|
||||
|
@ -344,8 +351,9 @@ Value MaterializeErfApproximationF32ForMagnitudeLEOne(
|
|||
|
||||
// This is the same approximation as used in Eigen.
|
||||
Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value operand) {
|
||||
assert(operand.getType().cast<ShapedType>().getElementType().isF32() &&
|
||||
Location loc, ValueRange args) {
|
||||
Value x = args.front();
|
||||
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
|
||||
"expect f32 element type");
|
||||
const std::vector<float> kAlpha{
|
||||
-2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
|
||||
|
@ -358,10 +366,9 @@ Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
|
|||
};
|
||||
|
||||
// Clamp argument between -4 and 4.
|
||||
Value lb = chlo::getConstantLike(rewriter, loc, -4.0, operand);
|
||||
Value ub = chlo::getConstantLike(rewriter, loc, 4.0, operand);
|
||||
Value x =
|
||||
rewriter.create<mhlo::ClampOp>(loc, operand.getType(), lb, operand, ub);
|
||||
Value lb = chlo::getConstantLike(rewriter, loc, -4.0, x);
|
||||
Value ub = chlo::getConstantLike(rewriter, loc, 4.0, x);
|
||||
x = rewriter.create<mhlo::ClampOp>(loc, x.getType(), lb, x, ub);
|
||||
Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
|
||||
|
||||
// Materialize polynomial approximation for x in [-4, 4] as
|
||||
|
@ -375,7 +382,8 @@ Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
|
|||
}
|
||||
|
||||
Value MaterializeErfcApproximationF32(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value x) {
|
||||
Location loc, ValueRange args) {
|
||||
Value x = args.front();
|
||||
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
|
||||
"expect f32 element type");
|
||||
|
||||
|
@ -401,18 +409,30 @@ Value MaterializeErfcApproximationF32(ConversionPatternRewriter &rewriter,
|
|||
}
|
||||
|
||||
Value MaterializeWithUpcast(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value arg, FloatType min_precision_ty,
|
||||
ValueRange args, FloatType min_precision_ty,
|
||||
Value callback(ConversionPatternRewriter &,
|
||||
Location, Value)) {
|
||||
auto original_ty = getElementTypeOrSelf(arg.getType()).cast<FloatType>();
|
||||
Location, ValueRange)) {
|
||||
auto original_ty =
|
||||
getElementTypeOrSelf(args.front().getType()).cast<FloatType>();
|
||||
bool needs_upcast = original_ty.getWidth() < min_precision_ty.getWidth();
|
||||
|
||||
// Upcast arguments if necessary.
|
||||
llvm::SmallVector<Value, 2> casted_args;
|
||||
if (needs_upcast) {
|
||||
arg = rewriter.create<mhlo::ConvertOp>(loc, arg, min_precision_ty);
|
||||
for (Value a : args) {
|
||||
casted_args.push_back(
|
||||
rewriter.create<mhlo::ConvertOp>(loc, a, min_precision_ty));
|
||||
}
|
||||
Value result = callback(rewriter, loc, arg);
|
||||
args = casted_args;
|
||||
}
|
||||
|
||||
Value result = callback(rewriter, loc, args);
|
||||
|
||||
// Cast back if necessary.
|
||||
if (needs_upcast) {
|
||||
result = rewriter.create<mhlo::ConvertOp>(loc, result, original_ty);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -434,8 +454,8 @@ struct ConvertErfOp : public OpConversionPattern<ErfOp> {
|
|||
return success();
|
||||
}
|
||||
|
||||
rewriter.replaceOp(
|
||||
op, MaterializeWithUpcast(rewriter, loc, x, rewriter.getF32Type(),
|
||||
rewriter.replaceOp(op, MaterializeWithUpcast(
|
||||
rewriter, loc, operands, rewriter.getF32Type(),
|
||||
&MaterializeErfApproximationF32));
|
||||
return success();
|
||||
}
|
||||
|
@ -459,8 +479,8 @@ struct ConvertErfcOp : public OpConversionPattern<ErfcOp> {
|
|||
return success();
|
||||
}
|
||||
|
||||
rewriter.replaceOp(
|
||||
op, MaterializeWithUpcast(rewriter, loc, x, rewriter.getF32Type(),
|
||||
rewriter.replaceOp(op, MaterializeWithUpcast(
|
||||
rewriter, loc, operands, rewriter.getF32Type(),
|
||||
&MaterializeErfcApproximationF32));
|
||||
return success();
|
||||
}
|
||||
|
@ -491,12 +511,13 @@ constexpr std::array<double, 8> kLanczosCoefficients = {
|
|||
// a(z) = kBaseLanczosCoeff
|
||||
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
|
||||
Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value x) {
|
||||
ValueRange args) {
|
||||
// If the input is less than 0.5 use Euler's reflection formula.
|
||||
// gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
|
||||
// Let z be
|
||||
// z = -x if x < 1/2
|
||||
// z = x - 1 otheriwse
|
||||
Value x = args.front();
|
||||
const StringAttr kLT = rewriter.getStringAttr(
|
||||
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
|
||||
Value half = getConstantLike(rewriter, loc, 0.5, x);
|
||||
|
@ -635,12 +656,13 @@ Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc,
|
|||
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
|
||||
// a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
|
||||
Value MaterializeDigamma(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value x) {
|
||||
ValueRange args) {
|
||||
// If the input is less than 0.5 use Euler's reflection formula.
|
||||
// digamma(x) = digamma(1 - x) - pi * cot(pi * x)
|
||||
// Let z be
|
||||
// z = -x if x < 1/2
|
||||
// z = x - 1 otheriwse
|
||||
Value x = args.front();
|
||||
const StringAttr kLT = rewriter.getStringAttr(
|
||||
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
|
||||
Value half = getConstantLike(rewriter, loc, 0.5, x);
|
||||
|
@ -739,36 +761,11 @@ Value MaterializeDigamma(ConversionPatternRewriter &rewriter, Location loc,
|
|||
digamma);
|
||||
}
|
||||
|
||||
struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> {
|
||||
using OpConversionPattern<LgammaOp>::OpConversionPattern;
|
||||
LogicalResult matchAndRewrite(
|
||||
LgammaOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
LgammaOp::Adaptor transformed(operands);
|
||||
FloatType min_precision_ty = rewriter.getF32Type();
|
||||
rewriter.replaceOp(
|
||||
op, MaterializeWithUpcast(rewriter, op.getLoc(), transformed.operand(),
|
||||
min_precision_ty, &MaterializeLgamma));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertDigammaOp : public OpConversionPattern<DigammaOp> {
|
||||
using OpConversionPattern<DigammaOp>::OpConversionPattern;
|
||||
LogicalResult matchAndRewrite(
|
||||
DigammaOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
DigammaOp::Adaptor transformed(operands);
|
||||
FloatType min_precision_ty = rewriter.getF32Type();
|
||||
rewriter.replaceOp(
|
||||
op, MaterializeWithUpcast(rewriter, op.getLoc(), transformed.operand(),
|
||||
min_precision_ty, &MaterializeDigamma));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
Value MaterializeZetaComputation(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value x, Value q) {
|
||||
Value MaterializeZeta(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ValueRange args) {
|
||||
assert(args.size() == 2);
|
||||
Value x = args[0];
|
||||
Value q = args[1];
|
||||
static const std::array<double, 12> kZetaCoeffs{
|
||||
-7.1661652561756670113e18,
|
||||
1.8152105401943546773e17,
|
||||
|
@ -897,34 +894,42 @@ Value MaterializeZetaComputation(ConversionPatternRewriter &rewriter,
|
|||
return output;
|
||||
}
|
||||
|
||||
struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> {
|
||||
using OpConversionPattern<LgammaOp>::OpConversionPattern;
|
||||
LogicalResult matchAndRewrite(
|
||||
LgammaOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
FloatType min_precision_ty = rewriter.getF32Type();
|
||||
rewriter.replaceOp(
|
||||
op, MaterializeWithUpcast(rewriter, op.getLoc(), operands,
|
||||
min_precision_ty, &MaterializeLgamma));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertDigammaOp : public OpConversionPattern<DigammaOp> {
|
||||
using OpConversionPattern<DigammaOp>::OpConversionPattern;
|
||||
LogicalResult matchAndRewrite(
|
||||
DigammaOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
FloatType min_precision_ty = rewriter.getF32Type();
|
||||
rewriter.replaceOp(
|
||||
op, MaterializeWithUpcast(rewriter, op.getLoc(), operands,
|
||||
min_precision_ty, &MaterializeDigamma));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertZetaOp : public OpConversionPattern<ZetaOp> {
|
||||
using OpConversionPattern<ZetaOp>::OpConversionPattern;
|
||||
LogicalResult matchAndRewrite(
|
||||
ZetaOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
ZetaOpAdaptor adaptor(operands);
|
||||
Location loc = op.getLoc();
|
||||
|
||||
// Zeta is only defined on tensors of float elements and statically
|
||||
// verified that both have the same type. So it suffices to look at one
|
||||
// here.
|
||||
auto elm_type = adaptor.x().getType().cast<ShapedType>().getElementType();
|
||||
|
||||
bool needs_upcast = elm_type.isF16() || elm_type.isBF16();
|
||||
|
||||
Value x = adaptor.x();
|
||||
Value q = adaptor.q();
|
||||
|
||||
if (needs_upcast) {
|
||||
x = rewriter.create<mhlo::ConvertOp>(loc, x, rewriter.getF32Type());
|
||||
q = rewriter.create<mhlo::ConvertOp>(loc, q, rewriter.getF32Type());
|
||||
}
|
||||
Value result = MaterializeZetaComputation(rewriter, loc, x, q);
|
||||
if (needs_upcast) {
|
||||
result = rewriter.create<mhlo::ConvertOp>(loc, result, elm_type);
|
||||
}
|
||||
rewriter.replaceOp(op, {result});
|
||||
|
||||
FloatType min_precision_ty = rewriter.getF32Type();
|
||||
rewriter.replaceOp(
|
||||
op, MaterializeWithUpcast(rewriter, loc, operands, min_precision_ty,
|
||||
&MaterializeZeta));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue