[MLIR][KernelGen] Add chlo.erfc lowering for f32

PiperOrigin-RevId: 353201886
This commit is contained in:
A. Unique TensorFlower 2021-01-22 02:32:39 -08:00 committed by TensorFlow MLIR Team
parent 56758a9562
commit c846f925d4
2 changed files with 230 additions and 3 deletions

View File

@ -254,6 +254,93 @@ Value MaterializeErfcApproximationF64(ConversionPatternRewriter &rewriter,
erfc_approx); erfc_approx);
} }
// Precondition is |x| >= 1. Use erf approximation, otherwise.
//
// We rely on multiple polynomial approximations for x >= 1. We pass |x| as an
// argument and derive the final approximation for all |x| >= 1.
// This implementation is based on Cephes.
Value MaterializeErfcApproximationF32ForMagnitudeGEOne(
ConversionPatternRewriter &rewriter, Location loc, Value x) {
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
"expect f32 element type");
const double kMaxlog = 88.72283905206835;
const std::vector<float> kErfcPCoefficients{
+2.326819970068386E-2, -1.387039388740657E-1, +3.687424674597105E-1,
-5.824733027278666E-1, +6.210004621745983E-1, -4.944515323274145E-1,
+3.404879937665872E-1, -2.741127028184656E-1, +5.638259427386472E-1,
};
const std::vector<float> kErfcRCoefficients{
-1.047766399936249E+1, +1.297719955372516E+1, -7.495518717768503E+0,
+2.921019019210786E+0, -1.015265279202700E+0, +4.218463358204948E-1,
-2.820767439740514E-1, +5.641895067754075E-1,
};
// Let z = -x^2.
Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
Value z = rewriter.create<mhlo::NegOp>(loc, x_sq);
// Materialize polynomial approximation for x >= 1 as
// erfc(x) = exp(z) 1/x P(1/x^2) if x in [1, 2)
// erfc(x) = exp(z) 1/x R(1/x^2) if x >= 2
const StringAttr kLT = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
Value reciprocal_x_sq = rewriter.create<mhlo::DivOp>(loc, one, x_sq);
Value exp_z = rewriter.create<mhlo::ExpOp>(loc, z);
Value one_div_abs_x = rewriter.create<mhlo::DivOp>(loc, one, abs_x);
Value exp_z_mul_one_div_abs_x =
rewriter.create<mhlo::MulOp>(loc, exp_z, one_div_abs_x);
Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
Value abs_x_lt_two = rewriter.create<mhlo::CompareOp>(loc, abs_x, two, kLT);
Value poly_p = MaterializePolynomialApproximation(
rewriter, loc, reciprocal_x_sq, kErfcPCoefficients);
Value poly_r = MaterializePolynomialApproximation(
rewriter, loc, reciprocal_x_sq, kErfcRCoefficients);
Value poly =
rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_two, poly_p, poly_r);
Value erfc_approx =
rewriter.create<mhlo::MulOp>(loc, exp_z_mul_one_div_abs_x, poly);
// Clamp to prevent overflow and materialize approximation for large x as
// erfc(x) = 0.
Value z_lt_neq_maxlog = rewriter.create<mhlo::CompareOp>(
loc, z, chlo::getConstantLike(rewriter, loc, -kMaxlog, x), kLT);
Value zero = chlo::getConstantLike(rewriter, loc, 0.0, x);
Value erfc_approx_clamped =
rewriter.create<mhlo::SelectOp>(loc, z_lt_neq_maxlog, zero, erfc_approx);
// Derive approximation for x <= -1 as
// erfc(x) = 2 - erfc(-x).
// Reuse previously materialized approximations all of which take |x| as their
// argument.
Value x_lt_zero = rewriter.create<mhlo::CompareOp>(loc, x, zero, kLT);
Value two_sub_erfc_approx =
rewriter.create<mhlo::SubOp>(loc, two, erfc_approx_clamped);
return rewriter.create<mhlo::SelectOp>(loc, x_lt_zero, two_sub_erfc_approx,
erfc_approx_clamped);
}
// Precondition is |x| <= 1. Use erfc approximation, otherwise.
// This implementation is based on Cephes.
Value MaterializeErfApproximationF32ForMagnitudeLEOne(
ConversionPatternRewriter &rewriter, Location loc, Value x) {
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
"expect f32 element type");
const std::vector<float> kErfTCoefficients{
+7.853861353153693E-5, -8.010193625184903E-4, +5.188327685732524E-3,
-2.685381193529856E-2, +1.128358514861418E-1, -3.761262582423300E-1,
+1.128379165726710E+0,
};
// Materialize polynomial approximation for |x| <= 1 as
// erf(x) = x T(x^2).
Value x_sq = rewriter.create<mhlo::MulOp>(loc, x, x);
Value poly_t = MaterializePolynomialApproximation(rewriter, loc, x_sq,
kErfTCoefficients);
return rewriter.create<mhlo::MulOp>(loc, x, poly_t);
}
// This is the same approximation as used in Eigen. // This is the same approximation as used in Eigen.
Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter, Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
Location loc, Value operand) { Location loc, Value operand) {
@ -286,6 +373,32 @@ Value MaterializeErfApproximationF32(ConversionPatternRewriter &rewriter,
return rewriter.create<mhlo::DivOp>(loc, x_mul_alpha_poly, beta_poly); return rewriter.create<mhlo::DivOp>(loc, x_mul_alpha_poly, beta_poly);
} }
Value MaterializeErfcApproximationF32(ConversionPatternRewriter &rewriter,
Location loc, Value x) {
assert(x.getType().cast<ShapedType>().getElementType().isF32() &&
"expect f32 element type");
// Rely on erfc approximation for |x| >= 1
// erfc(x) = erfc_approx(x)
Value erfc_approx =
MaterializeErfcApproximationF32ForMagnitudeGEOne(rewriter, loc, x);
// Rely on erf approximation for |x| < 1 and materialize erfc as
// erfc(x) = 1 - erf_approx(x)
Value one = chlo::getConstantLike(rewriter, loc, 1.0, x);
Value erf_approx =
MaterializeErfApproximationF32ForMagnitudeLEOne(rewriter, loc, x);
Value erf_based_approx = rewriter.create<mhlo::SubOp>(loc, one, erf_approx);
// Materialize approximation selection based on argument.
const StringAttr kLT = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
Value abs_x_lt_one = rewriter.create<mhlo::CompareOp>(loc, abs_x, one, kLT);
return rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_one, erf_based_approx,
erfc_approx);
}
struct ConvertErfOp : public OpConversionPattern<ErfOp> { struct ConvertErfOp : public OpConversionPattern<ErfOp> {
using OpConversionPattern<ErfOp>::OpConversionPattern; using OpConversionPattern<ErfOp>::OpConversionPattern;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
@ -332,10 +445,15 @@ struct ConvertErfcOp : public OpConversionPattern<ErfcOp> {
Value x = transformed.operand(); Value x = transformed.operand();
Type ty = x.getType().cast<ShapedType>().getElementType(); Type ty = x.getType().cast<ShapedType>().getElementType();
// For now, we support only f64. // For now, we support only f64 and f32.
if (!ty.isF64()) return failure(); if (!ty.isF64() && !ty.isF32()) return failure();
rewriter.replaceOp(op, MaterializeErfcApproximationF64(rewriter, loc, x)); if (ty.isF64()) {
rewriter.replaceOp(op, MaterializeErfcApproximationF64(rewriter, loc, x));
return success();
}
rewriter.replaceOp(op, MaterializeErfcApproximationF32(rewriter, loc, x));
return success(); return success();
} }
}; };

