[MLIR][KernelGen] Add chlo.erfc lowering for f32
PiperOrigin-RevId: 353201886
This commit is contained in:
parent
56758a9562
commit
c846f925d4
|
@ -254,6 +254,93 @@ Value MaterializeErfcApproximationF64(ConversionPatternRewriter &rewriter,
|
|||
erfc_approx);
|
||||
}
|
||||
|
||||
// Precondition is |x| >= 1. Use erf approximation, otherwise.
|
||||
//
|
||||
// We rely on multiple polynomial approximations for x >= 1. We pass |x| as an
|
||||
// argument and derive the final approximation for all |x| >= 1.
|
||||
// This implementation is based on Cephes.
|
||||
Value MaterializeErfcApproximationF32ForMagnitudeGEOne(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Value x) {
|
||||
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
|
||||
"expect f32 element type");
|
||||
const double kMaxlog = 88.72283905206835;
|
||||
const std::vector<float> kErfcPCoefficients{
|
||||
+2.326819970068386E-2, -1.387039388740657E-1, +3.687424674597105E-1,
|
||||
-5.824733027278666E-1, +6.210004621745983E-1, -4.944515323274145E-1,
|
||||
+3.404879937665872E-1, -2.741127028184656E-1, +5.638259427386472E-1,
|
||||
};
|
||||
const std::vector<float> kErfcRCoefficients{
|
||||
-1.047766399936249E+1, +1.297719955372516E+1, -7.495518717768503E+0,
|
||||
+2.921019019210786E+0, -1.015265279202700E+0, +4.218463358204948E-1,
|
||||
-2.820767439740514E-1, +5.641895067754075E-1,
|
||||
};
|
||||
|
||||
// Let z = -x^2.
|
||||
Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
|
||||
Value z = rewriter.create<mhlo::NegOp>(loc, x_sq);
|
||||
|
||||
// Materialize polynomial approximation for x >= 1 as
|
||||
// erfc(x) = exp(z) 1/x P(1/x^2) if x in [1, 2)
|
||||
// erfc(x) = exp(z) 1/x R(1/x^2) if x >= 2
|
||||
const StringAttr kLT = rewriter.getStringAttr(
|
||||
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
|
||||
Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
|
||||
Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
|
||||
Value reciprocal_x_sq = rewriter.create<mhlo::DivOp>(loc, one, x_sq);
|
||||
Value exp_z = rewriter.create<mhlo::ExpOp>(loc, z);
|
||||
Value one_div_abs_x = rewriter.create<mhlo::DivOp>(loc, one, abs_x);
|
||||
Value exp_z_mul_one_div_abs_x =
|
||||
rewriter.create<mhlo::MulOp>(loc, exp_z, one_div_abs_x);
|
||||
Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
|
||||
Value abs_x_lt_two = rewriter.create<mhlo::CompareOp>(loc, abs_x, two, kLT);
|
||||
Value poly_p = MaterializePolynomialApproximation(
|
||||
rewriter, loc, reciprocal_x_sq, kErfcPCoefficients);
|
||||
Value poly_r = MaterializePolynomialApproximation(
|
||||
rewriter, loc, reciprocal_x_sq, kErfcRCoefficients);
|
||||
Value poly =
|
||||
rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_two, poly_p, poly_r);
|
||||
Value erfc_approx =
|
||||
rewriter.create<mhlo::MulOp>(loc, exp_z_mul_one_div_abs_x, poly);
|
||||
|
||||
// Clamp to prevent overflow and materialize approximation for large x as
|
||||
// erfc(x) = 0.
|
||||
Value z_lt_neq_maxlog = rewriter.create<mhlo::CompareOp>(
|
||||
loc, z, chlo::getConstantLike(rewriter, loc, -kMaxlog, x), kLT);
|
||||
Value zero = chlo::getConstantLike(rewriter, loc, 0.0, x);
|
||||
Value erfc_approx_clamped =
|
||||
rewriter.create<mhlo::SelectOp>(loc, z_lt_neq_maxlog, zero, erfc_approx);
|
||||
|
||||
// Derive approximation for x <= -1 as
|
||||
// erfc(x) = 2 - erfc(-x).
|
||||
// Reuse previously materialized approximations all of which take |x| as their
|
||||
// argument.
|
||||
Value x_lt_zero = rewriter.create<mhlo::CompareOp>(loc, x, zero, kLT);
|
||||
Value two_sub_erfc_approx =
|
||||
rewriter.create<mhlo::SubOp>(loc, two, erfc_approx_clamped);
|
||||
return rewriter.create<mhlo::SelectOp>(loc, x_lt_zero, two_sub_erfc_approx,
|
||||
erfc_approx_clamped);
|
||||
}
|
||||
|
||||
// Precondition is |x| <= 1. Use erfc approximation, otherwise.
|
||||
// This implementation is based on Cephes.
|
||||
Value MaterializeErfApproximationF32ForMagnitudeLEOne(
|
||||
ConversionPatternRewriter &rewriter, Location loc, Value x) {
|
||||
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
|
||||
"expect f32 element type");
|
||||
const std::vector<float> kErfTCoefficients{
|
||||
+7.853861353153693E-5, -8.010193625184903E-4, +5.188327685732524E-3,
|
||||
-2.685381193529856E-2, +1.128358514861418E-1, -3.761262582423300E-1,
|
||||
+1.128379165726710E+0,
|
||||
};
|
||||
|
||||
// Materialize polynomial approximation for |x| <= 1 as
|
||||
// erf(x) = x T(x^2).
|
||||
Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
|
||||
Value poly_t = MaterializePolynomialApproximation(rewriter, loc, x_sq,
|
||||
kErfTCoefficients);
|
||||
return rewriter.create<mhlo::MulOp>(loc, x, poly_t);
|
||||
}
|
||||
|
||||
// This is the same approximation as used in Eigen.
|
||||
Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value operand) {
|
||||
|
@ -286,6 +373,32 @@ Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
|
|||
return rewriter.create<mhlo::DivOp>(loc, x_mul_alpha_poly, beta_poly);
|
||||
}
|
||||
|
||||
Value MaterializeErfcApproximationF32(ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value x) {
|
||||
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
|
||||
"expect f32 element type");
|
||||
|
||||
// Rely on erfc approximation for |x| >= 1
|
||||
// erfc(x) = erfc_approx(x)
|
||||
Value erfc_approx =
|
||||
MaterializeErfcApproximationF32ForMagnitudeGEOne(rewriter, loc, x);
|
||||
|
||||
// Rely on erf approximation for |x| < 1 and materialize erfc as
|
||||
// erfc(x) = 1 - erf_approx(x)
|
||||
Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
|
||||
Value erf_approx =
|
||||
MaterializeErfApproximationF32ForMagnitudeLEOne(rewriter, loc, x);
|
||||
Value erf_based_approx = rewriter.create<mhlo::SubOp>(loc, one, erf_approx);
|
||||
|
||||
// Materialize approximation selection based on argument.
|
||||
const StringAttr kLT = rewriter.getStringAttr(
|
||||
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
|
||||
Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
|
||||
Value abs_x_lt_one = rewriter.create<mhlo::CompareOp>(loc, abs_x, one, kLT);
|
||||
return rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_one, erf_based_approx,
|
||||
erfc_approx);
|
||||
}
|
||||
|
||||
struct ConvertErfOp : public OpConversionPattern<ErfOp> {
|
||||
using OpConversionPattern<ErfOp>::OpConversionPattern;
|
||||
LogicalResult matchAndRewrite(
|
||||
|
@ -332,10 +445,15 @@ struct ConvertErfcOp : public OpConversionPattern<ErfcOp> {
|
|||
Value x = transformed.operand();
|
||||
Type ty = x.getType().cast<ShapedType>().getElementType();
|
||||
|
||||
// For now, we support only f64.
|
||||
if (!ty.isF64()) return failure();
|
||||
// For now, we support only f64 and f32.
|
||||
if (!ty.isF64() && !ty.isF32()) return failure();
|
||||
|
||||
rewriter.replaceOp(op, MaterializeErfcApproximationF64(rewriter, loc, x));
|
||||
if (ty.isF64()) {
|
||||
rewriter.replaceOp(op, MaterializeErfcApproximationF64(rewriter, loc, x));
|
||||
return success();
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, MaterializeErfcApproximationF32(rewriter, loc, x));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -533,3 +533,112 @@ func @erfc_f64(%arg : tensor<f64>) -> tensor<f64> {
|
|||
%1 = "chlo.erfc"(%arg) : (tensor<f64>) -> tensor<f64>
|
||||
return %1 : tensor<f64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @erfc_f32
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>
|
||||
func @erfc_f32(%arg : tensor<f32>) -> tensor<f32> {
|
||||
// CHECK: %[[TMP_0:.*]] = mhlo.multiply %[[ARG]], %[[ARG]]
|
||||
// CHECK: %[[TMP_1:.*]] = "mhlo.negate"(%[[TMP_0]])
|
||||
// CHECK: %[[TMP_2:.*]] = "mhlo.abs"(%[[ARG]])
|
||||
// CHECK: %[[TMP_3:.*]] = mhlo.constant dense<1.000000e+00>
|
||||
// CHECK: %[[TMP_4:.*]] = mhlo.divide %[[TMP_3]], %[[TMP_0]]
|
||||
// CHECK: %[[TMP_5:.*]] = "mhlo.exponential"(%[[TMP_1]])
|
||||
// CHECK: %[[TMP_7:.*]] = mhlo.divide %[[TMP_3]], %[[TMP_2]]
|
||||
// CHECK: %[[TMP_8:.*]] = mhlo.multiply %[[TMP_5]], %[[TMP_7]]
|
||||
// CHECK: %[[TMP_9:.*]] = mhlo.constant dense<2.000000e+00>
|
||||
// CHECK: %[[TMP_10:.*]] = "mhlo.compare"(%[[TMP_2]], %[[TMP_9]]) {comparison_direction = "LT"}
|
||||
// CHECK: %[[TMP_11:.*]] = mhlo.constant dense<0.000000e+00>
|
||||
// CHECK: %[[TMP_12:.*]] = mhlo.multiply %[[TMP_11]], %[[TMP_4]]
|
||||
// CHECK: %[[TMP_13:.*]] = mhlo.constant dense<2.326820e-02>
|
||||
// CHECK: %[[TMP_14:.*]] = mhlo.add %[[TMP_12]], %[[TMP_13]]
|
||||
// CHECK: %[[TMP_15:.*]] = mhlo.multiply %[[TMP_14]], %[[TMP_4]]
|
||||
// CHECK: %[[TMP_16:.*]] = mhlo.constant dense<-0.138703942>
|
||||
// CHECK: %[[TMP_17:.*]] = mhlo.add %[[TMP_15]], %[[TMP_16]]
|
||||
// CHECK: %[[TMP_18:.*]] = mhlo.multiply %[[TMP_17]], %[[TMP_4]]
|
||||
// CHECK: %[[TMP_19:.*]] = mhlo.constant dense<0.368742466>
|
||||
// CHECK: %[[TMP_20:.*]] = mhlo.add %[[TMP_18]], %[[TMP_19]]
|
||||
// CHECK: %[[TMP_21:.*]] = mhlo.multiply %[[TMP_20]], %[[TMP_4]]
|
||||
// CHECK: %[[TMP_22:.*]] = mhlo.constant dense<-0.582473278>
|
||||
// CHECK: %[[TMP_23:.*]] = mhlo.add %[[TMP_21]], %[[TMP_22]]
|
||||
// CHECK: %[[TMP_24:.*]] = mhlo.multiply %[[TMP_23]], %[[TMP_4]]
|
||||
// CHECK: %[[TMP_25:.*]] = mhlo.constant dense<0.621000468>
|
||||
// CHECK: %[[TMP_26:.*]] = mhlo.add %[[TMP_24]], %[[TMP_25]]
|
||||
// CHECK: %[[TMP_27:.*]] = mhlo.multiply %[[TMP_26]], %[[TMP_4]]
|
||||
// CHECK: %[[TMP_28:.*]] = mhlo.constant dense<-0.494451523>
|
||||
// CHECK: %[[TMP_29:.*]] = mhlo.add %[[TMP_27]], %[[TMP_28]]
|
||||
// CHECK: %[[TMP_30:.*]] = mhlo.multiply %[[TMP_29]], %[[TMP_4]]
|
||||
// CHECK: %[[TMP_31:.*]] = mhlo.constant dense<3.404880e-01>
|
||||
// CHECK: %[[TMP_32:.*]] = mhlo.add %[[TMP_30]], %[[TMP_31]]
|
||||
// CHECK: %[[TMP_33:.*]] = mhlo.multiply %[[TMP_32]], %[[TMP_4]]
|
||||
// CHECK: %[[TMP_34:.*]] = mhlo.constant dense<-0.274112701>
|
||||
// CHECK: %[[TMP_35:.*]] = mhlo.add %[[TMP_33]], %[[TMP_34]]
|
||||
// CHECK: %[[TMP_36:.*]] = mhlo.multiply %[[TMP_35]], %[[TMP_4]]
|
||||
// CHECK: %[[TMP_37:.*]] = mhlo.constant dense<0.563825965>
|
||||
// CHECK: %[[TMP_38:.*]] = mhlo.add %[[TMP_36]], %[[TMP_37]]
|
||||
// CHECK: %[[TMP_39:.*]] = mhlo.constant dense<0.000000e+00>
|
||||
// CHECK: %[[TMP_40:.*]] = mhlo.multiply %[[TMP_39]], %[[TMP_4]]
|
||||
// CHECK: %[[TMP_41:.*]] = mhlo.constant dense<-10.477664>
|
||||
// CHECK: %[[TMP_42:.*]] = mhlo.add %[[TMP_40]], %[[TMP_41]]
|
||||
// CHECK: %[[TMP_43:.*]] = mhlo.multiply %[[TMP_42]], %[[TMP_4]]
|
||||
// CHECK: %[[TMP_44:.*]] = mhlo.constant dense<1.297720e+01>
|
||||
// CHECK: %[[TMP_45:.*]] = mhlo.add %[[TMP_43]], %[[TMP_44]]
|
||||
// CHECK: %[[TMP_46:.*]] = mhlo.multiply %[[TMP_45]], %[[TMP_4]]
|
||||
// CHECK: %[[TMP_47:.*]] = mhlo.constant dense<-7.49551868>
|
||||
// CHECK: %[[TMP_48:.*]] = mhlo.add %[[TMP_46]], %[[TMP_47]]
|
||||
// CHECK: %[[TMP_49:.*]] = mhlo.multiply %[[TMP_48]], %[[TMP_4]]
|
||||
// CHECK: %[[TMP_50:.*]] = mhlo.constant dense<2.92101908>
|
||||
// CHECK: %[[TMP_51:.*]] = mhlo.add %[[TMP_49]], %[[TMP_50]]
|
||||
// CHECK: %[[TMP_52:.*]] = mhlo.multiply %[[TMP_51]], %[[TMP_4]]
|
||||
// CHECK: %[[TMP_53:.*]] = mhlo.constant dense<-1.01526523>
|
||||
// CHECK: %[[TMP_54:.*]] = mhlo.add %[[TMP_52]], %[[TMP_53]]
|
||||
// CHECK: %[[TMP_55:.*]] = mhlo.multiply %[[TMP_54]], %[[TMP_4]]
|
||||
// CHECK: %[[TMP_56:.*]] = mhlo.constant dense<0.42184633>
|
||||
// CHECK: %[[TMP_57:.*]] = mhlo.add %[[TMP_55]], %[[TMP_56]]
|
||||
// CHECK: %[[TMP_58:.*]] = mhlo.multiply %[[TMP_57]], %[[TMP_4]]
|
||||
// CHECK: %[[TMP_59:.*]] = mhlo.constant dense<-0.282076746>
|
||||
// CHECK: %[[TMP_60:.*]] = mhlo.add %[[TMP_58]], %[[TMP_59]]
|
||||
// CHECK: %[[TMP_61:.*]] = mhlo.multiply %[[TMP_60]], %[[TMP_4]]
|
||||
// CHECK: %[[TMP_62:.*]] = mhlo.constant dense<0.564189494>
|
||||
// CHECK: %[[TMP_63:.*]] = mhlo.add %[[TMP_61]], %[[TMP_62]]
|
||||
// CHECK: %[[TMP_64:.*]] = "mhlo.select"(%[[TMP_10]], %[[TMP_38]], %[[TMP_63]])
|
||||
// CHECK: %[[TMP_65:.*]] = mhlo.multiply %[[TMP_8]], %[[TMP_64]]
|
||||
// CHECK: %[[TMP_66:.*]] = mhlo.constant dense<-88.7228394>
|
||||
// CHECK: %[[TMP_67:.*]] = "mhlo.compare"(%[[TMP_1]], %[[TMP_66]]) {comparison_direction = "LT"}
|
||||
// CHECK: %[[TMP_68:.*]] = mhlo.constant dense<0.000000e+00>
|
||||
// CHECK: %[[TMP_69:.*]] = "mhlo.select"(%[[TMP_67]], %[[TMP_68]], %[[TMP_65]])
|
||||
// CHECK: %[[TMP_71:.*]] = "mhlo.compare"(%[[ARG]], %[[TMP_68]]) {comparison_direction = "LT"}
|
||||
// CHECK: %[[TMP_73:.*]] = mhlo.subtract %[[TMP_9]], %[[TMP_69]]
|
||||
// CHECK: %[[TMP_74:.*]] = "mhlo.select"(%[[TMP_71]], %[[TMP_73]], %[[TMP_69]])
|
||||
// CHECK: %[[TMP_75:.*]] = mhlo.constant dense<1.000000e+00>
|
||||
// CHECK: %[[TMP_76:.*]] = mhlo.multiply %[[ARG]], %[[ARG]]
|
||||
// CHECK: %[[TMP_77:.*]] = mhlo.constant dense<0.000000e+00>
|
||||
// CHECK: %[[TMP_78:.*]] = mhlo.multiply %[[TMP_77]], %[[TMP_76]]
|
||||
// CHECK: %[[TMP_79:.*]] = mhlo.constant dense<7.85386146E-5>
|
||||
// CHECK: %[[TMP_80:.*]] = mhlo.add %[[TMP_78]], %[[TMP_79]]
|
||||
// CHECK: %[[TMP_81:.*]] = mhlo.multiply %[[TMP_80]], %[[TMP_76]]
|
||||
// CHECK: %[[TMP_82:.*]] = mhlo.constant dense<-8.0101937E-4>
|
||||
// CHECK: %[[TMP_83:.*]] = mhlo.add %[[TMP_81]], %[[TMP_82]]
|
||||
// CHECK: %[[TMP_84:.*]] = mhlo.multiply %[[TMP_83]], %[[TMP_76]]
|
||||
// CHECK: %[[TMP_85:.*]] = mhlo.constant dense<0.00518832775>
|
||||
// CHECK: %[[TMP_86:.*]] = mhlo.add %[[TMP_84]], %[[TMP_85]]
|
||||
// CHECK: %[[TMP_87:.*]] = mhlo.multiply %[[TMP_86]], %[[TMP_76]]
|
||||
// CHECK: %[[TMP_88:.*]] = mhlo.constant dense<-0.0268538129>
|
||||
// CHECK: %[[TMP_89:.*]] = mhlo.add %[[TMP_87]], %[[TMP_88]]
|
||||
// CHECK: %[[TMP_90:.*]] = mhlo.multiply %[[TMP_89]], %[[TMP_76]]
|
||||
// CHECK: %[[TMP_91:.*]] = mhlo.constant dense<0.112835854>
|
||||
// CHECK: %[[TMP_92:.*]] = mhlo.add %[[TMP_90]], %[[TMP_91]]
|
||||
// CHECK: %[[TMP_93:.*]] = mhlo.multiply %[[TMP_92]], %[[TMP_76]]
|
||||
// CHECK: %[[TMP_94:.*]] = mhlo.constant dense<-0.37612626>
|
||||
// CHECK: %[[TMP_95:.*]] = mhlo.add %[[TMP_93]], %[[TMP_94]]
|
||||
// CHECK: %[[TMP_96:.*]] = mhlo.multiply %[[TMP_95]], %[[TMP_76]]
|
||||
// CHECK: %[[TMP_97:.*]] = mhlo.constant dense<1.12837911>
|
||||
// CHECK: %[[TMP_98:.*]] = mhlo.add %[[TMP_96]], %[[TMP_97]]
|
||||
// CHECK: %[[TMP_99:.*]] = mhlo.multiply %[[ARG]], %[[TMP_98]]
|
||||
// CHECK: %[[TMP_100:.*]] = mhlo.subtract %[[TMP_75]], %[[TMP_99]]
|
||||
// CHECK: %[[TMP_101:.*]] = "mhlo.abs"(%[[ARG]])
|
||||
// CHECK: %[[TMP_103:.*]] = "mhlo.compare"(%[[TMP_101]], %[[TMP_75]]) {comparison_direction = "LT"}
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[TMP_103]], %[[TMP_100]], %[[TMP_74]])
|
||||
// CHECK: return %[[RESULT]]
|
||||
%1 = "chlo.erfc"(%arg) : (tensor<f32>) -> tensor<f32>
|
||||
return %1 : tensor<f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue