From a34aa699f8f1ee954d12262876e35e85c1aee1d3 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Wed, 24 Mar 2021 06:33:43 -0700 Subject: [PATCH] Fix tanh lowering for NaN input. If the input is NaN, the result should be NaN, too. PiperOrigin-RevId: 364788902 --- .../transforms/legalize_trigonometric_to_approximation.cc | 3 +++ tests/legalize-trigonometric-to-approximation.mlir | 4 +++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc b/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc index 180bf9d..24058e0 100644 --- a/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc +++ b/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc @@ -137,6 +137,8 @@ class ApproximateTanhLowering loc, CmpFPredicate::ULT, input, rewriter.create( loc, rewriter.getF32FloatAttr(-7.90531110763549805f))); + Value input_is_nan = + rewriter.create(loc, CmpFPredicate::UNE, input, input); approx = rewriter.create( loc, too_large_input, rewriter.create(loc, rewriter.getF32FloatAttr(1.0)), @@ -145,6 +147,7 @@ class ApproximateTanhLowering loc, too_small_input, rewriter.create(loc, rewriter.getF32FloatAttr(-1.0)), approx); + approx = rewriter.create(loc, input_is_nan, input, approx); return approx; } diff --git a/tests/legalize-trigonometric-to-approximation.mlir b/tests/legalize-trigonometric-to-approximation.mlir index 9b77d14..e19c4f0 100644 --- a/tests/legalize-trigonometric-to-approximation.mlir +++ b/tests/legalize-trigonometric-to-approximation.mlir @@ -54,9 +54,11 @@ func @tanh_f32(%arg0 : f32) -> f32 { // CHECK-DAG: %[[TMP23:.*]] = select %[[TMP22]], %[[ARG]], %[[TMP20]] : f32 // CHECK-DAG: %[[TMP24:.*]] = cmpf ugt, %[[ARG]], %[[C11]] : f32 // CHECK-DAG: %[[TMP25:.*]] = cmpf ult, %[[ARG]], %[[C12]] : f32 + // CHECK-DAG: %[[IS_NAN:.*]] = cmpf une, %[[ARG]], %[[ARG]] : f32 // CHECK-DAG: %[[TMP26:.*]] = select %[[TMP24]], %[[C13]], %[[TMP23]] : f32 // CHECK-DAG: %[[TMP27:.*]] = select %[[TMP25]], %[[C14]], %[[TMP26]] : f32 - // CHECK: return %[[TMP27]] : f32 + // CHECK-DAG: %[[RESULT:.*]] = select %[[IS_NAN]], %[[ARG]], %[[TMP27]] : f32 + // CHECK: return %[[RESULT]] : f32 %res = math.tanh %arg0 : f32 return %res : f32 }