[MLIR][CHLO] Generalize lowering with upcast to n-ary operation
Allows reuse for zeta lowering now and for the polygamma lowering soon. PiperOrigin-RevId: 357739910
This commit is contained in:
parent
81abaf364d
commit
c06de24f6c
1
BUILD
1
BUILD
|
@ -977,6 +977,7 @@ cc_library(
|
||||||
":chlo_legalize_to_hlo_inc_gen",
|
":chlo_legalize_to_hlo_inc_gen",
|
||||||
":hlo",
|
":hlo",
|
||||||
":map_chlo_to_hlo_op",
|
":map_chlo_to_hlo_op",
|
||||||
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:SCFDialect",
|
"@llvm-project//mlir:SCFDialect",
|
||||||
"@llvm-project//mlir:Shape",
|
"@llvm-project//mlir:Shape",
|
||||||
|
|
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
|
||||||
|
@ -96,7 +97,8 @@ Value MaterializePolynomialApproximation(ConversionPatternRewriter &rewriter,
|
||||||
// argument and derive the final approximation for all |x| >= 1.
|
// argument and derive the final approximation for all |x| >= 1.
|
||||||
// This implementation is based on Cephes.
|
// This implementation is based on Cephes.
|
||||||
Value MaterializeErfcApproximationF64ForMagnituteGEOne(
|
Value MaterializeErfcApproximationF64ForMagnituteGEOne(
|
||||||
ConversionPatternRewriter &rewriter, Location loc, Value x) {
|
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
|
||||||
|
Value x = args.front();
|
||||||
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
|
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
|
||||||
"expect f64 element type");
|
"expect f64 element type");
|
||||||
const double kMaxlog = 7.09782712893383996843E2;
|
const double kMaxlog = 7.09782712893383996843E2;
|
||||||
|
@ -179,7 +181,8 @@ Value MaterializeErfcApproximationF64ForMagnituteGEOne(
|
||||||
// Precondition is |x| <= 1. Use erfc approximation, otherwise.
|
// Precondition is |x| <= 1. Use erfc approximation, otherwise.
|
||||||
// This implementation is based on Cephes.
|
// This implementation is based on Cephes.
|
||||||
Value MaterializeErfApproximationF64ForMagnituteLEOne(
|
Value MaterializeErfApproximationF64ForMagnituteLEOne(
|
||||||
ConversionPatternRewriter &rewriter, Location loc, Value x) {
|
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
|
||||||
|
Value x = args.front();
|
||||||
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
|
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
|
||||||
"expect f64 element type");
|
"expect f64 element type");
|
||||||
const std::vector<double> kErfTCoefficients{
|
const std::vector<double> kErfTCoefficients{
|
||||||
|
@ -204,7 +207,8 @@ Value MaterializeErfApproximationF64ForMagnituteLEOne(
|
||||||
|
|
||||||
// This implementation is based on Cephes.
|
// This implementation is based on Cephes.
|
||||||
Value MaterializeErfApproximationF64(ConversionPatternRewriter &rewriter,
|
Value MaterializeErfApproximationF64(ConversionPatternRewriter &rewriter,
|
||||||
Location loc, Value x) {
|
Location loc, ValueRange args) {
|
||||||
|
Value x = args.front();
|
||||||
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
|
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
|
||||||
"expect f64 element type");
|
"expect f64 element type");
|
||||||
|
|
||||||
|
@ -230,7 +234,8 @@ Value MaterializeErfApproximationF64(ConversionPatternRewriter &rewriter,
|
||||||
}
|
}
|
||||||
|
|
||||||
Value MaterializeErfcApproximationF64(ConversionPatternRewriter &rewriter,
|
Value MaterializeErfcApproximationF64(ConversionPatternRewriter &rewriter,
|
||||||
Location loc, Value x) {
|
Location loc, ValueRange args) {
|
||||||
|
Value x = args.front();
|
||||||
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
|
assert(x.getType().cast<ShapedType>().getElementType().isF64() &&
|
||||||
"expect f64 element type");
|
"expect f64 element type");
|
||||||
|
|
||||||
|
@ -261,7 +266,8 @@ Value MaterializeErfcApproximationF64(ConversionPatternRewriter &rewriter,
|
||||||
// argument and derive the final approximation for all |x| >= 1.
|
// argument and derive the final approximation for all |x| >= 1.
|
||||||
// This implementation is based on Cephes.
|
// This implementation is based on Cephes.
|
||||||
Value MaterializeErfcApproximationF32ForMagnitudeGEOne(
|
Value MaterializeErfcApproximationF32ForMagnitudeGEOne(
|
||||||
ConversionPatternRewriter &rewriter, Location loc, Value x) {
|
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
|
||||||
|
Value x = args.front();
|
||||||
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
|
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
|
||||||
"expect f32 element type");
|
"expect f32 element type");
|
||||||
const double kMaxlog = 88.72283905206835;
|
const double kMaxlog = 88.72283905206835;
|
||||||
|
@ -325,7 +331,8 @@ Value MaterializeErfcApproximationF32ForMagnitudeGEOne(
|
||||||
// Precondition is |x| <= 1. Use erfc approximation, otherwise.
|
// Precondition is |x| <= 1. Use erfc approximation, otherwise.
|
||||||
// This implementation is based on Cephes.
|
// This implementation is based on Cephes.
|
||||||
Value MaterializeErfApproximationF32ForMagnitudeLEOne(
|
Value MaterializeErfApproximationF32ForMagnitudeLEOne(
|
||||||
ConversionPatternRewriter &rewriter, Location loc, Value x) {
|
ConversionPatternRewriter &rewriter, Location loc, ValueRange args) {
|
||||||
|
Value x = args.front();
|
||||||
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
|
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
|
||||||
"expect f32 element type");
|
"expect f32 element type");
|
||||||
const std::vector<float> kErfTCoefficients{
|
const std::vector<float> kErfTCoefficients{
|
||||||
|
@ -344,8 +351,9 @@ Value MaterializeErfApproximationF32ForMagnitudeLEOne(
|
||||||
|
|
||||||
// This is the same approximation as used in Eigen.
|
// This is the same approximation as used in Eigen.
|
||||||
Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
|
Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
|
||||||
Location loc, Value operand) {
|
Location loc, ValueRange args) {
|
||||||
assert(operand.getType().cast<ShapedType>().getElementType().isF32() &&
|
Value x = args.front();
|
||||||
|
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
|
||||||
"expect f32 element type");
|
"expect f32 element type");
|
||||||
const std::vector<float> kAlpha{
|
const std::vector<float> kAlpha{
|
||||||
-2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
|
-2.72614225801306e-10f, 2.77068142495902e-08f, -2.10102402082508e-06f,
|
||||||
|
@ -358,10 +366,9 @@ Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Clamp argument between -4 and 4.
|
// Clamp argument between -4 and 4.
|
||||||
Value lb = chlo::getConstantLike(rewriter, loc, -4.0, operand);
|
Value lb = chlo::getConstantLike(rewriter, loc, -4.0, x);
|
||||||
Value ub = chlo::getConstantLike(rewriter, loc, 4.0, operand);
|
Value ub = chlo::getConstantLike(rewriter, loc, 4.0, x);
|
||||||
Value x =
|
x = rewriter.create<mhlo::ClampOp>(loc, x.getType(), lb, x, ub);
|
||||||
rewriter.create<mhlo::ClampOp>(loc, operand.getType(), lb, operand, ub);
|
|
||||||
Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
|
Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
|
||||||
|
|
||||||
// Materialize polynomial approximation for x in [-4, 4] as
|
// Materialize polynomial approximation for x in [-4, 4] as
|
||||||
|
@ -375,7 +382,8 @@ Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
|
||||||
}
|
}
|
||||||
|
|
||||||
Value MaterializeErfcApproximationF32(ConversionPatternRewriter &rewriter,
|
Value MaterializeErfcApproximationF32(ConversionPatternRewriter &rewriter,
|
||||||
Location loc, Value x) {
|
Location loc, ValueRange args) {
|
||||||
|
Value x = args.front();
|
||||||
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
|
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
|
||||||
"expect f32 element type");
|
"expect f32 element type");
|
||||||
|
|
||||||
|
@ -401,18 +409,30 @@ Value MaterializeErfcApproximationF32(ConversionPatternRewriter &rewriter,
|
||||||
}
|
}
|
||||||
|
|
||||||
Value MaterializeWithUpcast(ConversionPatternRewriter &rewriter, Location loc,
|
Value MaterializeWithUpcast(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
Value arg, FloatType min_precision_ty,
|
ValueRange args, FloatType min_precision_ty,
|
||||||
Value callback(ConversionPatternRewriter &,
|
Value callback(ConversionPatternRewriter &,
|
||||||
Location, Value)) {
|
Location, ValueRange)) {
|
||||||
auto original_ty = getElementTypeOrSelf(arg.getType()).cast<FloatType>();
|
auto original_ty =
|
||||||
|
getElementTypeOrSelf(args.front().getType()).cast<FloatType>();
|
||||||
bool needs_upcast = original_ty.getWidth() < min_precision_ty.getWidth();
|
bool needs_upcast = original_ty.getWidth() < min_precision_ty.getWidth();
|
||||||
|
|
||||||
|
// Upcast arguments if necessary.
|
||||||
|
llvm::SmallVector<Value, 2> casted_args;
|
||||||
if (needs_upcast) {
|
if (needs_upcast) {
|
||||||
arg = rewriter.create<mhlo::ConvertOp>(loc, arg, min_precision_ty);
|
for (Value a : args) {
|
||||||
|
casted_args.push_back(
|
||||||
|
rewriter.create<mhlo::ConvertOp>(loc, a, min_precision_ty));
|
||||||
|
}
|
||||||
|
args = casted_args;
|
||||||
}
|
}
|
||||||
Value result = callback(rewriter, loc, arg);
|
|
||||||
|
Value result = callback(rewriter, loc, args);
|
||||||
|
|
||||||
|
// Cast back if necessary.
|
||||||
if (needs_upcast) {
|
if (needs_upcast) {
|
||||||
result = rewriter.create<mhlo::ConvertOp>(loc, result, original_ty);
|
result = rewriter.create<mhlo::ConvertOp>(loc, result, original_ty);
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -434,9 +454,9 @@ struct ConvertErfOp : public OpConversionPattern<ErfOp> {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(
|
rewriter.replaceOp(op, MaterializeWithUpcast(
|
||||||
op, MaterializeWithUpcast(rewriter, loc, x, rewriter.getF32Type(),
|
rewriter, loc, operands, rewriter.getF32Type(),
|
||||||
&MaterializeErfApproximationF32));
|
&MaterializeErfApproximationF32));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -459,9 +479,9 @@ struct ConvertErfcOp : public OpConversionPattern<ErfcOp> {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(
|
rewriter.replaceOp(op, MaterializeWithUpcast(
|
||||||
op, MaterializeWithUpcast(rewriter, loc, x, rewriter.getF32Type(),
|
rewriter, loc, operands, rewriter.getF32Type(),
|
||||||
&MaterializeErfcApproximationF32));
|
&MaterializeErfcApproximationF32));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -491,12 +511,13 @@ constexpr std::array<double, 8> kLanczosCoefficients = {
|
||||||
// a(z) = kBaseLanczosCoeff
|
// a(z) = kBaseLanczosCoeff
|
||||||
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
|
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
|
||||||
Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc,
|
Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
Value x) {
|
ValueRange args) {
|
||||||
// If the input is less than 0.5 use Euler's reflection formula.
|
// If the input is less than 0.5 use Euler's reflection formula.
|
||||||
// gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
|
// gamma(x) = pi / (sin(pi * x) * gamma(1 - x))
|
||||||
// Let z be
|
// Let z be
|
||||||
// z = -x if x < 1/2
|
// z = -x if x < 1/2
|
||||||
// z = x - 1 otheriwse
|
// z = x - 1 otheriwse
|
||||||
|
Value x = args.front();
|
||||||
const StringAttr kLT = rewriter.getStringAttr(
|
const StringAttr kLT = rewriter.getStringAttr(
|
||||||
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
|
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
|
||||||
Value half = getConstantLike(rewriter, loc, 0.5, x);
|
Value half = getConstantLike(rewriter, loc, 0.5, x);
|
||||||
|
@ -635,12 +656,13 @@ Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
|
// + sum(k = 1, n, kLanczosCoefficients[i] / (z + k))
|
||||||
// a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
|
// a'(z) = - sum(k = 1, n, kLanczosCoefficients[i] / (z + k) / (z + k))
|
||||||
Value MaterializeDigamma(ConversionPatternRewriter &rewriter, Location loc,
|
Value MaterializeDigamma(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
Value x) {
|
ValueRange args) {
|
||||||
// If the input is less than 0.5 use Euler's reflection formula.
|
// If the input is less than 0.5 use Euler's reflection formula.
|
||||||
// digamma(x) = digamma(1 - x) - pi * cot(pi * x)
|
// digamma(x) = digamma(1 - x) - pi * cot(pi * x)
|
||||||
// Let z be
|
// Let z be
|
||||||
// z = -x if x < 1/2
|
// z = -x if x < 1/2
|
||||||
// z = x - 1 otheriwse
|
// z = x - 1 otheriwse
|
||||||
|
Value x = args.front();
|
||||||
const StringAttr kLT = rewriter.getStringAttr(
|
const StringAttr kLT = rewriter.getStringAttr(
|
||||||
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
|
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
|
||||||
Value half = getConstantLike(rewriter, loc, 0.5, x);
|
Value half = getConstantLike(rewriter, loc, 0.5, x);
|
||||||
|
@ -739,36 +761,11 @@ Value MaterializeDigamma(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
digamma);
|
digamma);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> {
|
Value MaterializeZeta(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
using OpConversionPattern<LgammaOp>::OpConversionPattern;
|
ValueRange args) {
|
||||||
LogicalResult matchAndRewrite(
|
assert(args.size() == 2);
|
||||||
LgammaOp op, ArrayRef<Value> operands,
|
Value x = args[0];
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
Value q = args[1];
|
||||||
LgammaOp::Adaptor transformed(operands);
|
|
||||||
FloatType min_precision_ty = rewriter.getF32Type();
|
|
||||||
rewriter.replaceOp(
|
|
||||||
op, MaterializeWithUpcast(rewriter, op.getLoc(), transformed.operand(),
|
|
||||||
min_precision_ty, &MaterializeLgamma));
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ConvertDigammaOp : public OpConversionPattern<DigammaOp> {
|
|
||||||
using OpConversionPattern<DigammaOp>::OpConversionPattern;
|
|
||||||
LogicalResult matchAndRewrite(
|
|
||||||
DigammaOp op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
DigammaOp::Adaptor transformed(operands);
|
|
||||||
FloatType min_precision_ty = rewriter.getF32Type();
|
|
||||||
rewriter.replaceOp(
|
|
||||||
op, MaterializeWithUpcast(rewriter, op.getLoc(), transformed.operand(),
|
|
||||||
min_precision_ty, &MaterializeDigamma));
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
Value MaterializeZetaComputation(ConversionPatternRewriter &rewriter,
|
|
||||||
Location loc, Value x, Value q) {
|
|
||||||
static const std::array<double, 12> kZetaCoeffs{
|
static const std::array<double, 12> kZetaCoeffs{
|
||||||
-7.1661652561756670113e18,
|
-7.1661652561756670113e18,
|
||||||
1.8152105401943546773e17,
|
1.8152105401943546773e17,
|
||||||
|
@ -897,34 +894,42 @@ Value MaterializeZetaComputation(ConversionPatternRewriter &rewriter,
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ConvertLgammaOp : public OpConversionPattern<LgammaOp> {
|
||||||
|
using OpConversionPattern<LgammaOp>::OpConversionPattern;
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
LgammaOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
FloatType min_precision_ty = rewriter.getF32Type();
|
||||||
|
rewriter.replaceOp(
|
||||||
|
op, MaterializeWithUpcast(rewriter, op.getLoc(), operands,
|
||||||
|
min_precision_ty, &MaterializeLgamma));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ConvertDigammaOp : public OpConversionPattern<DigammaOp> {
|
||||||
|
using OpConversionPattern<DigammaOp>::OpConversionPattern;
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
DigammaOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
FloatType min_precision_ty = rewriter.getF32Type();
|
||||||
|
rewriter.replaceOp(
|
||||||
|
op, MaterializeWithUpcast(rewriter, op.getLoc(), operands,
|
||||||
|
min_precision_ty, &MaterializeDigamma));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct ConvertZetaOp : public OpConversionPattern<ZetaOp> {
|
struct ConvertZetaOp : public OpConversionPattern<ZetaOp> {
|
||||||
using OpConversionPattern<ZetaOp>::OpConversionPattern;
|
using OpConversionPattern<ZetaOp>::OpConversionPattern;
|
||||||
LogicalResult matchAndRewrite(
|
LogicalResult matchAndRewrite(
|
||||||
ZetaOp op, ArrayRef<Value> operands,
|
ZetaOp op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
ZetaOpAdaptor adaptor(operands);
|
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
|
FloatType min_precision_ty = rewriter.getF32Type();
|
||||||
// Zeta is only defined on tensors of float elements and statically
|
rewriter.replaceOp(
|
||||||
// verified that both have the same type. So it suffices to look at one
|
op, MaterializeWithUpcast(rewriter, loc, operands, min_precision_ty,
|
||||||
// here.
|
&MaterializeZeta));
|
||||||
auto elm_type = adaptor.x().getType().cast<ShapedType>().getElementType();
|
|
||||||
|
|
||||||
bool needs_upcast = elm_type.isF16() || elm_type.isBF16();
|
|
||||||
|
|
||||||
Value x = adaptor.x();
|
|
||||||
Value q = adaptor.q();
|
|
||||||
|
|
||||||
if (needs_upcast) {
|
|
||||||
x = rewriter.create<mhlo::ConvertOp>(loc, x, rewriter.getF32Type());
|
|
||||||
q = rewriter.create<mhlo::ConvertOp>(loc, q, rewriter.getF32Type());
|
|
||||||
}
|
|
||||||
Value result = MaterializeZetaComputation(rewriter, loc, x, q);
|
|
||||||
if (needs_upcast) {
|
|
||||||
result = rewriter.create<mhlo::ConvertOp>(loc, result, elm_type);
|
|
||||||
}
|
|
||||||
rewriter.replaceOp(op, {result});
|
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue