Fix tanh lowering for NaN input.
If the input is NaN, the result should be NaN, too. PiperOrigin-RevId: 364788902
This commit is contained in:
parent
7dd0fe4592
commit
a34aa699f8
|
@ -137,6 +137,8 @@ class ApproximateTanhLowering
|
|||
loc, CmpFPredicate::ULT, input,
|
||||
rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getF32FloatAttr(-7.90531110763549805f)));
|
||||
Value input_is_nan =
|
||||
rewriter.create<CmpFOp>(loc, CmpFPredicate::UNE, input, input);
|
||||
approx = rewriter.create<SelectOp>(
|
||||
loc, too_large_input,
|
||||
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0)),
|
||||
|
@ -145,6 +147,7 @@ class ApproximateTanhLowering
|
|||
loc, too_small_input,
|
||||
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(-1.0)),
|
||||
approx);
|
||||
approx = rewriter.create<SelectOp>(loc, input_is_nan, input, approx);
|
||||
|
||||
return approx;
|
||||
}
|
||||
|
|
|
@ -54,9 +54,11 @@ func @tanh_f32(%arg0 : f32) -> f32 {
|
|||
// CHECK-DAG: %[[TMP23:.*]] = select %[[TMP22]], %[[ARG]], %[[TMP20]] : f32
|
||||
// CHECK-DAG: %[[TMP24:.*]] = cmpf ugt, %[[ARG]], %[[C11]] : f32
|
||||
// CHECK-DAG: %[[TMP25:.*]] = cmpf ult, %[[ARG]], %[[C12]] : f32
|
||||
// CHECK-DAG: %[[IS_NAN:.*]] = cmpf une, %[[ARG]], %[[ARG]] : f32
|
||||
// CHECK-DAG: %[[TMP26:.*]] = select %[[TMP24]], %[[C13]], %[[TMP23]] : f32
|
||||
// CHECK-DAG: %[[TMP27:.*]] = select %[[TMP25]], %[[C14]], %[[TMP26]] : f32
|
||||
// CHECK: return %[[TMP27]] : f32
|
||||
// CHECK-DAG: %[[RESULT:.*]] = select %[[IS_NAN]], %[[ARG]], %[[TMP27]] : f32
|
||||
// CHECK: return %[[RESULT]] : f32
|
||||
%res = math.tanh %arg0 : f32
|
||||
return %res : f32
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue