Fix tanh lowering for NaN input.

If the input is NaN, the result should be NaN, too.

PiperOrigin-RevId: 364788902
This commit is contained in:
Adrian Kuegel 2021-03-24 06:33:43 -07:00 committed by TensorFlow MLIR Team
parent 7dd0fe4592
commit a34aa699f8
2 changed files with 6 additions and 1 deletions

View File

@ -137,6 +137,8 @@ class ApproximateTanhLowering
loc, CmpFPredicate::ULT, input,
rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(-7.90531110763549805f)));
Value input_is_nan =
rewriter.create<CmpFOp>(loc, CmpFPredicate::UNE, input, input);
approx = rewriter.create<SelectOp>(
loc, too_large_input,
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0)),
@ -145,6 +147,7 @@ class ApproximateTanhLowering
loc, too_small_input,
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(-1.0)),
approx);
approx = rewriter.create<SelectOp>(loc, input_is_nan, input, approx);
return approx;
}

View File

@ -54,9 +54,11 @@ func @tanh_f32(%arg0 : f32) -> f32 {
// CHECK-DAG: %[[TMP23:.*]] = select %[[TMP22]], %[[ARG]], %[[TMP20]] : f32
// CHECK-DAG: %[[TMP24:.*]] = cmpf ugt, %[[ARG]], %[[C11]] : f32
// CHECK-DAG: %[[TMP25:.*]] = cmpf ult, %[[ARG]], %[[C12]] : f32
// CHECK-DAG: %[[IS_NAN:.*]] = cmpf une, %[[ARG]], %[[ARG]] : f32
// CHECK-DAG: %[[TMP26:.*]] = select %[[TMP24]], %[[C13]], %[[TMP23]] : f32
// CHECK-DAG: %[[TMP27:.*]] = select %[[TMP25]], %[[C14]], %[[TMP26]] : f32
// CHECK: return %[[TMP27]] : f32
// CHECK-DAG: %[[RESULT:.*]] = select %[[IS_NAN]], %[[ARG]], %[[TMP27]] : f32
// CHECK: return %[[RESULT]] : f32
%res = math.tanh %arg0 : f32
return %res : f32
}