Correct HLO atan2 lowering in cases of -inf and -0 inputs.

This is being done by just removing the approximation and lowering to atan2 lib calls later to make the implementation the same as XLA. Note that if the approximation is brought back later, it can be fixed by changing the IR checking `less-than(X, 0)` to `less-than(copysign(X, 1), 0)`

PiperOrigin-RevId: 356253941
This commit is contained in:
Tres Popp 2021-02-08 06:57:16 -08:00 committed by TensorFlow MLIR Team
parent bd0856578f
commit d086b8a0ec
3 changed files with 1 additions and 440 deletions

View File

@ -144,114 +144,6 @@ class ApproximateTanhLowering
} }
}; };
class ApproximateAtan2Lowering
: public ApproximateOnExtendedF32Lowering<Atan2Op> {
public:
explicit ApproximateAtan2Lowering(MLIRContext *ctx)
: ApproximateOnExtendedF32Lowering<Atan2Op>(ctx) {}
// Reduces atan2 to atan in the same way XLA does it.
Value emitApproximation(ValueRange args, Location loc,
PatternRewriter &rewriter) const override {
Value y = args[0];
Value x = args[1];
assert(x.getType().isF32() && y.getType().isF32() &&
"expect f32 arguments");
Value ax = rewriter.create<AbsFOp>(loc, x);
Value ay = rewriter.create<AbsFOp>(loc, y);
Value le_ax_ay = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLE, ax, ay);
Value min_ax_ay = rewriter.create<mlir::SelectOp>(loc, le_ax_ay, ax, ay);
Value max_ax_ay = rewriter.create<mlir::SelectOp>(loc, le_ax_ay, ay, ax);
Value zero_to_one = rewriter.create<DivFOp>(loc, min_ax_ay, max_ax_ay);
Value a = emitAtanCoreApproximation(zero_to_one, loc, rewriter);
Value pi_over_2 =
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.57079637f));
a = rewriter.create<mlir::SelectOp>(
loc, le_ax_ay, rewriter.create<SubFOp>(loc, pi_over_2, a), a);
Value zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0));
Value lt_x_0 = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, x, zero);
Value pi =
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(3.14159274f));
a = rewriter.create<mlir::SelectOp>(loc, lt_x_0,
rewriter.create<SubFOp>(loc, pi, a), a);
Value t = rewriter.create<mlir::SelectOp>(loc, lt_x_0, pi, zero);
Value eq_y_0 = rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, y, zero);
a = rewriter.create<mlir::SelectOp>(loc, eq_y_0, t, a);
// Propagate nan.
Value is_nan = rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, y, x);
Value nan = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(std::numeric_limits<float>::quiet_NaN()));
a = rewriter.create<mlir::SelectOp>(loc, is_nan, nan, a);
// x and y are +- inf.
Value three_pi_over_4 =
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.3561945f));
Value pi_over_4 = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(0.785398185f));
t = rewriter.create<mlir::SelectOp>(loc, lt_x_0, three_pi_over_4,
pi_over_4);
Value inf = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(std::numeric_limits<float>::infinity()));
Value eq_x_inf = rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, x, inf);
Value eq_y_inf = rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, y, inf);
Value all_inf = rewriter.create<mlir::AndOp>(loc, eq_x_inf, eq_y_inf);
a = rewriter.create<mlir::SelectOp>(loc, all_inf, t, a);
return rewriter.create<CopySignOp>(loc, a, y);
}
private:
// The core atan reduction derives from the heuristic described in
// https://arxiv.org/abs/1508.03211 and has a < 0.95 ulp error in the [-1, 1]
// range (though that assumed FMA was available, and it is not here). This is
// the same approximation that is also used by XLA.
Value emitAtanCoreApproximation(Value x, Location loc,
PatternRewriter &rewriter) const {
auto constant = [&](float c) {
return rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(c));
};
// Computes ab + c.
auto mul_add = [&](Value a, Value b, Value c) {
Value prod = rewriter.create<MulFOp>(loc, a, b);
return rewriter.create<AddFOp>(loc, prod, c);
};
Value s = rewriter.create<MulFOp>(loc, x, x);
Value r = constant(0.0027856871f);
r = mul_add(r, s, constant(-0.0158660002f));
r = mul_add(r, s, constant(0.042472221f));
r = mul_add(r, s, constant(-0.0749753043f));
r = mul_add(r, s, constant(0.106448799f));
r = mul_add(r, s, constant(-0.142070308f));
r = mul_add(r, s, constant(0.199934542f));
r = mul_add(r, s, constant(-0.333331466f));
r = rewriter.create<MulFOp>(loc, r, s);
return mul_add(r, x, x);
}
};
class ApproximateAtanLowering
: public ApproximateOnExtendedF32Lowering<AtanOp> {
public:
explicit ApproximateAtanLowering(MLIRContext *ctx)
: ApproximateOnExtendedF32Lowering<AtanOp>(ctx) {}
// Reduce atan(x) to atan2(x, 1) to subsequently rely on an atan approximation
// for the argument range [-1, 1].
Value emitApproximation(ValueRange args, Location loc,
PatternRewriter &rewriter) const override {
Value x = args.front();
assert(x.getType().isF32());
Value one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1));
return rewriter.create<Atan2Op>(loc, x, one);
}
};
struct LegalizeTrigonometricToApproximationPass struct LegalizeTrigonometricToApproximationPass
: public PassWrapper<LegalizeTrigonometricToApproximationPass, : public PassWrapper<LegalizeTrigonometricToApproximationPass,
FunctionPass> { FunctionPass> {
@ -273,10 +165,7 @@ createLegalizeTrigonometricToApproximationPass() {
void PopulateTrigonometricToApproximationPatterns( void PopulateTrigonometricToApproximationPatterns(
mlir::MLIRContext *context, OwningRewritePatternList *patterns) { mlir::MLIRContext *context, OwningRewritePatternList *patterns) {
// clang-format off // clang-format off
patterns->insert< patterns->insert<ApproximateTanhLowering>(context);
ApproximateAtanLowering,
ApproximateAtan2Lowering,
ApproximateTanhLowering>(context);
// clang-format on // clang-format on
} }

View File

@ -12,19 +12,6 @@ func @print_f32(%arg : f32) {
return return
} }
// Compute and print trigonometric function.
func @atan_f32(%arg : f32) {
%res = atan %arg : f32
call @print_f32(%res) : (f32) -> ()
return
}
func @atan2_f32(%arg0 : f32, %arg1 : f32) {
%res = atan2 %arg0, %arg1 : f32
call @print_f32(%res) : (f32) -> ()
return
}
func @tanh_f32(%arg : f32) { func @tanh_f32(%arg : f32) {
%res = tanh %arg : f32 %res = tanh %arg : f32
call @print_f32(%res) : (f32) -> () call @print_f32(%res) : (f32) -> ()
@ -81,79 +68,5 @@ func @main() {
call @tanh_f32(%cf_50_0) : (f32) -> () call @tanh_f32(%cf_50_0) : (f32) -> ()
// CHECK: 1 // CHECK: 1
// Atan2 with divisor 1.
call @atan2_f32(%cf_n50_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: -1.550{{.*}}
call @atan2_f32(%cf_n5_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: -1.373{{.*}}
call @atan2_f32(%cf_n3_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: -1.249{{.*}}
call @atan2_f32(%cf_n2_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: -1.107{{.*}}
call @atan2_f32(%cf_n1_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: -0.785{{.*}}
call @atan2_f32(%cf_n0_5, %cf_1_0) : (f32, f32) -> ()
// CHECK: -0.463{{.*}}
call @atan2_f32(%cf_n0_1, %cf_1_0) : (f32, f32) -> ()
// CHECK: -0.099{{.*}}
call @atan2_f32(%cf_0_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: 0
call @atan2_f32(%cf_0_1, %cf_1_0) : (f32, f32) -> ()
// CHECK: 0.099{{.*}}
call @atan2_f32(%cf_0_5, %cf_1_0) : (f32, f32) -> ()
// CHECK: 0.463{{.*}}
call @atan2_f32(%cf_1_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: 0.785{{.*}}
call @atan2_f32(%cf_2_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: 1.107{{.*}}
call @atan2_f32(%cf_3_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: 1.249{{.*}}
call @atan2_f32(%cf_5_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: 1.373{{.*}}
call @atan2_f32(%cf_50_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: 1.550{{.*}}
// Atan2 all four quadrants.
call @atan2_f32(%cf_n1_0, %cf_n1_0) : (f32, f32) -> ()
// CHECK: -2.356{{.*}}
call @atan2_f32(%cf_n1_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: -0.785{{.*}}
call @atan2_f32(%cf_1_0, %cf_n1_0) : (f32, f32) -> ()
// CHECK: 2.356{{.*}}
call @atan2_f32(%cf_1_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: 0.785{{.*}}
// Atan.
call @atan_f32(%cf_n50_0) : (f32) -> ()
// CHECK: -1.550{{.*}}
call @atan_f32(%cf_n5_0) : (f32) -> ()
// CHECK: -1.373{{.*}}
call @atan_f32(%cf_n3_0) : (f32) -> ()
// CHECK: -1.249{{.*}}
call @atan_f32(%cf_n2_0) : (f32) -> ()
// CHECK: -1.107{{.*}}
call @atan_f32(%cf_n1_0) : (f32) -> ()
// CHECK: -0.785{{.*}}
call @atan_f32(%cf_n0_5) : (f32) -> ()
// CHECK: -0.463{{.*}}
call @atan_f32(%cf_n0_1) : (f32) -> ()
// CHECK: -0.099{{.*}}
call @atan_f32(%cf_0_0) : (f32) -> ()
// CHECK: 0
call @atan_f32(%cf_0_1) : (f32) -> ()
// CHECK: 0.099{{.*}}
call @atan_f32(%cf_0_5) : (f32) -> ()
// CHECK: 0.463{{.*}}
call @atan_f32(%cf_1_0) : (f32) -> ()
// CHECK: 0.785{{.*}}
call @atan_f32(%cf_2_0) : (f32) -> ()
// CHECK: 1.107{{.*}}
call @atan_f32(%cf_3_0) : (f32) -> ()
// CHECK: 1.249{{.*}}
call @atan_f32(%cf_5_0) : (f32) -> ()
// CHECK: 1.373{{.*}}
call @atan_f32(%cf_50_0) : (f32) -> ()
// CHECK: 1.550{{.*}}
return return
} }

View File

@ -131,250 +131,9 @@ func @atan2_f64(%arg0 : f64, %arg1 : f64) -> f64 {
// ----- // -----
// CHECK-LABEL: func @atan2_f32
// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32) -> f32
func @atan2_f32(%arg0 : f32, %arg1 : f32) -> f32 {
// CHECK: %[[CST:.*]] = constant 0.0027856871 : f32
// CHECK: %[[CST_0:.*]] = constant -1.586600e-02 : f32
// CHECK: %[[CST_1:.*]] = constant 0.042472221 : f32
// CHECK: %[[CST_2:.*]] = constant -0.0749753043 : f32
// CHECK: %[[CST_3:.*]] = constant 0.106448799 : f32
// CHECK: %[[CST_4:.*]] = constant -0.142070308 : f32
// CHECK: %[[CST_5:.*]] = constant 0.199934542 : f32
// CHECK: %[[CST_6:.*]] = constant -0.333331466 : f32
// CHECK: %[[CST_7:.*]] = constant 1.57079637 : f32
// CHECK: %[[CST_8:.*]] = constant 0.000000e+00 : f32
// CHECK: %[[CST_9:.*]] = constant 3.14159274 : f32
// CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32
// CHECK: %[[CST_11:.*]] = constant 2.3561945 : f32
// CHECK: %[[CST_12:.*]] = constant 0.785398185 : f32
// CHECK: %[[CST_13:.*]] = constant 0x7F800000 : f32
// CHECK: %[[VAL_0:.*]] = absf %[[ARG1]] : f32
// CHECK: %[[VAL_1:.*]] = absf %[[ARG0]] : f32
// CHECK: %[[VAL_2:.*]] = cmpf ole, %[[VAL_0]], %[[VAL_1]] : f32
// CHECK: %[[VAL_3:.*]] = select %[[VAL_2]], %[[VAL_0]], %[[VAL_1]] : f32
// CHECK: %[[VAL_4:.*]] = select %[[VAL_2]], %[[VAL_1]], %[[VAL_0]] : f32
// CHECK: %[[VAL_5:.*]] = divf %[[VAL_3]], %[[VAL_4]] : f32
// CHECK: %[[VAL_6:.*]] = mulf %[[VAL_5]], %[[VAL_5]] : f32
// CHECK: %[[VAL_7:.*]] = mulf %[[CST]], %[[VAL_6]] : f32
// CHECK: %[[VAL_8:.*]] = addf %[[VAL_7]], %[[CST_0]] : f32
// CHECK: %[[VAL_9:.*]] = mulf %[[VAL_8]], %[[VAL_6]] : f32
// CHECK: %[[VAL_10:.*]] = addf %[[VAL_9]], %[[CST_1]] : f32
// CHECK: %[[VAL_11:.*]] = mulf %[[VAL_10]], %[[VAL_6]] : f32
// CHECK: %[[VAL_12:.*]] = addf %[[VAL_11]], %[[CST_2]] : f32
// CHECK: %[[VAL_13:.*]] = mulf %[[VAL_12]], %[[VAL_6]] : f32
// CHECK: %[[VAL_14:.*]] = addf %[[VAL_13]], %[[CST_3]] : f32
// CHECK: %[[VAL_15:.*]] = mulf %[[VAL_14]], %[[VAL_6]] : f32
// CHECK: %[[VAL_16:.*]] = addf %[[VAL_15]], %[[CST_4]] : f32
// CHECK: %[[VAL_17:.*]] = mulf %[[VAL_16]], %[[VAL_6]] : f32
// CHECK: %[[VAL_18:.*]] = addf %[[VAL_17]], %[[CST_5]] : f32
// CHECK: %[[VAL_19:.*]] = mulf %[[VAL_18]], %[[VAL_6]] : f32
// CHECK: %[[VAL_20:.*]] = addf %[[VAL_19]], %[[CST_6]] : f32
// CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_6]] : f32
// CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_5]] : f32
// CHECK: %[[VAL_23:.*]] = addf %[[VAL_22]], %[[VAL_5]] : f32
// CHECK: %[[VAL_24:.*]] = subf %[[CST_7]], %[[VAL_23]] : f32
// CHECK: %[[VAL_25:.*]] = select %[[VAL_2]], %[[VAL_24]], %[[VAL_23]] : f32
// CHECK: %[[VAL_26:.*]] = cmpf olt, %[[ARG1]], %[[CST_8]] : f32
// CHECK: %[[VAL_27:.*]] = subf %[[CST_9]], %[[VAL_25]] : f32
// CHECK: %[[VAL_28:.*]] = select %[[VAL_26]], %[[VAL_27]], %[[VAL_25]] : f32
// CHECK: %[[VAL_29:.*]] = select %[[VAL_26]], %[[CST_9]], %[[CST_8]] : f32
// CHECK: %[[VAL_30:.*]] = cmpf oeq, %[[ARG0]], %[[CST_8]] : f32
// CHECK: %[[VAL_31:.*]] = select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : f32
// CHECK: %[[VAL_32:.*]] = cmpf uno, %[[ARG0]], %[[ARG1]] : f32
// CHECK: %[[VAL_35:.*]] = select %[[VAL_32]], %[[CST_10]], %[[VAL_31]] : f32
// CHECK: %[[VAL_36:.*]] = select %[[VAL_26]], %[[CST_11]], %[[CST_12]] : f32
// CHECK: %[[VAL_37:.*]] = cmpf oeq, %[[ARG1]], %[[CST_13]] : f32
// CHECK: %[[VAL_38:.*]] = cmpf oeq, %[[ARG0]], %[[CST_13]] : f32
// CHECK: %[[VAL_39:.*]] = and %[[VAL_37]], %[[VAL_38]] : i1
// CHECK: %[[VAL_40:.*]] = select %[[VAL_39]], %[[VAL_36]], %[[VAL_35]] : f32
// CHECK: %[[VAL_41:.*]] = copysign %[[VAL_40]], %[[ARG0]] : f32
// CHECK: return %[[VAL_41]] : f32
%res = atan2 %arg0, %arg1 : f32
return %res : f32
}
// -----
// CHECK-LABEL: @atan2_f16
// CHECK-SAME: (%[[ARG0:.*]]: f16, %[[ARG1:.*]]: f16) -> f16
func @atan2_f16(%arg0 : f16, %arg1 : f16) -> f16 {
// CHECK: %[[CST:.*]] = constant 0.0027856871 : f32
// CHECK: %[[CST_0:.*]] = constant -1.586600e-02 : f32
// CHECK: %[[CST_1:.*]] = constant 0.042472221 : f32
// CHECK: %[[CST_2:.*]] = constant -0.0749753043 : f32
// CHECK: %[[CST_3:.*]] = constant 0.106448799 : f32
// CHECK: %[[CST_4:.*]] = constant -0.142070308 : f32
// CHECK: %[[CST_5:.*]] = constant 0.199934542 : f32
// CHECK: %[[CST_6:.*]] = constant -0.333331466 : f32
// CHECK: %[[CST_7:.*]] = constant 1.57079637 : f32
// CHECK: %[[CST_8:.*]] = constant 0.000000e+00 : f32
// CHECK: %[[CST_9:.*]] = constant 3.14159274 : f32
// CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32
// CHECK: %[[CST_11:.*]] = constant 2.3561945 : f32
// CHECK: %[[CST_12:.*]] = constant 0.785398185 : f32
// CHECK: %[[CST_13:.*]] = constant 0x7F800000 : f32
// CHECK: %[[VAL_0:.*]] = fpext %[[ARG0]] : f16 to f32
// CHECK: %[[VAL_1:.*]] = fpext %[[ARG1]] : f16 to f32
// CHECK: %[[VAL_2:.*]] = absf %[[VAL_1]] : f32
// CHECK: %[[VAL_3:.*]] = absf %[[VAL_0]] : f32
// CHECK: %[[VAL_4:.*]] = cmpf ole, %[[VAL_2]], %[[VAL_3]] : f32
// CHECK: %[[VAL_5:.*]] = select %[[VAL_4]], %[[VAL_2]], %[[VAL_3]] : f32
// CHECK: %[[VAL_6:.*]] = select %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : f32
// CHECK: %[[VAL_7:.*]] = divf %[[VAL_5]], %[[VAL_6]] : f32
// CHECK: %[[VAL_8:.*]] = mulf %[[VAL_7]], %[[VAL_7]] : f32
// CHECK: %[[VAL_9:.*]] = mulf %[[CST]], %[[VAL_8]] : f32
// CHECK: %[[VAL_10:.*]] = addf %[[VAL_9]], %[[CST_0]] : f32
// CHECK: %[[VAL_11:.*]] = mulf %[[VAL_10]], %[[VAL_8]] : f32
// CHECK: %[[VAL_12:.*]] = addf %[[VAL_11]], %[[CST_1]] : f32
// CHECK: %[[VAL_13:.*]] = mulf %[[VAL_12]], %[[VAL_8]] : f32
// CHECK: %[[VAL_14:.*]] = addf %[[VAL_13]], %[[CST_2]] : f32
// CHECK: %[[VAL_15:.*]] = mulf %[[VAL_14]], %[[VAL_8]] : f32
// CHECK: %[[VAL_16:.*]] = addf %[[VAL_15]], %[[CST_3]] : f32
// CHECK: %[[VAL_17:.*]] = mulf %[[VAL_16]], %[[VAL_8]] : f32
// CHECK: %[[VAL_18:.*]] = addf %[[VAL_17]], %[[CST_4]] : f32
// CHECK: %[[VAL_19:.*]] = mulf %[[VAL_18]], %[[VAL_8]] : f32
// CHECK: %[[VAL_20:.*]] = addf %[[VAL_19]], %[[CST_5]] : f32
// CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_8]] : f32
// CHECK: %[[VAL_22:.*]] = addf %[[VAL_21]], %[[CST_6]] : f32
// CHECK: %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_8]] : f32
// CHECK: %[[VAL_24:.*]] = mulf %[[VAL_23]], %[[VAL_7]] : f32
// CHECK: %[[VAL_25:.*]] = addf %[[VAL_24]], %[[VAL_7]] : f32
// CHECK: %[[VAL_26:.*]] = subf %[[CST_7]], %[[VAL_25]] : f32
// CHECK: %[[VAL_27:.*]] = select %[[VAL_4]], %[[VAL_26]], %[[VAL_25]] : f32
// CHECK: %[[VAL_28:.*]] = cmpf olt, %[[VAL_1]], %[[CST_8]] : f32
// CHECK: %[[VAL_29:.*]] = subf %[[CST_9]], %[[VAL_27]] : f32
// CHECK: %[[VAL_30:.*]] = select %[[VAL_28]], %[[VAL_29]], %[[VAL_27]] : f32
// CHECK: %[[VAL_31:.*]] = select %[[VAL_28]], %[[CST_9]], %[[CST_8]] : f32
// CHECK: %[[VAL_32:.*]] = cmpf oeq, %[[VAL_0]], %[[CST_8]] : f32
// CHECK: %[[VAL_33:.*]] = select %[[VAL_32]], %[[VAL_31]], %[[VAL_30]] : f32
// CHECK: %[[VAL_34:.*]] = cmpf uno, %[[VAL_0]], %[[VAL_1]] : f32
// CHECK: %[[VAL_37:.*]] = select %[[VAL_34]], %[[CST_10]], %[[VAL_33]] : f32
// CHECK: %[[VAL_38:.*]] = select %[[VAL_28]], %[[CST_11]], %[[CST_12]] : f32
// CHECK: %[[VAL_39:.*]] = cmpf oeq, %[[VAL_1]], %[[CST_13]] : f32
// CHECK: %[[VAL_40:.*]] = cmpf oeq, %[[VAL_0]], %[[CST_13]] : f32
// CHECK: %[[VAL_41:.*]] = and %[[VAL_39]], %[[VAL_40]] : i1
// CHECK: %[[VAL_42:.*]] = select %[[VAL_41]], %[[VAL_38]], %[[VAL_37]] : f32
// CHECK: %[[VAL_43:.*]] = copysign %[[VAL_42]], %[[VAL_0]] : f32
// CHECK: %[[VAL_44:.*]] = fptrunc %[[VAL_43]] : f32 to f16
// CHECK: return %[[VAL_44]] : f16
%res = atan2 %arg0, %arg1 : f16
return %res : f16
}
// -----
// CHECK-LABEL: @atan_f64 // CHECK-LABEL: @atan_f64
func @atan_f64(%arg : f64) -> f64 { func @atan_f64(%arg : f64) -> f64 {
// CHECK: atan // CHECK: atan
%res = atan %arg : f64 %res = atan %arg : f64
return %res : f64 return %res : f64
} }
// -----
// CHECK-LABEL: func @atan_f32
// CHECK-SAME: (%[[ARG:.*]]: f32) -> f32
func @atan_f32(%arg : f32) -> f32 {
// CHECK: %[[CST:.*]] = constant 1.000000e+00 : f32
// CHECK: %[[CST_0:.*]] = constant 0.0027856871 : f32
// CHECK: %[[CST_1:.*]] = constant -1.586600e-02 : f32
// CHECK: %[[CST_2:.*]] = constant 0.042472221 : f32
// CHECK: %[[CST_3:.*]] = constant -0.0749753043 : f32
// CHECK: %[[CST_4:.*]] = constant 0.106448799 : f32
// CHECK: %[[CST_5:.*]] = constant -0.142070308 : f32
// CHECK: %[[CST_6:.*]] = constant 0.199934542 : f32
// CHECK: %[[CST_7:.*]] = constant -0.333331466 : f32
// CHECK: %[[CST_8:.*]] = constant 1.57079637 : f32
// CHECK: %[[CST_9:.*]] = constant 0.000000e+00 : f32
// CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32
// CHECK: %[[VAL_0:.*]] = absf %[[CST]] : f32
// CHECK: %[[VAL_1:.*]] = absf %arg0 : f32
// CHECK: %[[VAL_2:.*]] = cmpf ole, %[[VAL_0]], %[[VAL_1]] : f32
// CHECK: %[[VAL_3:.*]] = select %[[VAL_2]], %[[VAL_0]], %[[VAL_1]] : f32
// CHECK: %[[VAL_4:.*]] = select %[[VAL_2]], %[[VAL_1]], %[[VAL_0]] : f32
// CHECK: %[[VAL_5:.*]] = divf %[[VAL_3]], %[[VAL_4]] : f32
// CHECK: %[[VAL_6:.*]] = mulf %[[VAL_5]], %[[VAL_5]] : f32
// CHECK: %[[VAL_7:.*]] = mulf %[[CST_0]], %[[VAL_6]] : f32
// CHECK: %[[VAL_8:.*]] = addf %[[VAL_7]], %[[CST_1]] : f32
// CHECK: %[[VAL_9:.*]] = mulf %[[VAL_8]], %[[VAL_6]] : f32
// CHECK: %[[VAL_10:.*]] = addf %[[VAL_9]], %[[CST_2]] : f32
// CHECK: %[[VAL_11:.*]] = mulf %[[VAL_10]], %[[VAL_6]] : f32
// CHECK: %[[VAL_12:.*]] = addf %[[VAL_11]], %[[CST_3]] : f32
// CHECK: %[[VAL_13:.*]] = mulf %[[VAL_12]], %[[VAL_6]] : f32
// CHECK: %[[VAL_14:.*]] = addf %[[VAL_13]], %[[CST_4]] : f32
// CHECK: %[[VAL_15:.*]] = mulf %[[VAL_14]], %[[VAL_6]] : f32
// CHECK: %[[VAL_16:.*]] = addf %[[VAL_15]], %[[CST_5]] : f32
// CHECK: %[[VAL_17:.*]] = mulf %[[VAL_16]], %[[VAL_6]] : f32
// CHECK: %[[VAL_18:.*]] = addf %[[VAL_17]], %[[CST_6]] : f32
// CHECK: %[[VAL_19:.*]] = mulf %[[VAL_18]], %[[VAL_6]] : f32
// CHECK: %[[VAL_20:.*]] = addf %[[VAL_19]], %[[CST_7]] : f32
// CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_6]] : f32
// CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_5]] : f32
// CHECK: %[[VAL_23:.*]] = addf %[[VAL_22]], %[[VAL_5]] : f32
// CHECK: %[[VAL_24:.*]] = subf %[[CST_8]], %[[VAL_23]] : f32
// CHECK: %[[VAL_25:.*]] = select %[[VAL_2]], %[[VAL_24]], %[[VAL_23]] : f32
// CHECK: %[[VAL_26:.*]] = cmpf oeq, %arg0, %[[CST_9]] : f32
// CHECK: %[[VAL_27:.*]] = select %[[VAL_26]], %[[CST_9]], %[[VAL_25]] : f32
// CHECK: %[[VAL_28:.*]] = cmpf uno, %arg0, %[[CST]] : f32
// CHECK: %[[VAL_29:.*]] = select %[[VAL_28]], %[[CST_10]], %[[VAL_27]] : f32
// CHECK: %[[VAL_30:.*]] = copysign %[[VAL_29]], %arg0 : f32
// CHECK: return %[[VAL_30]] : f32
%res = atan %arg : f32
return %res : f32
}
// -----
// CHECK-LABEL: @atan_f16
// CHECK-SAME: (%[[ARG:.*]]: f16) -> f16
func @atan_f16(%arg : f16) -> f16 {
// CHECK: %[[CST:.*]] = constant 1.000000e+00 : f32
// CHECK: %[[CST_0:.*]] = constant 0.0027856871 : f32
// CHECK: %[[CST_1:.*]] = constant -1.586600e-02 : f32
// CHECK: %[[CST_2:.*]] = constant 0.042472221 : f32
// CHECK: %[[CST_3:.*]] = constant -0.0749753043 : f32
// CHECK: %[[CST_4:.*]] = constant 0.106448799 : f32
// CHECK: %[[CST_5:.*]] = constant -0.142070308 : f32
// CHECK: %[[CST_6:.*]] = constant 0.199934542 : f32
// CHECK: %[[CST_7:.*]] = constant -0.333331466 : f32
// CHECK: %[[CST_8:.*]] = constant 1.57079637 : f32
// CHECK: %[[CST_9:.*]] = constant 0.000000e+00 : f32
// CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32
// CHECK: %[[VAL_0:.*]] = fpext %arg0 : f16 to f32
// CHECK: %[[VAL_1:.*]] = absf %[[CST]] : f32
// CHECK: %[[VAL_2:.*]] = absf %[[VAL_0]] : f32
// CHECK: %[[VAL_3:.*]] = cmpf ole, %[[VAL_1]], %[[VAL_2]] : f32
// CHECK: %[[VAL_4:.*]] = select %[[VAL_3]], %[[VAL_1]], %[[VAL_2]] : f32
// CHECK: %[[VAL_5:.*]] = select %[[VAL_3]], %[[VAL_2]], %[[VAL_1]] : f32
// CHECK: %[[VAL_6:.*]] = divf %[[VAL_4]], %[[VAL_5]] : f32
// CHECK: %[[VAL_7:.*]] = mulf %[[VAL_6]], %[[VAL_6]] : f32
// CHECK: %[[VAL_8:.*]] = mulf %[[CST_0]], %[[VAL_7]] : f32
// CHECK: %[[VAL_9:.*]] = addf %[[VAL_8]], %[[CST_1]] : f32
// CHECK: %[[VAL_10:.*]] = mulf %[[VAL_9]], %[[VAL_7]] : f32
// CHECK: %[[VAL_11:.*]] = addf %[[VAL_10]], %[[CST_2]] : f32
// CHECK: %[[VAL_12:.*]] = mulf %[[VAL_11]], %[[VAL_7]] : f32
// CHECK: %[[VAL_13:.*]] = addf %[[VAL_12]], %[[CST_3]] : f32
// CHECK: %[[VAL_14:.*]] = mulf %[[VAL_13]], %[[VAL_7]] : f32
// CHECK: %[[VAL_15:.*]] = addf %[[VAL_14]], %[[CST_4]] : f32
// CHECK: %[[VAL_16:.*]] = mulf %[[VAL_15]], %[[VAL_7]] : f32
// CHECK: %[[VAL_17:.*]] = addf %[[VAL_16]], %[[CST_5]] : f32
// CHECK: %[[VAL_18:.*]] = mulf %[[VAL_17]], %[[VAL_7]] : f32
// CHECK: %[[VAL_19:.*]] = addf %[[VAL_18]], %[[CST_6]] : f32
// CHECK: %[[VAL_20:.*]] = mulf %[[VAL_19]], %[[VAL_7]] : f32
// CHECK: %[[VAL_21:.*]] = addf %[[VAL_20]], %[[CST_7]] : f32
// CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_7]] : f32
// CHECK: %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_6]] : f32
// CHECK: %[[VAL_24:.*]] = addf %[[VAL_23]], %[[VAL_6]] : f32
// CHECK: %[[VAL_25:.*]] = subf %[[CST_8]], %[[VAL_24]] : f32
// CHECK: %[[VAL_26:.*]] = select %[[VAL_3]], %[[VAL_25]], %[[VAL_24]] : f32
// CHECK: %[[VAL_27:.*]] = cmpf oeq, %[[VAL_0]], %[[CST_9]] : f32
// CHECK: %[[VAL_28:.*]] = select %[[VAL_27]], %[[CST_9]], %[[VAL_26]] : f32
// CHECK: %[[VAL_29:.*]] = cmpf uno, %[[VAL_0]], %[[CST]] : f32
// CHECK: %[[VAL_30:.*]] = select %[[VAL_29]], %[[CST_10]], %[[VAL_28]] : f32
// CHECK: %[[VAL_31:.*]] = copysign %[[VAL_30]], %[[VAL_0]] : f32
// CHECK: %[[VAL_32:.*]] = fptrunc %[[VAL_31]] : f32 to f16
// CHECK: return %[[VAL_32]] : f16
%res = atan %arg : f16
return %res : f16
}