diff --git a/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc b/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc index 180bf9d..24058e0 100644 --- a/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc +++ b/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc @@ -137,6 +137,8 @@ class ApproximateTanhLowering loc, CmpFPredicate::ULT, input, rewriter.create( loc, rewriter.getF32FloatAttr(-7.90531110763549805f))); + Value input_is_nan = + rewriter.create(loc, CmpFPredicate::UNE, input, input); approx = rewriter.create( loc, too_large_input, rewriter.create(loc, rewriter.getF32FloatAttr(1.0)), @@ -145,6 +147,7 @@ class ApproximateTanhLowering loc, too_small_input, rewriter.create(loc, rewriter.getF32FloatAttr(-1.0)), approx); + approx = rewriter.create(loc, input_is_nan, input, approx); return approx; } diff --git a/tests/legalize-trigonometric-to-approximation.mlir b/tests/legalize-trigonometric-to-approximation.mlir index 9b77d14..e19c4f0 100644 --- a/tests/legalize-trigonometric-to-approximation.mlir +++ b/tests/legalize-trigonometric-to-approximation.mlir @@ -54,9 +54,11 @@ func @tanh_f32(%arg0 : f32) -> f32 { // CHECK-DAG: %[[TMP23:.*]] = select %[[TMP22]], %[[ARG]], %[[TMP20]] : f32 // CHECK-DAG: %[[TMP24:.*]] = cmpf ugt, %[[ARG]], %[[C11]] : f32 // CHECK-DAG: %[[TMP25:.*]] = cmpf ult, %[[ARG]], %[[C12]] : f32 + // CHECK-DAG: %[[IS_NAN:.*]] = cmpf une, %[[ARG]], %[[ARG]] : f32 // CHECK-DAG: %[[TMP26:.*]] = select %[[TMP24]], %[[C13]], %[[TMP23]] : f32 // CHECK-DAG: %[[TMP27:.*]] = select %[[TMP25]], %[[C14]], %[[TMP26]] : f32 - // CHECK: return %[[TMP27]] : f32 + // CHECK-DAG: %[[RESULT:.*]] = select %[[IS_NAN]], %[[ARG]], %[[TMP27]] : f32 + // CHECK: return %[[RESULT]] : f32 %res = math.tanh %arg0 : f32 return %res : f32 }