Correct HLO atan2 lowering in cases of -inf and -0 inputs.
This is being done by just removing the approximation and lowering to atan2 lib calls later to make the implementation the same as XLA. Note that if the approximation is brought back later, it can be fixed by changing the IR checking `less-than(X, 0)` to `less-than(copysign(X, 1), 0)` PiperOrigin-RevId: 356253941
This commit is contained in:
		
							parent
							
								
									bd0856578f
								
							
						
					
					
						commit
						d086b8a0ec
					
				|  | @ -144,114 +144,6 @@ class ApproximateTanhLowering | |||
|   } | ||||
| }; | ||||
| 
 | ||||
| class ApproximateAtan2Lowering | ||||
|     : public ApproximateOnExtendedF32Lowering<Atan2Op> { | ||||
|  public: | ||||
|   explicit ApproximateAtan2Lowering(MLIRContext *ctx) | ||||
|       : ApproximateOnExtendedF32Lowering<Atan2Op>(ctx) {} | ||||
| 
 | ||||
|   // Reduces atan2 to atan in the same way XLA does it.
 | ||||
|   Value emitApproximation(ValueRange args, Location loc, | ||||
|                           PatternRewriter &rewriter) const override { | ||||
|     Value y = args[0]; | ||||
|     Value x = args[1]; | ||||
|     assert(x.getType().isF32() && y.getType().isF32() && | ||||
|            "expect f32 arguments"); | ||||
|     Value ax = rewriter.create<AbsFOp>(loc, x); | ||||
|     Value ay = rewriter.create<AbsFOp>(loc, y); | ||||
|     Value le_ax_ay = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLE, ax, ay); | ||||
|     Value min_ax_ay = rewriter.create<mlir::SelectOp>(loc, le_ax_ay, ax, ay); | ||||
|     Value max_ax_ay = rewriter.create<mlir::SelectOp>(loc, le_ax_ay, ay, ax); | ||||
|     Value zero_to_one = rewriter.create<DivFOp>(loc, min_ax_ay, max_ax_ay); | ||||
|     Value a = emitAtanCoreApproximation(zero_to_one, loc, rewriter); | ||||
| 
 | ||||
|     Value pi_over_2 = | ||||
|         rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.57079637f)); | ||||
|     a = rewriter.create<mlir::SelectOp>( | ||||
|         loc, le_ax_ay, rewriter.create<SubFOp>(loc, pi_over_2, a), a); | ||||
| 
 | ||||
|     Value zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0)); | ||||
|     Value lt_x_0 = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, x, zero); | ||||
|     Value pi = | ||||
|         rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(3.14159274f)); | ||||
|     a = rewriter.create<mlir::SelectOp>(loc, lt_x_0, | ||||
|                                         rewriter.create<SubFOp>(loc, pi, a), a); | ||||
| 
 | ||||
|     Value t = rewriter.create<mlir::SelectOp>(loc, lt_x_0, pi, zero); | ||||
|     Value eq_y_0 = rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, y, zero); | ||||
|     a = rewriter.create<mlir::SelectOp>(loc, eq_y_0, t, a); | ||||
| 
 | ||||
|     // Propagate nan.
 | ||||
|     Value is_nan = rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, y, x); | ||||
|     Value nan = rewriter.create<ConstantOp>( | ||||
|         loc, rewriter.getF32FloatAttr(std::numeric_limits<float>::quiet_NaN())); | ||||
|     a = rewriter.create<mlir::SelectOp>(loc, is_nan, nan, a); | ||||
| 
 | ||||
|     // x and y are +- inf.
 | ||||
|     Value three_pi_over_4 = | ||||
|         rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.3561945f)); | ||||
|     Value pi_over_4 = rewriter.create<ConstantOp>( | ||||
|         loc, rewriter.getF32FloatAttr(0.785398185f)); | ||||
|     t = rewriter.create<mlir::SelectOp>(loc, lt_x_0, three_pi_over_4, | ||||
|                                         pi_over_4); | ||||
|     Value inf = rewriter.create<ConstantOp>( | ||||
|         loc, rewriter.getF32FloatAttr(std::numeric_limits<float>::infinity())); | ||||
|     Value eq_x_inf = rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, x, inf); | ||||
|     Value eq_y_inf = rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, y, inf); | ||||
|     Value all_inf = rewriter.create<mlir::AndOp>(loc, eq_x_inf, eq_y_inf); | ||||
|     a = rewriter.create<mlir::SelectOp>(loc, all_inf, t, a); | ||||
| 
 | ||||
|     return rewriter.create<CopySignOp>(loc, a, y); | ||||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   // The core atan reduction derives from the heuristic described in
 | ||||
|   // https://arxiv.org/abs/1508.03211 and has a < 0.95 ulp error in the [-1, 1]
 | ||||
|   // range (though that assumed FMA was available, and it is not here).  This is
 | ||||
|   // the same approximation that is also used by XLA.
 | ||||
|   Value emitAtanCoreApproximation(Value x, Location loc, | ||||
|                                   PatternRewriter &rewriter) const { | ||||
|     auto constant = [&](float c) { | ||||
|       return rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(c)); | ||||
|     }; | ||||
| 
 | ||||
