From 2be112a603dd60716b34d49ced687cd6c1c8ef79 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 16 Mar 2021 10:11:47 -0700 Subject: [PATCH] [MLIR][MHLO] Approximate `tf.Tanh` as constant +/-1 for small/large values Fix issue raised in https://github.com/tensorflow/tensorflow/issues/47724 PiperOrigin-RevId: 363210296 --- ...legalize_trigonometric_to_approximation.cc | 67 ++++---- ...galize-trigonometric-to-approximation.mlir | 155 ++++++------------ 2 files changed, 90 insertions(+), 132 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc b/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc index 64be60f..a6e0b78 100644 --- a/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc +++ b/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc @@ -74,6 +74,9 @@ class ApproximateOnExtendedF32Lowering : public OpRewritePattern { } }; +// This approximation resembles Eigen and realizes a constant approximation for +// the +/-1 limits on top. +// https://gitlab.com/libeigen/eigen/-/blob/master/Eigen/src/Core/MathFunctionsImpl.h class ApproximateTanhLowering : public ApproximateOnExtendedF32Lowering { public: @@ -83,42 +86,18 @@ class ApproximateTanhLowering // Emits the fast tanh approximation that is also used by XLA. Value emitApproximation(ValueRange args, Location loc, PatternRewriter &rewriter) const override { - // For small values of x, we can approximate tanh(x) = x. For extremely - // small values of x (|x| < 1e-37), the other approximation would evaluate - // tanh(x) = 0. Value input = args.front(); assert(input.getType().isF32()); - constexpr float kCanUseApprox = 0.0004; - Value abs_value = rewriter.create(loc, input); - Value can_use_approx = rewriter.create( - loc, rewriter.getF32FloatAttr(kCanUseApprox)); - Value return_input = rewriter.create(loc, CmpFPredicate::OLT, - abs_value, can_use_approx); - // Clamp the input to [-c, c]. - Value max_clamp = rewriter.create( - loc, rewriter.getF32FloatAttr(7.90531110763549805f)); - Value smaller_than_max = - rewriter.create(loc, CmpFPredicate::ULE, input, max_clamp); - Value clamped_half = - rewriter.create(loc, smaller_than_max, input, max_clamp); - Value min_clamp = rewriter.create( - loc, rewriter.getF32FloatAttr(-7.90531110763549805f)); - Value larger_than_min = rewriter.create(loc, CmpFPredicate::UGE, - clamped_half, min_clamp); - Value input_clamped = rewriter.create(loc, larger_than_min, - clamped_half, min_clamp); - static constexpr std::array numerator_coeffs{ -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f, 4.89352455891786e-03f}; - static constexpr std::array denominator_coeffs{ 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f, 4.89352518554385e-03f}; - Value input_squared = - rewriter.create(loc, input_clamped, input_clamped); + // Materialize polynomial approximation. + Value input_squared = rewriter.create(loc, input, input); Value numerator = rewriter.create( loc, rewriter.getF32FloatAttr(numerator_coeffs[0])); for (int i = 1; i < numerator_coeffs.size(); i++) { @@ -127,9 +106,7 @@ class ApproximateTanhLowering rewriter.create( loc, rewriter.getF32FloatAttr(numerator_coeffs[i]))); } - - numerator = rewriter.create(loc, input_clamped, numerator); - + numerator = rewriter.create(loc, input, numerator); Value denominator = rewriter.create( loc, rewriter.getF32FloatAttr(denominator_coeffs[0])); for (int i = 1; i < denominator_coeffs.size(); i++) { @@ -138,10 +115,38 @@ class ApproximateTanhLowering rewriter.create( loc, rewriter.getF32FloatAttr(denominator_coeffs[i]))); } - Value approx = rewriter.create(loc, numerator, denominator); - return rewriter.create(loc, return_input, input, approx); + // For small values of |x|, we can approximate tanh(x) = x. For extremely + // small values of x (|x| < 1e-37), the other approximation would evaluate + // tanh(x) = 0. + constexpr float kUseIdentityApprox = 0.0004; + Value abs_input = rewriter.create(loc, input); + Value use_identity_approx = rewriter.create( + loc, CmpFPredicate::OLT, abs_input, + rewriter.create( + loc, rewriter.getF32FloatAttr(kUseIdentityApprox))); + approx = rewriter.create(loc, use_identity_approx, input, approx); + + // For very small/large values, use a constant approximation -1/1. + Value too_large_input = rewriter.create( + loc, CmpFPredicate::UGT, input, + rewriter.create( + loc, rewriter.getF32FloatAttr(7.90531110763549805f))); + Value too_small_input = rewriter.create( + loc, CmpFPredicate::ULT, input, + rewriter.create( + loc, rewriter.getF32FloatAttr(-7.90531110763549805f))); + approx = rewriter.create( + loc, too_large_input, + rewriter.create(loc, rewriter.getF32FloatAttr(1.0)), + approx); + approx = rewriter.create( + loc, too_small_input, + rewriter.create(loc, rewriter.getF32FloatAttr(-1.0)), + approx); + + return approx; } }; diff --git a/tests/legalize-trigonometric-to-approximation.mlir b/tests/legalize-trigonometric-to-approximation.mlir index 7178c6a..9b77d14 100644 --- a/tests/legalize-trigonometric-to-approximation.mlir +++ b/tests/legalize-trigonometric-to-approximation.mlir @@ -1,125 +1,78 @@ // RUN: mlir-hlo-opt --mhlo-legalize-trigonometric-to-approximation --split-input-file %s | FileCheck %s +// CHECK-LABEL: @tanh_f64 func @tanh_f64(%arg0 : f64) -> f64 { + // CHECK: tanh %res = math.tanh %arg0 : f64 return %res : f64 } -// CHECK-LABEL: @tanh_f64 -// CHECK: tanh - // ----- +// CHECK-LABEL: @tanh_f32 +// CHECK-SAME: (%[[ARG:.*]]: f32) -> f32 func @tanh_f32(%arg0 : f32) -> f32 { + // CHECK-DAG: %[[C:.*]] = constant -2.76076837E-16 : f32 + // CHECK-DAG: %[[C0:.*]] = constant 2.00018794E-13 : f32 + // CHECK-DAG: %[[C1:.*]] = constant -8.60467184E-11 : f32 + // CHECK-DAG: %[[C2:.*]] = constant 5.12229725E-8 : f32 + // CHECK-DAG: %[[C3:.*]] = constant 1.48572235E-5 : f32 + // CHECK-DAG: %[[C4:.*]] = constant 6.37261954E-4 : f32 + // CHECK-DAG: %[[C5:.*]] = constant 0.00489352457 : f32 + // CHECK-DAG: %[[C6:.*]] = constant 1.19825836E-6 : f32 + // CHECK-DAG: %[[C7:.*]] = constant 1.18534706E-4 : f32 + // CHECK-DAG: %[[C8:.*]] = constant 0.00226843474 : f32 + // CHECK-DAG: %[[C9:.*]] = constant 0.00489352504 : f32 + // CHECK-DAG: %[[C10:.*]] = constant 4.000000e-04 : f32 + // CHECK-DAG: %[[C11:.*]] = constant 7.90531111 : f32 + // CHECK-DAG: %[[C12:.*]] = constant -7.90531111 : f32 + // CHECK-DAG: %[[C13:.*]] = constant 1.000000e+00 : f32 + // CHECK-DAG: %[[C14:.*]] = constant -1.000000e+00 : f32 + // CHECK-DAG: %[[TMP0:.*]] = mulf %[[ARG]], %[[ARG]] : f32 + // CHECK-DAG: %[[TMP1:.*]] = mulf %[[TMP0]], %[[C]] : f32 + // CHECK-DAG: %[[TMP2:.*]] = addf %[[TMP1]], %[[C0]] : f32 + // CHECK-DAG: %[[TMP3:.*]] = mulf %[[TMP0]], %[[TMP2]] : f32 + // CHECK-DAG: %[[TMP4:.*]] = addf %[[TMP3]], %[[C1]] : f32 + // CHECK-DAG: %[[TMP5:.*]] = mulf %[[TMP0]], %[[TMP4]] : f32 + // CHECK-DAG: %[[TMP6:.*]] = addf %[[TMP5]], %[[C2]] : f32 + // CHECK-DAG: %[[TMP7:.*]] = mulf %[[TMP0]], %[[TMP6]] : f32 + // CHECK-DAG: %[[TMP8:.*]] = addf %[[TMP7]], %[[C3]] : f32 + // CHECK-DAG: %[[TMP9:.*]] = mulf %[[TMP0]], %[[TMP8]] : f32 + // CHECK-DAG: %[[TMP10:.*]] = addf %[[TMP9]], %[[C4]] : f32 + // CHECK-DAG: %[[TMP11:.*]] = mulf %[[TMP0]], %[[TMP10]] : f32 + // CHECK-DAG: %[[TMP12:.*]] = addf %[[TMP11]], %[[C5]] : f32 + // CHECK-DAG: %[[TMP13:.*]] = mulf %[[ARG]], %[[TMP12]] : f32 + // CHECK-DAG: %[[TMP14:.*]] = mulf %[[TMP0]], %[[C6]] : f32 + // CHECK-DAG: %[[TMP15:.*]] = addf %[[TMP14]], %[[C7]] : f32 + // CHECK-DAG: %[[TMP16:.*]] = mulf %[[TMP0]], %[[TMP15]] : f32 + // CHECK-DAG: %[[TMP17:.*]] = addf %[[TMP16]], %[[C8]] : f32 + // CHECK-DAG: %[[TMP18:.*]] = mulf %[[TMP0]], %[[TMP17]] : f32 + // CHECK-DAG: %[[TMP19:.*]] = addf %[[TMP18]], %[[C9]] : f32 + // CHECK-DAG: %[[TMP20:.*]] = divf %[[TMP13]], %[[TMP19]] : f32 + // CHECK-DAG: %[[TMP21:.*]] = absf %[[ARG]] : f32 + // CHECK-DAG: %[[TMP22:.*]] = cmpf olt, %[[TMP21]], %[[C10]] : f32 + // CHECK-DAG: %[[TMP23:.*]] = select %[[TMP22]], %[[ARG]], %[[TMP20]] : f32 + // CHECK-DAG: %[[TMP24:.*]] = cmpf ugt, %[[ARG]], %[[C11]] : f32 + // CHECK-DAG: %[[TMP25:.*]] = cmpf ult, %[[ARG]], %[[C12]] : f32 + // CHECK-DAG: %[[TMP26:.*]] = select %[[TMP24]], %[[C13]], %[[TMP23]] : f32 + // CHECK-DAG: %[[TMP27:.*]] = select %[[TMP25]], %[[C14]], %[[TMP26]] : f32 + // CHECK: return %[[TMP27]] : f32 %res = math.tanh %arg0 : f32 return %res : f32 } -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py - -// CHECK-LABEL: func @tanh_f32 -// CHECK-SAME: (%[[VAL_0:.*]]: f32) -> f32 -// CHECK: %[[VAL_1:.*]] = constant 4.000000e-04 : f32 -// CHECK: %[[VAL_2:.*]] = constant 7.90531111 : f32 -// CHECK: %[[VAL_3:.*]] = constant -7.90531111 : f32 -// CHECK: %[[VAL_4:.*]] = constant -2.76076837E-16 : f32 -// CHECK: %[[VAL_5:.*]] = constant 2.00018794E-13 : f32 -// CHECK: %[[VAL_6:.*]] = constant -8.60467184E-11 : f32 -// CHECK: %[[VAL_7:.*]] = constant 5.12229725E-8 : f32 -// CHECK: %[[VAL_8:.*]] = constant 1.48572235E-5 : f32 -// CHECK: %[[VAL_9:.*]] = constant 6.37261954E-4 : f32 -// CHECK: %[[VAL_10:.*]] = constant 0.00489352457 : f32 -// CHECK: %[[VAL_11:.*]] = constant 1.19825836E-6 : f32 -// CHECK: %[[VAL_12:.*]] = constant 1.18534706E-4 : f32 -// CHECK: %[[VAL_13:.*]] = constant 0.00226843474 : f32 -// CHECK: %[[VAL_14:.*]] = constant 0.00489352504 : f32 -// CHECK: %[[VAL_15:.*]] = absf %[[VAL_0]] : f32 -// CHECK: %[[VAL_16:.*]] = cmpf olt, %[[VAL_15]], %[[VAL_1]] : f32 -// CHECK: %[[VAL_17:.*]] = cmpf ule, %[[VAL_0]], %[[VAL_2]] : f32 -// CHECK: %[[VAL_18:.*]] = select %[[VAL_17]], %[[VAL_0]], %[[VAL_2]] : f32 -// CHECK: %[[VAL_19:.*]] = cmpf uge, %[[VAL_18]], %[[VAL_3]] : f32 -// CHECK: %[[VAL_20:.*]] = select %[[VAL_19]], %[[VAL_18]], %[[VAL_3]] : f32 -// CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_20]] : f32 -// CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_4]] : f32 -// CHECK: %[[VAL_23:.*]] = addf %[[VAL_22]], %[[VAL_5]] : f32 -// CHECK: %[[VAL_24:.*]] = mulf %[[VAL_21]], %[[VAL_23]] : f32 -// CHECK: %[[VAL_25:.*]] = addf %[[VAL_24]], %[[VAL_6]] : f32 -// CHECK: %[[VAL_26:.*]] = mulf %[[VAL_21]], %[[VAL_25]] : f32 -// CHECK: %[[VAL_27:.*]] = addf %[[VAL_26]], %[[VAL_7]] : f32 -// CHECK: %[[VAL_28:.*]] = mulf %[[VAL_21]], %[[VAL_27]] : f32 -// CHECK: %[[VAL_29:.*]] = addf %[[VAL_28]], %[[VAL_8]] : f32 -// CHECK: %[[VAL_30:.*]] = mulf %[[VAL_21]], %[[VAL_29]] : f32 -// CHECK: %[[VAL_31:.*]] = addf %[[VAL_30]], %[[VAL_9]] : f32 -// CHECK: %[[VAL_32:.*]] = mulf %[[VAL_21]], %[[VAL_31]] : f32 -// CHECK: %[[VAL_33:.*]] = addf %[[VAL_32]], %[[VAL_10]] : f32 -// CHECK: %[[VAL_34:.*]] = mulf %[[VAL_20]], %[[VAL_33]] : f32 -// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_21]], %[[VAL_11]] : f32 -// CHECK: %[[VAL_36:.*]] = addf %[[VAL_35]], %[[VAL_12]] : f32 -// CHECK: %[[VAL_37:.*]] = mulf %[[VAL_21]], %[[VAL_36]] : f32 -// CHECK: %[[VAL_38:.*]] = addf %[[VAL_37]], %[[VAL_13]] : f32 -// CHECK: %[[VAL_39:.*]] = mulf %[[VAL_21]], %[[VAL_38]] : f32 -// CHECK: %[[VAL_40:.*]] = addf %[[VAL_39]], %[[VAL_14]] : f32 -// CHECK: %[[VAL_41:.*]] = divf %[[VAL_34]], %[[VAL_40]] : f32 -// CHECK: %[[VAL_42:.*]] = select %[[VAL_16]], %[[VAL_0]], %[[VAL_41]] : f32 -// CHECK: return %[[VAL_42]] : f32 - // ----- func @tanh_f16(%arg0 : f16) -> f16 { + // CHECK-LABEL: func @tanh_f16 + // CHECK-SAME: (%[[ARG:.*]]: f16) -> f16 + // CHECK: %{{.*}} = fpext %[[ARG]] : f16 to f32 + // CHECK: %[[RES:.*]] = fptrunc %{{.*}} : f32 to f16 + // CHECK: return %[[RES]] : f16 %res = math.tanh %arg0 : f16 return %res : f16 } -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py - -// CHECK-LABEL: func @tanh_f16 -// CHECK-SAME: (%[[VAL_0:.*]]: f16) -> f16 -// CHECK: %[[VAL_1:.*]] = constant 4.000000e-04 : f32 -// CHECK: %[[VAL_2:.*]] = constant 7.90531111 : f32 -// CHECK: %[[VAL_3:.*]] = constant -7.90531111 : f32 -// CHECK: %[[VAL_4:.*]] = constant -2.76076837E-16 : f32 -// CHECK: %[[VAL_5:.*]] = constant 2.00018794E-13 : f32 -// CHECK: %[[VAL_6:.*]] = constant -8.60467184E-11 : f32 -// CHECK: %[[VAL_7:.*]] = constant 5.12229725E-8 : f32 -// CHECK: %[[VAL_8:.*]] = constant 1.48572235E-5 : f32 -// CHECK: %[[VAL_9:.*]] = constant 6.37261954E-4 : f32 -// CHECK: %[[VAL_10:.*]] = constant 0.00489352457 : f32 -// CHECK: %[[VAL_11:.*]] = constant 1.19825836E-6 : f32 -// CHECK: %[[VAL_12:.*]] = constant 1.18534706E-4 : f32 -// CHECK: %[[VAL_13:.*]] = constant 0.00226843474 : f32 -// CHECK: %[[VAL_14:.*]] = constant 0.00489352504 : f32 -// CHECK: %[[VAL_15:.*]] = fpext %[[VAL_0]] : f16 to f32 -// CHECK: %[[VAL_16:.*]] = absf %[[VAL_15]] : f32 -// CHECK: %[[VAL_17:.*]] = cmpf olt, %[[VAL_16]], %[[VAL_1]] : f32 -// CHECK: %[[VAL_18:.*]] = cmpf ule, %[[VAL_15]], %[[VAL_2]] : f32 -// CHECK: %[[VAL_19:.*]] = select %[[VAL_18]], %[[VAL_15]], %[[VAL_2]] : f32 -// CHECK: %[[VAL_20:.*]] = cmpf uge, %[[VAL_19]], %[[VAL_3]] : f32 -// CHECK: %[[VAL_21:.*]] = select %[[VAL_20]], %[[VAL_19]], %[[VAL_3]] : f32 -// CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_21]] : f32 -// CHECK: %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_4]] : f32 -// CHECK: %[[VAL_24:.*]] = addf %[[VAL_23]], %[[VAL_5]] : f32 -// CHECK: %[[VAL_25:.*]] = mulf %[[VAL_22]], %[[VAL_24]] : f32 -// CHECK: %[[VAL_26:.*]] = addf %[[VAL_25]], %[[VAL_6]] : f32 -// CHECK: %[[VAL_27:.*]] = mulf %[[VAL_22]], %[[VAL_26]] : f32 -// CHECK: %[[VAL_28:.*]] = addf %[[VAL_27]], %[[VAL_7]] : f32 -// CHECK: %[[VAL_29:.*]] = mulf %[[VAL_22]], %[[VAL_28]] : f32 -// CHECK: %[[VAL_30:.*]] = addf %[[VAL_29]], %[[VAL_8]] : f32 -// CHECK: %[[VAL_31:.*]] = mulf %[[VAL_22]], %[[VAL_30]] : f32 -// CHECK: %[[VAL_32:.*]] = addf %[[VAL_31]], %[[VAL_9]] : f32 -// CHECK: %[[VAL_33:.*]] = mulf %[[VAL_22]], %[[VAL_32]] : f32 -// CHECK: %[[VAL_34:.*]] = addf %[[VAL_33]], %[[VAL_10]] : f32 -// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_21]], %[[VAL_34]] : f32 -// CHECK: %[[VAL_36:.*]] = mulf %[[VAL_22]], %[[VAL_11]] : f32 -// CHECK: %[[VAL_37:.*]] = addf %[[VAL_36]], %[[VAL_12]] : f32 -// CHECK: %[[VAL_38:.*]] = mulf %[[VAL_22]], %[[VAL_37]] : f32 -// CHECK: %[[VAL_39:.*]] = addf %[[VAL_38]], %[[VAL_13]] : f32 -// CHECK: %[[VAL_40:.*]] = mulf %[[VAL_22]], %[[VAL_39]] : f32 -// CHECK: %[[VAL_41:.*]] = addf %[[VAL_40]], %[[VAL_14]] : f32 -// CHECK: %[[VAL_42:.*]] = divf %[[VAL_35]], %[[VAL_41]] : f32 -// CHECK: %[[VAL_43:.*]] = select %[[VAL_17]], %[[VAL_15]], %[[VAL_42]] : f32 -// CHECK: %[[VAL_44:.*]] = fptrunc %[[VAL_43]] : f32 to f16 -// CHECK: return %[[VAL_44]] : f16 - // ----- // CHECK-LABEL: @atan2_f64