View File

@ -533,3 +533,112 @@ func @erfc_f64(%arg : tensor<f64>) -> tensor<f64> {
%1 = "chlo.erfc"(%arg) : (tensor<f64>) -> tensor<f64> %1 = "chlo.erfc"(%arg) : (tensor<f64>) -> tensor<f64>
return %1 : tensor<f64> return %1 : tensor<f64>
} }
// CHECK-LABEL: @erfc_f32
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>
func @erfc_f32(%arg : tensor<f32>) -> tensor<f32> {
// CHECK: %[[TMP_0:.*]] = mhlo.multiply %[[ARG]], %[[ARG]]
// CHECK: %[[TMP_1:.*]] = "mhlo.negate"(%[[TMP_0]])
// CHECK: %[[TMP_2:.*]] = "mhlo.abs"(%[[ARG]])
// CHECK: %[[TMP_3:.*]] = mhlo.constant dense<1.000000e+00>
// CHECK: %[[TMP_4:.*]] = mhlo.divide %[[TMP_3]], %[[TMP_0]]
// CHECK: %[[TMP_5:.*]] = "mhlo.exponential"(%[[TMP_1]])
// CHECK: %[[TMP_7:.*]] = mhlo.divide %[[TMP_3]], %[[TMP_2]]
// CHECK: %[[TMP_8:.*]] = mhlo.multiply %[[TMP_5]], %[[TMP_7]]
// CHECK: %[[TMP_9:.*]] = mhlo.constant dense<2.000000e+00>
// CHECK: %[[TMP_10:.*]] = "mhlo.compare"(%[[TMP_2]], %[[TMP_9]]) {comparison_direction = "LT"}
// CHECK: %[[TMP_11:.*]] = mhlo.constant dense<0.000000e+00>
// CHECK: %[[TMP_12:.*]] = mhlo.multiply %[[TMP_11]], %[[TMP_4]]
// CHECK: %[[TMP_13:.*]] = mhlo.constant dense<2.326820e-02>
// CHECK: %[[TMP_14:.*]] = mhlo.add %[[TMP_12]], %[[TMP_13]]
// CHECK: %[[TMP_15:.*]] = mhlo.multiply %[[TMP_14]], %[[TMP_4]]
// CHECK: %[[TMP_16:.*]] = mhlo.constant dense<-0.138703942>
// CHECK: %[[TMP_17:.*]] = mhlo.add %[[TMP_15]], %[[TMP_16]]
// CHECK: %[[TMP_18:.*]] = mhlo.multiply %[[TMP_17]], %[[TMP_4]]
// CHECK: %[[TMP_19:.*]] = mhlo.constant dense<0.368742466>
// CHECK: %[[TMP_20:.*]] = mhlo.add %[[TMP_18]], %[[TMP_19]]
// CHECK: %[[TMP_21:.*]] = mhlo.multiply %[[TMP_20]], %[[TMP_4]]
// CHECK: %[[TMP_22:.*]] = mhlo.constant dense<-0.582473278>
// CHECK: %[[TMP_23:.*]] = mhlo.add %[[TMP_21]], %[[TMP_22]]
// CHECK: %[[TMP_24:.*]] = mhlo.multiply %[[TMP_23]], %[[TMP_4]]
// CHECK: %[[TMP_25:.*]] = mhlo.constant dense<0.621000468>
// CHECK: %[[TMP_26:.*]] = mhlo.add %[[TMP_24]], %[[TMP_25]]
// CHECK: %[[TMP_27:.*]] = mhlo.multiply %[[TMP_26]], %[[TMP_4]]
// CHECK: %[[TMP_28:.*]] = mhlo.constant dense<-0.494451523>
// CHECK: %[[TMP_29:.*]] = mhlo.add %[[TMP_27]], %[[TMP_28]]
// CHECK: %[[TMP_30:.*]] = mhlo.multiply %[[TMP_29]], %[[TMP_4]]
// CHECK: %[[TMP_31:.*]] = mhlo.constant dense<3.404880e-01>
// CHECK: %[[TMP_32:.*]] = mhlo.add %[[TMP_30]], %[[TMP_31]]
// CHECK: %[[TMP_33:.*]] = mhlo.multiply %[[TMP_32]], %[[TMP_4]]
// CHECK: %[[TMP_34:.*]] = mhlo.constant dense<-0.274112701>
// CHECK: %[[TMP_35:.*]] = mhlo.add %[[TMP_33]], %[[TMP_34]]
// CHECK: %[[TMP_36:.*]] = mhlo.multiply %[[TMP_35]], %[[TMP_4]]
// CHECK: %[[TMP_37:.*]] = mhlo.constant dense<0.563825965>
// CHECK: %[[TMP_38:.*]] = mhlo.add %[[TMP_36]], %[[TMP_37]]
// CHECK: %[[TMP_39:.*]] = mhlo.constant dense<0.000000e+00>
// CHECK: %[[TMP_40:.*]] = mhlo.multiply %[[TMP_39]], %[[TMP_4]]
// CHECK: %[[TMP_41:.*]] = mhlo.constant dense<-10.477664>
// CHECK: %[[TMP_42:.*]] = mhlo.add %[[TMP_40]], %[[TMP_41]]
// CHECK: %[[TMP_43:.*]] = mhlo.multiply %[[TMP_42]], %[[TMP_4]]
// CHECK: %[[TMP_44:.*]] = mhlo.constant dense<1.297720e+01>
// CHECK: %[[TMP_45:.*]] = mhlo.add %[[TMP_43]], %[[TMP_44]]
// CHECK: %[[TMP_46:.*]] = mhlo.multiply %[[TMP_45]], %[[TMP_4]]
// CHECK: %[[TMP_47:.*]] = mhlo.constant dense<-7.49551868>
// CHECK: %[[TMP_48:.*]] = mhlo.add %[[TMP_46]], %[[TMP_47]]
// CHECK: %[[TMP_49:.*]] = mhlo.multiply %[[TMP_48]], %[[TMP_4]]
// CHECK: %[[TMP_50:.*]] = mhlo.constant dense<2.92101908>
// CHECK: %[[TMP_51:.*]] = mhlo.add %[[TMP_49]], %[[TMP_50]]
// CHECK: %[[TMP_52:.*]] = mhlo.multiply %[[TMP_51]], %[[TMP_4]]
// CHECK: %[[TMP_53:.*]] = mhlo.constant dense<-1.01526523>
// CHECK: %[[TMP_54:.*]] = mhlo.add %[[TMP_52]], %[[TMP_53]]
// CHECK: %[[TMP_55:.*]] = mhlo.multiply %[[TMP_54]], %[[TMP_4]]
// CHECK: %[[TMP_56:.*]] = mhlo.constant dense<0.42184633>
// CHECK: %[[TMP_57:.*]] = mhlo.add %[[TMP_55]], %[[TMP_56]]
// CHECK: %[[TMP_58:.*]] = mhlo.multiply %[[TMP_57]], %[[TMP_4]]
// CHECK: %[[TMP_59:.*]] = mhlo.constant dense<-0.282076746>
// CHECK: %[[TMP_60:.*]] = mhlo.add %[[TMP_58]], %[[TMP_59]]
// CHECK: %[[TMP_61:.*]] = mhlo.multiply %[[TMP_60]], %[[TMP_4]]
// CHECK: %[[TMP_62:.*]] = mhlo.constant dense<0.564189494>
// CHECK: %[[TMP_63:.*]] = mhlo.add %[[TMP_61]], %[[TMP_62]]
// CHECK: %[[TMP_64:.*]] = "mhlo.select"(%[[TMP_10]], %[[TMP_38]], %[[TMP_63]])
// CHECK: %[[TMP_65:.*]] = mhlo.multiply %[[TMP_8]], %[[TMP_64]]
// CHECK: %[[TMP_66:.*]] = mhlo.constant dense<-88.7228394>
// CHECK: %[[TMP_67:.*]] = "mhlo.compare"(%[[TMP_1]], %[[TMP_66]]) {comparison_direction = "LT"}
// CHECK: %[[TMP_68:.*]] = mhlo.constant dense<0.000000e+00>
// CHECK: %[[TMP_69:.*]] = "mhlo.select"(%[[TMP_67]], %[[TMP_68]], %[[TMP_65]])
// CHECK: %[[TMP_71:.*]] = "mhlo.compare"(%[[ARG]], %[[TMP_68]]) {comparison_direction = "LT"}
// CHECK: %[[TMP_73:.*]] = mhlo.subtract %[[TMP_9]], %[[TMP_69]]
// CHECK: %[[TMP_74:.*]] = "mhlo.select"(%[[TMP_71]], %[[TMP_73]], %[[TMP_69]])
// CHECK: %[[TMP_75:.*]] = mhlo.constant dense<1.000000e+00>
// CHECK: %[[TMP_76:.*]] = mhlo.multiply %[[ARG]], %[[ARG]]
// CHECK: %[[TMP_77:.*]] = mhlo.constant dense<0.000000e+00>
// CHECK: %[[TMP_78:.*]] = mhlo.multiply %[[TMP_77]], %[[TMP_76]]
// CHECK: %[[TMP_79:.*]] = mhlo.constant dense<7.85386146E-5>
// CHECK: %[[TMP_80:.*]] = mhlo.add %[[TMP_78]], %[[TMP_79]]
// CHECK: %[[TMP_81:.*]] = mhlo.multiply %[[TMP_80]], %[[TMP_76]]
// CHECK: %[[TMP_82:.*]] = mhlo.constant dense<-8.0101937E-4>
// CHECK: %[[TMP_83:.*]] = mhlo.add %[[TMP_81]], %[[TMP_82]]
// CHECK: %[[TMP_84:.*]] = mhlo.multiply %[[TMP_83]], %[[TMP_76]]
// CHECK: %[[TMP_85:.*]] = mhlo.constant dense<0.00518832775>
// CHECK: %[[TMP_86:.*]] = mhlo.add %[[TMP_84]], %[[TMP_85]]
// CHECK: %[[TMP_87:.*]] = mhlo.multiply %[[TMP_86]], %[[TMP_76]]
// CHECK: %[[TMP_88:.*]] = mhlo.constant dense<-0.0268538129>
// CHECK: %[[TMP_89:.*]] = mhlo.add %[[TMP_87]], %[[TMP_88]]
// CHECK: %[[TMP_90:.*]] = mhlo.multiply %[[TMP_89]], %[[TMP_76]]
// CHECK: %[[TMP_91:.*]] = mhlo.constant dense<0.112835854>
// CHECK: %[[TMP_92:.*]] = mhlo.add %[[TMP_90]], %[[TMP_91]]
// CHECK: %[[TMP_93:.*]] = mhlo.multiply %[[TMP_92]], %[[TMP_76]]
// CHECK: %[[TMP_94:.*]] = mhlo.constant dense<-0.37612626>
// CHECK: %[[TMP_95:.*]] = mhlo.add %[[TMP_93]], %[[TMP_94]]
// CHECK: %[[TMP_96:.*]] = mhlo.multiply %[[TMP_95]], %[[TMP_76]]
// CHECK: %[[TMP_97:.*]] = mhlo.constant dense<1.12837911>
// CHECK: %[[TMP_98:.*]] = mhlo.add %[[TMP_96]], %[[TMP_97]]
// CHECK: %[[TMP_99:.*]] = mhlo.multiply %[[ARG]], %[[TMP_98]]
// CHECK: %[[TMP_100:.*]] = mhlo.subtract %[[TMP_75]], %[[TMP_99]]
// CHECK: %[[TMP_101:.*]] = "mhlo.abs"(%[[ARG]])
// CHECK: %[[TMP_103:.*]] = "mhlo.compare"(%[[TMP_101]], %[[TMP_75]]) {comparison_direction = "LT"}
// CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[TMP_103]], %[[TMP_100]], %[[TMP_74]])
// CHECK: return %[[RESULT]]
%1 = "chlo.erfc"(%arg) : (tensor<f32>) -> tensor<f32>
return %1 : tensor<f32>
}