|     // Computes ab + c.
 | ||||
|     auto mul_add = [&](Value a, Value b, Value c) { | ||||
|       Value prod = rewriter.create<MulFOp>(loc, a, b); | ||||
|       return rewriter.create<AddFOp>(loc, prod, c); | ||||
|     }; | ||||
| 
 | ||||
|     Value s = rewriter.create<MulFOp>(loc, x, x); | ||||
|     Value r = constant(0.0027856871f); | ||||
|     r = mul_add(r, s, constant(-0.0158660002f)); | ||||
|     r = mul_add(r, s, constant(0.042472221f)); | ||||
|     r = mul_add(r, s, constant(-0.0749753043f)); | ||||
|     r = mul_add(r, s, constant(0.106448799f)); | ||||
|     r = mul_add(r, s, constant(-0.142070308f)); | ||||
|     r = mul_add(r, s, constant(0.199934542f)); | ||||
|     r = mul_add(r, s, constant(-0.333331466f)); | ||||
|     r = rewriter.create<MulFOp>(loc, r, s); | ||||
|     return mul_add(r, x, x); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| class ApproximateAtanLowering | ||||
|     : public ApproximateOnExtendedF32Lowering<AtanOp> { | ||||
|  public: | ||||
|   explicit ApproximateAtanLowering(MLIRContext *ctx) | ||||
|       : ApproximateOnExtendedF32Lowering<AtanOp>(ctx) {} | ||||
| 
 | ||||
|   // Reduce atan(x) to atan2(x, 1) to subsequently rely on an atan approximation
 | ||||
|   // for the argument range [-1, 1].
 | ||||
|   Value emitApproximation(ValueRange args, Location loc, | ||||
|                           PatternRewriter &rewriter) const override { | ||||
|     Value x = args.front(); | ||||
|     assert(x.getType().isF32()); | ||||
|     Value one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1)); | ||||
|     return rewriter.create<Atan2Op>(loc, x, one); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| struct LegalizeTrigonometricToApproximationPass | ||||
|     : public PassWrapper<LegalizeTrigonometricToApproximationPass, | ||||
|                          FunctionPass> { | ||||
|  | @ -273,10 +165,7 @@ createLegalizeTrigonometricToApproximationPass() { | |||
| void PopulateTrigonometricToApproximationPatterns( | ||||
|     mlir::MLIRContext *context, OwningRewritePatternList *patterns) { | ||||
|   // clang-format off
 | ||||
|   patterns->insert< | ||||
|       ApproximateAtanLowering, | ||||
|       ApproximateAtan2Lowering, | ||||
|       ApproximateTanhLowering>(context); | ||||
|   patterns->insert<ApproximateTanhLowering>(context); | ||||
|   // clang-format on
 | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -12,19 +12,6 @@ func @print_f32(%arg : f32) { | |||
|   return | ||||
| } | ||||
| 
 | ||||
| // Compute and print trigonometric function. | ||||
| func @atan_f32(%arg : f32) { | ||||
|   %res = atan %arg : f32 | ||||
|   call @print_f32(%res) : (f32) -> () | ||||
|   return | ||||
| } | ||||
| 
 | ||||
| func @atan2_f32(%arg0 : f32, %arg1 : f32) { | ||||
|   %res = atan2 %arg0, %arg1 : f32 | ||||
|   call @print_f32(%res) : (f32) -> () | ||||
|   return | ||||
| } | ||||
| 
 | ||||
| func @tanh_f32(%arg : f32) { | ||||
|   %res = tanh %arg : f32 | ||||
|   call @print_f32(%res) : (f32) -> () | ||||
|  | @ -81,79 +68,5 @@ func @main() { | |||
|   call @tanh_f32(%cf_50_0) : (f32) -> () | ||||
|   // CHECK: 1 | ||||
| 
 | ||||
|   // Atan2 with divisor 1. | ||||
|   call @atan2_f32(%cf_n50_0, %cf_1_0) : (f32, f32) -> () | ||||
|   // CHECK: -1.550{{.*}} | ||||
|   call @atan2_f32(%cf_n5_0, %cf_1_0) : (f32, f32) -> () | ||||
|   // CHECK: -1.373{{.*}} | ||||
|   call @atan2_f32(%cf_n3_0, %cf_1_0) : (f32, f32) -> () | ||||
|   // CHECK: -1.249{{.*}} | ||||
|   call @atan2_f32(%cf_n2_0, %cf_1_0) : (f32, f32) -> () | ||||
|   // CHECK: -1.107{{.*}} | ||||
|   call @atan2_f32(%cf_n1_0, %cf_1_0) : (f32, f32) -> () | ||||
|   // CHECK: -0.785{{.*}} | ||||
|   call @atan2_f32(%cf_n0_5, %cf_1_0) : (f32, f32) -> () | ||||
|   // CHECK: -0.463{{.*}} | ||||
|   call @atan2_f32(%cf_n0_1, %cf_1_0) : (f32, f32) -> () | ||||
|   // CHECK: -0.099{{.*}} | ||||
|   call @atan2_f32(%cf_0_0, %cf_1_0) : (f32, f32) -> () | ||||
|   // CHECK: 0 | ||||
|   call @atan2_f32(%cf_0_1, %cf_1_0) : (f32, f32) -> () | ||||
|   // CHECK: 0.099{{.*}} | ||||
|   call @atan2_f32(%cf_0_5, %cf_1_0) : (f32, f32) -> () | ||||
|   // CHECK: 0.463{{.*}} | ||||
|   call @atan2_f32(%cf_1_0, %cf_1_0) : (f32, f32) -> () | ||||
|   // CHECK: 0.785{{.*}} | ||||
|   call @atan2_f32(%cf_2_0, %cf_1_0) : (f32, f32) -> () | ||||
|   // CHECK: 1.107{{.*}} | ||||
|   call @atan2_f32(%cf_3_0, %cf_1_0) : (f32, f32) -> () | ||||
|   // CHECK: 1.249{{.*}} | ||||
|   call @atan2_f32(%cf_5_0, %cf_1_0) : (f32, f32) -> () | ||||
|   // CHECK: 1.373{{.*}} | ||||
|   call @atan2_f32(%cf_50_0, %cf_1_0) : (f32, f32) -> () | ||||
|   // CHECK: 1.550{{.*}} | ||||
| 
 | ||||
|   // Atan2 all four quadrants. | ||||
|   call @atan2_f32(%cf_n1_0, %cf_n1_0) : (f32, f32) -> () | ||||
|   // CHECK: -2.356{{.*}} | ||||
|   call @atan2_f32(%cf_n1_0, %cf_1_0) : (f32, f32) -> () | ||||
|   // CHECK: -0.785{{.*}} | ||||
|   call @atan2_f32(%cf_1_0, %cf_n1_0) : (f32, f32) -> () | ||||
|   // CHECK: 2.356{{.*}} | ||||
|   call @atan2_f32(%cf_1_0, %cf_1_0) : (f32, f32) -> () | ||||
|   // CHECK: 0.785{{.*}} | ||||
| 
 | ||||
|   // Atan. | ||||
|   call @atan_f32(%cf_n50_0) : (f32) -> () | ||||
|   // CHECK: -1.550{{.*}} | ||||
|   call @atan_f32(%cf_n5_0) : (f32) -> () | ||||
|   // CHECK: -1.373{{.*}} | ||||
|   call @atan_f32(%cf_n3_0) : (f32) -> () | ||||
|   // CHECK: -1.249{{.*}} | ||||
|   call @atan_f32(%cf_n2_0) : (f32) -> () | ||||
|   // CHECK: -1.107{{.*}} | ||||
|   call @atan_f32(%cf_n1_0) : (f32) -> () | ||||
|   // CHECK: -0.785{{.*}} | ||||
|   call @atan_f32(%cf_n0_5) : (f32) -> () | ||||
|   // CHECK: -0.463{{.*}} | ||||
|   call @atan_f32(%cf_n0_1) : (f32) -> () | ||||
|   // CHECK: -0.099{{.*}} | ||||
|   call @atan_f32(%cf_0_0) : (f32) -> () | ||||
|   // CHECK: 0 | ||||
|   call @atan_f32(%cf_0_1) : (f32) -> () | ||||
|   // CHECK: 0.099{{.*}} | ||||
|   call @atan_f32(%cf_0_5) : (f32) -> () | ||||
|   // CHECK: 0.463{{.*}} | ||||
|   call @atan_f32(%cf_1_0) : (f32) -> () | ||||
|   // CHECK: 0.785{{.*}} | ||||
|   call @atan_f32(%cf_2_0) : (f32) -> () | ||||
|   // CHECK: 1.107{{.*}} | ||||
|   call @atan_f32(%cf_3_0) : (f32) -> () | ||||
|   // CHECK: 1.249{{.*}} | ||||
|   call @atan_f32(%cf_5_0) : (f32) -> () | ||||
|   // CHECK: 1.373{{.*}} | ||||
|   call @atan_f32(%cf_50_0) : (f32) -> () | ||||
|   // CHECK: 1.550{{.*}} | ||||
| 
 | ||||
|   return | ||||
| } | ||||
|  |  | |||
|  | @ -131,250 +131,9 @@ func @atan2_f64(%arg0 : f64, %arg1 : f64) -> f64 { | |||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: func @atan2_f32 | ||||
| // CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32) -> f32 | ||||
| func @atan2_f32(%arg0 : f32, %arg1 : f32) -> f32 { | ||||
|   // CHECK: %[[CST:.*]] = constant 0.0027856871 : f32 | ||||
|   // CHECK: %[[CST_0:.*]] = constant -1.586600e-02 : f32 | ||||
|   // CHECK: %[[CST_1:.*]] = constant 0.042472221 : f32 | ||||
|   // CHECK: %[[CST_2:.*]] = constant -0.0749753043 : f32 | ||||
|   // CHECK: %[[CST_3:.*]] = constant 0.106448799 : f32 | ||||
|   // CHECK: %[[CST_4:.*]] = constant -0.142070308 : f32 | ||||
|   // CHECK: %[[CST_5:.*]] = constant 0.199934542 : f32 | ||||
|   // CHECK: %[[CST_6:.*]] = constant -0.333331466 : f32 | ||||
|   // CHECK: %[[CST_7:.*]] = constant 1.57079637 : f32 | ||||
|   // CHECK: %[[CST_8:.*]] = constant 0.000000e+00 : f32 | ||||
|   // CHECK: %[[CST_9:.*]] = constant 3.14159274 : f32 | ||||
|   // CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32 | ||||
|   // CHECK: %[[CST_11:.*]] = constant 2.3561945 : f32 | ||||
|   // CHECK: %[[CST_12:.*]] = constant 0.785398185 : f32 | ||||
|   // CHECK: %[[CST_13:.*]] = constant 0x7F800000 : f32 | ||||
|   // CHECK: %[[VAL_0:.*]] = absf %[[ARG1]] : f32 | ||||
|   // CHECK: %[[VAL_1:.*]] = absf %[[ARG0]] : f32 | ||||
|   // CHECK: %[[VAL_2:.*]] = cmpf ole, %[[VAL_0]], %[[VAL_1]] : f32 | ||||
|   // CHECK: %[[VAL_3:.*]] = select %[[VAL_2]], %[[VAL_0]], %[[VAL_1]] : f32 | ||||
|   // CHECK: %[[VAL_4:.*]] = select %[[VAL_2]], %[[VAL_1]], %[[VAL_0]] : f32 | ||||
|   // CHECK: %[[VAL_5:.*]] = divf %[[VAL_3]], %[[VAL_4]] : f32 | ||||
|   // CHECK: %[[VAL_6:.*]] = mulf %[[VAL_5]], %[[VAL_5]] : f32 | ||||
|   // CHECK: %[[VAL_7:.*]] = mulf %[[CST]], %[[VAL_6]] : f32 | ||||
|   // CHECK: %[[VAL_8:.*]] = addf %[[VAL_7]], %[[CST_0]] : f32 | ||||
|   // CHECK: %[[VAL_9:.*]] = mulf %[[VAL_8]], %[[VAL_6]] : f32 | ||||
|   // CHECK: %[[VAL_10:.*]] = addf %[[VAL_9]], %[[CST_1]] : f32 | ||||
|   // CHECK: %[[VAL_11:.*]] = mulf %[[VAL_10]], %[[VAL_6]] : f32 | ||||
|   // CHECK: %[[VAL_12:.*]] = addf %[[VAL_11]], %[[CST_2]] : f32 | ||||
|   // CHECK: %[[VAL_13:.*]] = mulf %[[VAL_12]], %[[VAL_6]] : f32 | ||||
|   // CHECK: %[[VAL_14:.*]] = addf %[[VAL_13]], %[[CST_3]] : f32 | ||||
|   // CHECK: %[[VAL_15:.*]] = mulf %[[VAL_14]], %[[VAL_6]] : f32 | ||||
|   // CHECK: %[[VAL_16:.*]] = addf %[[VAL_15]], %[[CST_4]] : f32 | ||||
|   // CHECK: %[[VAL_17:.*]] = mulf %[[VAL_16]], %[[VAL_6]] : f32 | ||||
|   // CHECK: %[[VAL_18:.*]] = addf %[[VAL_17]], %[[CST_5]] : f32 | ||||
|   // CHECK: %[[VAL_19:.*]] = mulf %[[VAL_18]], %[[VAL_6]] : f32 | ||||
|   // CHECK: %[[VAL_20:.*]] = addf %[[VAL_19]], %[[CST_6]] : f32 | ||||
|   // CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_6]] : f32 | ||||
|   // CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_5]] : f32 | ||||
|   // CHECK: %[[VAL_23:.*]] = addf %[[VAL_22]], %[[VAL_5]] : f32 | ||||
|   // CHECK: %[[VAL_24:.*]] = subf %[[CST_7]], %[[VAL_23]] : f32 | ||||
|   // CHECK: %[[VAL_25:.*]] = select %[[VAL_2]], %[[VAL_24]], %[[VAL_23]] : f32 | ||||
|   // CHECK: %[[VAL_26:.*]] = cmpf olt, %[[ARG1]], %[[CST_8]] : f32 | ||||
|   // CHECK: %[[VAL_27:.*]] = subf %[[CST_9]], %[[VAL_25]] : f32 | ||||
|   // CHECK: %[[VAL_28:.*]] = select %[[VAL_26]], %[[VAL_27]], %[[VAL_25]] : f32 | ||||
|   // CHECK: %[[VAL_29:.*]] = select %[[VAL_26]], %[[CST_9]], %[[CST_8]] : f32 | ||||
|   // CHECK: %[[VAL_30:.*]] = cmpf oeq, %[[ARG0]], %[[CST_8]] : f32 | ||||
|   // CHECK: %[[VAL_31:.*]] = select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : f32 | ||||
|   // CHECK: %[[VAL_32:.*]] = cmpf uno, %[[ARG0]], %[[ARG1]] : f32 | ||||
|   // CHECK: %[[VAL_35:.*]] = select %[[VAL_32]], %[[CST_10]], %[[VAL_31]] : f32 | ||||
|   // CHECK: %[[VAL_36:.*]] = select %[[VAL_26]], %[[CST_11]], %[[CST_12]] : f32 | ||||
|   // CHECK: %[[VAL_37:.*]] = cmpf oeq, %[[ARG1]], %[[CST_13]] : f32 | ||||
|   // CHECK: %[[VAL_38:.*]] = cmpf oeq, %[[ARG0]], %[[CST_13]] : f32 | ||||
|   // CHECK: %[[VAL_39:.*]] = and %[[VAL_37]], %[[VAL_38]] : i1 | ||||
|   // CHECK: %[[VAL_40:.*]] = select %[[VAL_39]], %[[VAL_36]], %[[VAL_35]] : f32 | ||||
|   // CHECK: %[[VAL_41:.*]] = copysign %[[VAL_40]], %[[ARG0]] : f32 | ||||
|   // CHECK: return %[[VAL_41]] : f32 | ||||
|   %res = atan2 %arg0, %arg1 : f32 | ||||
|   return %res : f32 | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: @atan2_f16 | ||||
| // CHECK-SAME: (%[[ARG0:.*]]: f16, %[[ARG1:.*]]: f16) -> f16 | ||||
| func @atan2_f16(%arg0 : f16, %arg1 : f16) -> f16 { | ||||
|   // CHECK: %[[CST:.*]] = constant 0.0027856871 : f32 | ||||
|   // CHECK: %[[CST_0:.*]] = constant -1.586600e-02 : f32 | ||||
|   // CHECK: %[[CST_1:.*]] = constant 0.042472221 : f32 | ||||
|   // CHECK: %[[CST_2:.*]] = constant -0.0749753043 : f32 | ||||
|   // CHECK: %[[CST_3:.*]] = constant 0.106448799 : f32 | ||||
|   // CHECK: %[[CST_4:.*]] = constant -0.142070308 : f32 | ||||
|   // CHECK: %[[CST_5:.*]] = constant 0.199934542 : f32 | ||||
|   // CHECK: %[[CST_6:.*]] = constant -0.333331466 : f32 | ||||
|   // CHECK: %[[CST_7:.*]] = constant 1.57079637 : f32 | ||||
|   // CHECK: %[[CST_8:.*]] = constant 0.000000e+00 : f32 | ||||
|   // CHECK: %[[CST_9:.*]] = constant 3.14159274 : f32 | ||||
|   // CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32 | ||||
|   // CHECK: %[[CST_11:.*]] = constant 2.3561945 : f32 | ||||
|   // CHECK: %[[CST_12:.*]] = constant 0.785398185 : f32 | ||||
|   // CHECK: %[[CST_13:.*]] = constant 0x7F800000 : f32 | ||||
|   // CHECK: %[[VAL_0:.*]] = fpext %[[ARG0]] : f16 to f32 | ||||
|   // CHECK: %[[VAL_1:.*]] = fpext %[[ARG1]] : f16 to f32 | ||||
|   // CHECK: %[[VAL_2:.*]] = absf %[[VAL_1]] : f32 | ||||
|   // CHECK: %[[VAL_3:.*]] = absf %[[VAL_0]] : f32 | ||||
|   // CHECK: %[[VAL_4:.*]] = cmpf ole, %[[VAL_2]], %[[VAL_3]] : f32 | ||||
|   // CHECK: %[[VAL_5:.*]] = select %[[VAL_4]], %[[VAL_2]], %[[VAL_3]] : f32 | ||||
|   // CHECK: %[[VAL_6:.*]] = select %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : f32 | ||||
|   // CHECK: %[[VAL_7:.*]] = divf %[[VAL_5]], %[[VAL_6]] : f32 | ||||
|   // CHECK: %[[VAL_8:.*]] = mulf %[[VAL_7]], %[[VAL_7]] : f32 | ||||
|   // CHECK: %[[VAL_9:.*]] = mulf %[[CST]], %[[VAL_8]] : f32 | ||||
|   // CHECK: %[[VAL_10:.*]] = addf %[[VAL_9]], %[[CST_0]] : f32 | ||||
|   // CHECK: %[[VAL_11:.*]] = mulf %[[VAL_10]], %[[VAL_8]] : f32 | ||||
|   // CHECK: %[[VAL_12:.*]] = addf %[[VAL_11]], %[[CST_1]] : f32 | ||||
|   // CHECK: %[[VAL_13:.*]] = mulf %[[VAL_12]], %[[VAL_8]] : f32 | ||||
|   // CHECK: %[[VAL_14:.*]] = addf %[[VAL_13]], %[[CST_2]] : f32 | ||||
|   // CHECK: %[[VAL_15:.*]] = mulf %[[VAL_14]], %[[VAL_8]] : f32 | ||||
|   // CHECK: %[[VAL_16:.*]] = addf %[[VAL_15]], %[[CST_3]] : f32 | ||||
|   // CHECK: %[[VAL_17:.*]] = mulf %[[VAL_16]], %[[VAL_8]] : f32 | ||||
|   // CHECK: %[[VAL_18:.*]] = addf %[[VAL_17]], %[[CST_4]] : f32 | ||||
|   // CHECK: %[[VAL_19:.*]] = mulf %[[VAL_18]], %[[VAL_8]] : f32 | ||||
|   // CHECK: %[[VAL_20:.*]] = addf %[[VAL_19]], %[[CST_5]] : f32 | ||||
|   // CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_8]] : f32 | ||||
|   // CHECK: %[[VAL_22:.*]] = addf %[[VAL_21]], %[[CST_6]] : f32 | ||||
|   // CHECK: %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_8]] : f32 | ||||
|   // CHECK: %[[VAL_24:.*]] = mulf %[[VAL_23]], %[[VAL_7]] : f32 | ||||
|   // CHECK: %[[VAL_25:.*]] = addf %[[VAL_24]], %[[VAL_7]] : f32 | ||||
|   // CHECK: %[[VAL_26:.*]] = subf %[[CST_7]], %[[VAL_25]] : f32 | ||||
|   // CHECK: %[[VAL_27:.*]] = select %[[VAL_4]], %[[VAL_26]], %[[VAL_25]] : f32 | ||||
|   // CHECK: %[[VAL_28:.*]] = cmpf olt, %[[VAL_1]], %[[CST_8]] : f32 | ||||
|   // CHECK: %[[VAL_29:.*]] = subf %[[CST_9]], %[[VAL_27]] : f32 | ||||
|   // CHECK: %[[VAL_30:.*]] = select %[[VAL_28]], %[[VAL_29]], %[[VAL_27]] : f32 | ||||
|   // CHECK: %[[VAL_31:.*]] = select %[[VAL_28]], %[[CST_9]], %[[CST_8]] : f32 | ||||
|   // CHECK: %[[VAL_32:.*]] = cmpf oeq, %[[VAL_0]], %[[CST_8]] : f32 | ||||
|   // CHECK: %[[VAL_33:.*]] = select %[[VAL_32]], %[[VAL_31]], %[[VAL_30]] : f32 | ||||
|   // CHECK: %[[VAL_34:.*]] = cmpf uno, %[[VAL_0]], %[[VAL_1]] : f32 | ||||
|   // CHECK: %[[VAL_37:.*]] = select %[[VAL_34]], %[[CST_10]], %[[VAL_33]] : f32 | ||||
|   // CHECK: %[[VAL_38:.*]] = select %[[VAL_28]], %[[CST_11]], %[[CST_12]] : f32 | ||||
|   // CHECK: %[[VAL_39:.*]] = cmpf oeq, %[[VAL_1]], %[[CST_13]] : f32 | ||||
|   // CHECK: %[[VAL_40:.*]] = cmpf oeq, %[[VAL_0]], %[[CST_13]] : f32 | ||||
|   // CHECK: %[[VAL_41:.*]] = and %[[VAL_39]], %[[VAL_40]] : i1 | ||||
|   // CHECK: %[[VAL_42:.*]] = select %[[VAL_41]], %[[VAL_38]], %[[VAL_37]] : f32 | ||||
|   // CHECK: %[[VAL_43:.*]] = copysign %[[VAL_42]], %[[VAL_0]] : f32 | ||||
|   // CHECK: %[[VAL_44:.*]] = fptrunc %[[VAL_43]] : f32 to f16 | ||||
|   // CHECK: return %[[VAL_44]] : f16 | ||||
|   %res = atan2 %arg0, %arg1 : f16 | ||||
|   return %res : f16 | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: @atan_f64 | ||||
| func @atan_f64(%arg : f64) -> f64 { | ||||
|   // CHECK: atan | ||||
|   %res = atan %arg : f64 | ||||
|   return %res : f64 | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: func @atan_f32 | ||||
| // CHECK-SAME: (%[[ARG:.*]]: f32) -> f32 | ||||
| func @atan_f32(%arg : f32) -> f32 { | ||||
|   // CHECK: %[[CST:.*]] = constant 1.000000e+00 : f32 | ||||
|   // CHECK: %[[CST_0:.*]] = constant 0.0027856871 : f32 | ||||
|   // CHECK: %[[CST_1:.*]] = constant -1.586600e-02 : f32 | ||||
|   // CHECK: %[[CST_2:.*]] = constant 0.042472221 : f32 | ||||
|   // CHECK: %[[CST_3:.*]] = constant -0.0749753043 : f32 | ||||
|   // CHECK: %[[CST_4:.*]] = constant 0.106448799 : f32 | ||||
|   // CHECK: %[[CST_5:.*]] = constant -0.142070308 : f32 | ||||
|   // CHECK: %[[CST_6:.*]] = constant 0.199934542 : f32 | ||||
|   // CHECK: %[[CST_7:.*]] = constant -0.333331466 : f32 | ||||
|   // CHECK: %[[CST_8:.*]] = constant 1.57079637 : f32 | ||||
|   // CHECK: %[[CST_9:.*]] = constant 0.000000e+00 : f32 | ||||
|   // CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32 | ||||
|   // CHECK: %[[VAL_0:.*]] = absf %[[CST]] : f32 | ||||
|   // CHECK: %[[VAL_1:.*]] = absf %arg0 : f32 | ||||
|   // CHECK: %[[VAL_2:.*]] = cmpf ole, %[[VAL_0]], %[[VAL_1]] : f32 | ||||
|   // CHECK: %[[VAL_3:.*]] = select %[[VAL_2]], %[[VAL_0]], %[[VAL_1]] : f32 | ||||
|   // CHECK: %[[VAL_4:.*]] = select %[[VAL_2]], %[[VAL_1]], %[[VAL_0]] : f32 | ||||
|   // CHECK: %[[VAL_5:.*]] = divf %[[VAL_3]], %[[VAL_4]] : f32 | ||||
|   // CHECK: %[[VAL_6:.*]] = mulf %[[VAL_5]], %[[VAL_5]] : f32 | ||||
|   // CHECK: %[[VAL_7:.*]] = mulf %[[CST_0]], %[[VAL_6]] : f32 | ||||
|   // CHECK: %[[VAL_8:.*]] = addf %[[VAL_7]], %[[CST_1]] : f32 | ||||
|   // CHECK: %[[VAL_9:.*]] = mulf %[[VAL_8]], %[[VAL_6]] : f32 | ||||
|   // CHECK: %[[VAL_10:.*]] = addf %[[VAL_9]], %[[CST_2]] : f32 | ||||
|   // CHECK: %[[VAL_11:.*]] = mulf %[[VAL_10]], %[[VAL_6]] : f32 | ||||
|   // CHECK: %[[VAL_12:.*]] = addf %[[VAL_11]], %[[CST_3]] : f32 | ||||
|   // CHECK: %[[VAL_13:.*]] = mulf %[[VAL_12]], %[[VAL_6]] : f32 | ||||
|   // CHECK: %[[VAL_14:.*]] = addf %[[VAL_13]], %[[CST_4]] : f32 | ||||
|   // CHECK: %[[VAL_15:.*]] = mulf %[[VAL_14]], %[[VAL_6]] : f32 | ||||
|   // CHECK: %[[VAL_16:.*]] = addf %[[VAL_15]], %[[CST_5]] : f32 | ||||
|   // CHECK: %[[VAL_17:.*]] = mulf %[[VAL_16]], %[[VAL_6]] : f32 | ||||
|   // CHECK: %[[VAL_18:.*]] = addf %[[VAL_17]], %[[CST_6]] : f32 | ||||
|   // CHECK: %[[VAL_19:.*]] = mulf %[[VAL_18]], %[[VAL_6]] : f32 | ||||
|   // CHECK: %[[VAL_20:.*]] = addf %[[VAL_19]], %[[CST_7]] : f32 | ||||
|   // CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_6]] : f32 | ||||
|   // CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_5]] : f32 | ||||
|   // CHECK: %[[VAL_23:.*]] = addf %[[VAL_22]], %[[VAL_5]] : f32 | ||||
|   // CHECK: %[[VAL_24:.*]] = subf %[[CST_8]], %[[VAL_23]] : f32 | ||||
|   // CHECK: %[[VAL_25:.*]] = select %[[VAL_2]], %[[VAL_24]], %[[VAL_23]] : f32 | ||||
|   // CHECK: %[[VAL_26:.*]] = cmpf oeq, %arg0, %[[CST_9]] : f32 | ||||
|   // CHECK: %[[VAL_27:.*]] = select %[[VAL_26]], %[[CST_9]], %[[VAL_25]] : f32 | ||||
|   // CHECK: %[[VAL_28:.*]] = cmpf uno, %arg0, %[[CST]] : f32 | ||||
|   // CHECK: %[[VAL_29:.*]] = select %[[VAL_28]], %[[CST_10]], %[[VAL_27]] : f32 | ||||
|   // CHECK: %[[VAL_30:.*]] = copysign %[[VAL_29]], %arg0 : f32 | ||||
|   // CHECK: return %[[VAL_30]] : f32 | ||||
|   %res = atan %arg : f32 | ||||
|   return %res : f32 | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: @atan_f16 | ||||
| // CHECK-SAME: (%[[ARG:.*]]: f16) -> f16 | ||||
| func @atan_f16(%arg : f16) -> f16 { | ||||
|   // CHECK: %[[CST:.*]] = constant 1.000000e+00 : f32 | ||||
|   // CHECK: %[[CST_0:.*]] = constant 0.0027856871 : f32 | ||||
|   // CHECK: %[[CST_1:.*]] = constant -1.586600e-02 : f32 | ||||
|   // CHECK: %[[CST_2:.*]] = constant 0.042472221 : f32 | ||||
|   // CHECK: %[[CST_3:.*]] = constant -0.0749753043 : f32 | ||||
|   // CHECK: %[[CST_4:.*]] = constant 0.106448799 : f32 | ||||
|   // CHECK: %[[CST_5:.*]] = constant -0.142070308 : f32 | ||||
|   // CHECK: %[[CST_6:.*]] = constant 0.199934542 : f32 | ||||
|   // CHECK: %[[CST_7:.*]] = constant -0.333331466 : f32 | ||||
|   // CHECK: %[[CST_8:.*]] = constant 1.57079637 : f32 | ||||
|   // CHECK: %[[CST_9:.*]] = constant 0.000000e+00 : f32 | ||||
|   // CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32 | ||||
|   // CHECK: %[[VAL_0:.*]] = fpext %arg0 : f16 to f32 | ||||
|   // CHECK: %[[VAL_1:.*]] = absf %[[CST]] : f32 | ||||
|   // CHECK: %[[VAL_2:.*]] = absf %[[VAL_0]] : f32 | ||||
|   // CHECK: %[[VAL_3:.*]] = cmpf ole, %[[VAL_1]], %[[VAL_2]] : f32 | ||||
|   // CHECK: %[[VAL_4:.*]] = select %[[VAL_3]], %[[VAL_1]], %[[VAL_2]] : f32 | ||||
|   // CHECK: %[[VAL_5:.*]] = select %[[VAL_3]], %[[VAL_2]], %[[VAL_1]] : f32 | ||||
|   // CHECK: %[[VAL_6:.*]] = divf %[[VAL_4]], %[[VAL_5]] : f32 | ||||
|   // CHECK: %[[VAL_7:.*]] = mulf %[[VAL_6]], %[[VAL_6]] : f32 | ||||
|   // CHECK: %[[VAL_8:.*]] = mulf %[[CST_0]], %[[VAL_7]] : f32 | ||||
|   // CHECK: %[[VAL_9:.*]] = addf %[[VAL_8]], %[[CST_1]] : f32 | ||||
|   // CHECK: %[[VAL_10:.*]] = mulf %[[VAL_9]], %[[VAL_7]] : f32 | ||||
|   // CHECK: %[[VAL_11:.*]] = addf %[[VAL_10]], %[[CST_2]] : f32 | ||||
|   // CHECK: %[[VAL_12:.*]] = mulf %[[VAL_11]], %[[VAL_7]] : f32 | ||||
|   // CHECK: %[[VAL_13:.*]] = addf %[[VAL_12]], %[[CST_3]] : f32 | ||||
|   // CHECK: %[[VAL_14:.*]] = mulf %[[VAL_13]], %[[VAL_7]] : f32 | ||||
|   // CHECK: %[[VAL_15:.*]] = addf %[[VAL_14]], %[[CST_4]] : f32 | ||||
|   // CHECK: %[[VAL_16:.*]] = mulf %[[VAL_15]], %[[VAL_7]] : f32 | ||||
|   // CHECK: %[[VAL_17:.*]] = addf %[[VAL_16]], %[[CST_5]] : f32 | ||||
|   // CHECK: %[[VAL_18:.*]] = mulf %[[VAL_17]], %[[VAL_7]] : f32 | ||||
|   // CHECK: %[[VAL_19:.*]] = addf %[[VAL_18]], %[[CST_6]] : f32 | ||||
|   // CHECK: %[[VAL_20:.*]] = mulf %[[VAL_19]], %[[VAL_7]] : f32 | ||||
|   // CHECK: %[[VAL_21:.*]] = addf %[[VAL_20]], %[[CST_7]] : f32 | ||||
|   // CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_7]] : f32 | ||||
|   // CHECK: %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_6]] : f32 | ||||
|   // CHECK: %[[VAL_24:.*]] = addf %[[VAL_23]], %[[VAL_6]] : f32 | ||||
|   // CHECK: %[[VAL_25:.*]] = subf %[[CST_8]], %[[VAL_24]] : f32 | ||||
|   // CHECK: %[[VAL_26:.*]] = select %[[VAL_3]], %[[VAL_25]], %[[VAL_24]] : f32 | ||||
|   // CHECK: %[[VAL_27:.*]] = cmpf oeq, %[[VAL_0]], %[[CST_9]] : f32 | ||||
|   // CHECK: %[[VAL_28:.*]] = select %[[VAL_27]], %[[CST_9]], %[[VAL_26]] : f32 | ||||
|   // CHECK: %[[VAL_29:.*]] = cmpf uno, %[[VAL_0]], %[[CST]] : f32 | ||||
|   // CHECK: %[[VAL_30:.*]] = select %[[VAL_29]], %[[CST_10]], %[[VAL_28]] : f32 | ||||
|   // CHECK: %[[VAL_31:.*]] = copysign %[[VAL_30]], %[[VAL_0]] : f32 | ||||
|   // CHECK: %[[VAL_32:.*]] = fptrunc %[[VAL_31]] : f32 to f16 | ||||
|   // CHECK: return %[[VAL_32]] : f16 | ||||
|   %res = atan %arg : f16 | ||||
|   return %res : f16 | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue