From 049ca060a1a6bf1df594888018ca723d71cf3817 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 1 Oct 2020 05:33:59 -0700 Subject: [PATCH] [MLIR][KernelGen] Legalize `atan2` to approximation Legalize `atan2` analogously to XLA. `atan2` is first reduced to `atan` on the interval [-1, 1] and subsequently approximated. This CL also adds e2e tests for trigonometric approximations. PiperOrigin-RevId: 334794336 --- ...legalize_trigonometric_to_approximation.cc | 287 ++++++++++++------ ...galize-trigonometric-to-approximation.mlir | 121 ++++++++ ...galize-trigonometric-to-approximation.mlir | 142 +++++++++ 3 files changed, 463 insertions(+), 87 deletions(-) create mode 100644 tests/end2end/legalize-trigonometric-to-approximation.mlir diff --git a/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc b/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc index 7021fb8..e24cd9c 100644 --- a/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc +++ b/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file implements logic for lowering the tanh standard ops to an -// approximation. +// This file implements the lowering for trigonometric standard ops to +// approximations. #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" @@ -27,102 +27,211 @@ namespace mlir { namespace mhlo { namespace { -/// Emits the fast tanh approximation that is also used by XLA. -Value EmitTanhApproximation(Value input, Location loc, - PatternRewriter &rewriter) { - // For small values of x, we can approximate tanh(x)=x. For extremely small - // values of x (|x| < 1e-37), the other approximation would evaluate - // tanh(x) = 0. - constexpr float kCanUseApprox = 0.0004; - Value abs_value = rewriter.create(loc, input); - Value can_use_approx = - rewriter.create(loc, rewriter.getF32FloatAttr(kCanUseApprox)); - Value return_input = rewriter.create(loc, CmpFPredicate::OLT, - abs_value, can_use_approx); - // Clamp the input to [-c, c]. - Value max_clamp = rewriter.create( - loc, rewriter.getF32FloatAttr(7.90531110763549805f)); - Value smaller_than_max = - rewriter.create(loc, CmpFPredicate::ULE, input, max_clamp); - Value clamped_half = - rewriter.create(loc, smaller_than_max, input, max_clamp); - Value min_clamp = rewriter.create( - loc, rewriter.getF32FloatAttr(-7.90531110763549805f)); - Value larger_than_min = - rewriter.create(loc, CmpFPredicate::UGE, clamped_half, min_clamp); - Value input_clamped = - rewriter.create(loc, larger_than_min, clamped_half, min_clamp); - - static constexpr std::array numerator_coeffs{ - -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, - 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f, - 4.89352455891786e-03f}; - - static constexpr std::array denominator_coeffs{ - 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f, - 4.89352518554385e-03f}; - - Value input_squared = - rewriter.create(loc, input_clamped, input_clamped); - Value numerator = rewriter.create( - loc, rewriter.getF32FloatAttr(numerator_coeffs[0])); - for (int i = 1; i < numerator_coeffs.size(); i++) { - numerator = rewriter.create( - loc, rewriter.create(loc, input_squared, numerator), - rewriter.create( - loc, rewriter.getF32FloatAttr(numerator_coeffs[i]))); - } - - numerator = rewriter.create(loc, input_clamped, numerator); - - Value denominator = rewriter.create( - loc, rewriter.getF32FloatAttr(denominator_coeffs[0])); - for (int i = 1; i < denominator_coeffs.size(); i++) { - denominator = rewriter.create( - loc, rewriter.create(loc, input_squared, denominator), - rewriter.create( - loc, rewriter.getF32FloatAttr(denominator_coeffs[i]))); - } - - Value approx = rewriter.create(loc, numerator, denominator); - - return rewriter.create(loc, return_input, input, approx); -} - -class ApproximateTanhLowering : public OpRewritePattern { +template +class ApproximateOnExtendedF32Lowering : public OpRewritePattern { public: - explicit ApproximateTanhLowering(MLIRContext *ctx) - : OpRewritePattern(ctx, 100) {} + explicit ApproximateOnExtendedF32Lowering(MLIRContext *ctx) + : OpRewritePattern(ctx, /*benefit=*/100) {} - LogicalResult matchAndRewrite(TanhOp tanhOp, + virtual Value emitApproximation(ValueRange, Location, + PatternRewriter &) const = 0; + + LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - Type operand_type = tanhOp.getType(); + Location loc = op.getLoc(); + auto raw_args = op.getOperation()->getOperands(); - if (operand_type.isF64()) { + // Supports only f16 and f32 for now. + if (!op.getType().isF16() && !op.getType().isF32()) return failure(); + + // Extend operands to f32 if needed and possible. + SmallVector f32_args; + f32_args.reserve(raw_args.size()); + for (Value arg : raw_args) { // Similar to XLA, do not rewrite f64 as precision might matter. - return failure(); + Type arg_ty = arg.getType(); + if (arg_ty.isF64()) return failure(); + + if (arg_ty.isF16()) + arg = rewriter.create(loc, arg, rewriter.getF32Type()); + + // If we still do not have f32, fail. + if (!arg.getType().isF32()) return failure(); + + f32_args.push_back(arg); } - Location loc = tanhOp.getLoc(); - Value input = tanhOp.operand(); - if (operand_type.isF16()) { - input = rewriter.create(loc, input, rewriter.getF32Type()); - } - - // If we still do not have f32, fail. - if (!input.getType().isF32()) { - return failure(); - } - - Value result = EmitTanhApproximation(input, loc, rewriter); + Value result = emitApproximation(f32_args, loc, rewriter); + assert(result.getType().isF32() && "Expect f32 intermediate result."); // Truncate back if needed. - if (operand_type.isF16()) { + if (op.getType().isF16()) result = rewriter.create(loc, result, rewriter.getF16Type()); + + rewriter.replaceOp(op, {result}); + return success(); + } +}; + +class ApproximateTanhLowering + : public ApproximateOnExtendedF32Lowering { + public: + explicit ApproximateTanhLowering(MLIRContext *ctx) + : ApproximateOnExtendedF32Lowering(ctx) {} + + // Emits the fast tanh approximation that is also used by XLA. + Value emitApproximation(ValueRange args, Location loc, + PatternRewriter &rewriter) const override { + // For small values of x, we can approximate tanh(x) = x. For extremely + // small values of x (|x| < 1e-37), the other approximation would evaluate + // tanh(x) = 0. + Value input = args.front(); + assert(input.getType().isF32()); + constexpr float kCanUseApprox = 0.0004; + Value abs_value = rewriter.create(loc, input); + Value can_use_approx = rewriter.create( + loc, rewriter.getF32FloatAttr(kCanUseApprox)); + Value return_input = rewriter.create(loc, CmpFPredicate::OLT, + abs_value, can_use_approx); + // Clamp the input to [-c, c]. + Value max_clamp = rewriter.create( + loc, rewriter.getF32FloatAttr(7.90531110763549805f)); + Value smaller_than_max = + rewriter.create(loc, CmpFPredicate::ULE, input, max_clamp); + Value clamped_half = + rewriter.create(loc, smaller_than_max, input, max_clamp); + Value min_clamp = rewriter.create( + loc, rewriter.getF32FloatAttr(-7.90531110763549805f)); + Value larger_than_min = rewriter.create(loc, CmpFPredicate::UGE, + clamped_half, min_clamp); + Value input_clamped = rewriter.create(loc, larger_than_min, + clamped_half, min_clamp); + + static constexpr std::array numerator_coeffs{ + -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, + 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f, + 4.89352455891786e-03f}; + + static constexpr std::array denominator_coeffs{ + 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f, + 4.89352518554385e-03f}; + + Value input_squared = + rewriter.create(loc, input_clamped, input_clamped); + Value numerator = rewriter.create( + loc, rewriter.getF32FloatAttr(numerator_coeffs[0])); + for (int i = 1; i < numerator_coeffs.size(); i++) { + numerator = rewriter.create( + loc, rewriter.create(loc, input_squared, numerator), + rewriter.create( + loc, rewriter.getF32FloatAttr(numerator_coeffs[i]))); } - rewriter.replaceOp(tanhOp, {result}); - return success(); + numerator = rewriter.create(loc, input_clamped, numerator); + + Value denominator = rewriter.create( + loc, rewriter.getF32FloatAttr(denominator_coeffs[0])); + for (int i = 1; i < denominator_coeffs.size(); i++) { + denominator = rewriter.create( + loc, rewriter.create(loc, input_squared, denominator), + rewriter.create( + loc, rewriter.getF32FloatAttr(denominator_coeffs[i]))); + } + + Value approx = rewriter.create(loc, numerator, denominator); + + return rewriter.create(loc, return_input, input, approx); + } +}; + +class ApproximateAtan2Lowering + : public ApproximateOnExtendedF32Lowering { + public: + explicit ApproximateAtan2Lowering(MLIRContext *ctx) + : ApproximateOnExtendedF32Lowering(ctx) {} + + // Reduces atan2 to atan in the same way XLA does it. + Value emitApproximation(ValueRange args, Location loc, + PatternRewriter &rewriter) const override { + Value y = args[0]; + Value x = args[1]; + assert(x.getType().isF32() && y.getType().isF32() && + "expect f32 arguments"); + Value ax = rewriter.create(loc, x); + Value ay = rewriter.create(loc, y); + Value le_ax_ay = rewriter.create(loc, CmpFPredicate::OLE, ax, ay); + Value min_ax_ay = rewriter.create(loc, le_ax_ay, ax, ay); + Value max_ax_ay = rewriter.create(loc, le_ax_ay, ay, ax); + Value zero_to_one = rewriter.create(loc, min_ax_ay, max_ax_ay); + Value a = emitAtanCoreApproximation(zero_to_one, loc, rewriter); + + Value pi_over_2 = + rewriter.create(loc, rewriter.getF32FloatAttr(1.57079637f)); + a = rewriter.create( + loc, le_ax_ay, rewriter.create(loc, pi_over_2, a), a); + + Value zero = rewriter.create(loc, rewriter.getF32FloatAttr(0)); + Value lt_x_0 = rewriter.create(loc, CmpFPredicate::OLT, x, zero); + Value pi = + rewriter.create(loc, rewriter.getF32FloatAttr(3.14159274f)); + a = rewriter.create(loc, lt_x_0, + rewriter.create(loc, pi, a), a); + + Value t = rewriter.create(loc, lt_x_0, pi, zero); + Value eq_y_0 = rewriter.create(loc, CmpFPredicate::OEQ, y, zero); + a = rewriter.create(loc, eq_y_0, t, a); + + // Propagate nan. + Value is_nan = rewriter.create(loc, CmpFPredicate::UNO, y, x); + Value nan = rewriter.create( + loc, rewriter.getF32FloatAttr(std::numeric_limits::quiet_NaN())); + a = rewriter.create(loc, is_nan, nan, a); + + // x and y are +- inf. + Value three_pi_over_4 = + rewriter.create(loc, rewriter.getF32FloatAttr(2.3561945f)); + Value pi_over_4 = rewriter.create( + loc, rewriter.getF32FloatAttr(0.785398185f)); + t = rewriter.create(loc, lt_x_0, three_pi_over_4, + pi_over_4); + Value inf = rewriter.create( + loc, rewriter.getF32FloatAttr(std::numeric_limits::infinity())); + Value eq_x_inf = rewriter.create(loc, CmpFPredicate::OEQ, x, inf); + Value eq_y_inf = rewriter.create(loc, CmpFPredicate::OEQ, y, inf); + Value all_inf = rewriter.create(loc, eq_x_inf, eq_y_inf); + a = rewriter.create(loc, all_inf, t, a); + + return rewriter.create(loc, a, y); + } + + private: + // The core atan reduction derives from the heuristic described in + // https://arxiv.org/abs/1508.03211 and has a < 0.95 ulp error in the [-1, 1] + // range (though that assumed FMA was available, and it is not here). This is + // the same approximation that is also used by XLA. + Value emitAtanCoreApproximation(Value x, Location loc, + PatternRewriter &rewriter) const { + auto constant = [&](float c) { + return rewriter.create(loc, rewriter.getF32FloatAttr(c)); + }; + + // Computes ab + c. + auto mul_add = [&](Value a, Value b, Value c) { + Value prod = rewriter.create(loc, a, b); + return rewriter.create(loc, prod, c); + }; + + Value s = rewriter.create(loc, x, x); + Value r = constant(0.0027856871f); + r = mul_add(r, s, constant(-0.0158660002f)); + r = mul_add(r, s, constant(0.042472221f)); + r = mul_add(r, s, constant(-0.0749753043f)); + r = mul_add(r, s, constant(0.106448799f)); + r = mul_add(r, s, constant(-0.142070308f)); + r = mul_add(r, s, constant(0.199934542f)); + r = mul_add(r, s, constant(-0.333331466f)); + r = rewriter.create(loc, r, s); + return mul_add(r, x, x); } }; @@ -146,7 +255,11 @@ createLegalizeTrigonometricToApproximationPass() { void PopulateTrigonometricToApproximationPatterns( mlir::MLIRContext *context, OwningRewritePatternList *patterns) { - patterns->insert(context); + // clang-format off + patterns->insert< + ApproximateAtan2Lowering, + ApproximateTanhLowering>(context); + // clang-format on } } // namespace mhlo diff --git a/tests/end2end/legalize-trigonometric-to-approximation.mlir b/tests/end2end/legalize-trigonometric-to-approximation.mlir new file mode 100644 index 0000000..48b9d3b --- /dev/null +++ b/tests/end2end/legalize-trigonometric-to-approximation.mlir @@ -0,0 +1,121 @@ +// RUN: mlir-hlo-opt %s --mhlo-legalize-trigonometric-to-approximation --convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void --shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s + +func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface } + +// Helper function to print scalar values. +func @print_f32(%arg : f32) { + %mem = alloca() : memref<1xf32> + %c0 = constant 0 : index + store %arg, %mem[%c0] : memref<1xf32> + %mem_unranked = memref_cast %mem : memref<1xf32> to memref<*xf32> + call @print_memref_f32(%mem_unranked) : (memref<*xf32>) -> () + return +} + +// Compute and print trigonometric function. +func @atan2_f32(%arg0 : f32, %arg1 : f32) { + %res = atan2 %arg0, %arg1 : f32 + call @print_f32(%res) : (f32) -> () + return +} + +func @tanh_f32(%arg : f32) { + %res = tanh %arg : f32 + call @print_f32(%res) : (f32) -> () + return +} + +func @main() { + // Some constants to use as arguments. + %cf_n50_0 = constant -50.0 : f32 + %cf_n5_0 = constant -5.0 : f32 + %cf_n3_0 = constant -3.0 : f32 + %cf_n2_0 = constant -2.0 : f32 + %cf_n1_0 = constant -1.0 : f32 + %cf_n0_5 = constant -0.5 : f32 + %cf_n0_1 = constant -0.1 : f32 + %cf_0_0 = constant 0.0 : f32 + %cf_0_1 = constant 0.1 : f32 + %cf_0_5 = constant 0.5 : f32 + %cf_1_0 = constant 1.0 : f32 + %cf_2_0 = constant 2.0 : f32 + %cf_3_0 = constant 3.0 : f32 + %cf_5_0 = constant 5.0 : f32 + %cf_50_0 = constant 50.0 : f32 + + // Tanh. + call @tanh_f32(%cf_n50_0) : (f32) -> () + // CHECK: -1 + call @tanh_f32(%cf_n5_0) : (f32) -> () + // CHECK: -0.999{{.*}} + call @tanh_f32(%cf_n3_0) : (f32) -> () + // CHECK: -0.995{{.*}} + call @tanh_f32(%cf_n2_0) : (f32) -> () + // CHECK: -0.964{{.*}} + call @tanh_f32(%cf_n1_0) : (f32) -> () + // CHECK: -0.761{{.*}} + call @tanh_f32(%cf_n0_5) : (f32) -> () + // CHECK: -0.462{{.*}} + call @tanh_f32(%cf_n0_1) : (f32) -> () + // CHECK: -0.099{{.*}} + call @tanh_f32(%cf_0_0) : (f32) -> () + // CHECK: 0 + call @tanh_f32(%cf_0_1) : (f32) -> () + // CHECK: 0.099{{.*}} + call @tanh_f32(%cf_0_5) : (f32) -> () + // CHECK: 0.462{{.*}} + call @tanh_f32(%cf_1_0) : (f32) -> () + // CHECK: 0.761{{.*}} + call @tanh_f32(%cf_2_0) : (f32) -> () + // CHECK: 0.964{{.*}} + call @tanh_f32(%cf_3_0) : (f32) -> () + // CHECK: 0.995{{.*}} + call @tanh_f32(%cf_5_0) : (f32) -> () + // CHECK: 0.999{{.*}} + call @tanh_f32(%cf_50_0) : (f32) -> () + // CHECK: 1 + + // Atan2 with divisor 1. + call @atan2_f32(%cf_n50_0, %cf_1_0) : (f32, f32) -> () + // CHECK: -1.550{{.*}} + call @atan2_f32(%cf_n5_0, %cf_1_0) : (f32, f32) -> () + // CHECK: -1.373{{.*}} + call @atan2_f32(%cf_n3_0, %cf_1_0) : (f32, f32) -> () + // CHECK: -1.249{{.*}} + call @atan2_f32(%cf_n2_0, %cf_1_0) : (f32, f32) -> () + // CHECK: -1.107{{.*}} + call @atan2_f32(%cf_n1_0, %cf_1_0) : (f32, f32) -> () + // CHECK: -0.785{{.*}} + call @atan2_f32(%cf_n0_5, %cf_1_0) : (f32, f32) -> () + // CHECK: -0.463{{.*}} + call @atan2_f32(%cf_n0_1, %cf_1_0) : (f32, f32) -> () + // CHECK: -0.099{{.*}} + call @atan2_f32(%cf_0_0, %cf_1_0) : (f32, f32) -> () + // CHECK: 0 + call @atan2_f32(%cf_0_1, %cf_1_0) : (f32, f32) -> () + // CHECK: 0.099{{.*}} + call @atan2_f32(%cf_0_5, %cf_1_0) : (f32, f32) -> () + // CHECK: 0.463{{.*}} + call @atan2_f32(%cf_1_0, %cf_1_0) : (f32, f32) -> () + // CHECK: 0.785{{.*}} + call @atan2_f32(%cf_2_0, %cf_1_0) : (f32, f32) -> () + // CHECK: 1.107{{.*}} + call @atan2_f32(%cf_3_0, %cf_1_0) : (f32, f32) -> () + // CHECK: 1.249{{.*}} + call @atan2_f32(%cf_5_0, %cf_1_0) : (f32, f32) -> () + // CHECK: 1.373{{.*}} + call @atan2_f32(%cf_50_0, %cf_1_0) : (f32, f32) -> () + // CHECK: 1.550{{.*}} + + // Atan2 all four quadrants. + call @atan2_f32(%cf_n1_0, %cf_n1_0) : (f32, f32) -> () + // CHECK: -2.356{{.*}} + call @atan2_f32(%cf_n1_0, %cf_1_0) : (f32, f32) -> () + // CHECK: -0.785{{.*}} + call @atan2_f32(%cf_1_0, %cf_n1_0) : (f32, f32) -> () + // CHECK: 2.356{{.*}} + call @atan2_f32(%cf_1_0, %cf_1_0) : (f32, f32) -> () + // CHECK: 0.785{{.*}} + + return +} diff --git a/tests/legalize-trigonometric-to-approximation.mlir b/tests/legalize-trigonometric-to-approximation.mlir index b138e1c..43278ff 100644 --- a/tests/legalize-trigonometric-to-approximation.mlir +++ b/tests/legalize-trigonometric-to-approximation.mlir @@ -119,3 +119,145 @@ func @tanh_f16(%arg0 : f16) -> f16 { // CHECK: %[[VAL_43:.*]] = select %[[VAL_17]], %[[VAL_15]], %[[VAL_42]] : f32 // CHECK: %[[VAL_44:.*]] = fptrunc %[[VAL_43]] : f32 to f16 // CHECK: return %[[VAL_44]] : f16 + +// ----- + +// CHECK-LABEL: @atan2_f64 +func @atan2_f64(%arg0 : f64, %arg1 : f64) -> f64 { + // CHECK: atan2 + %res = atan2 %arg0, %arg1 : f64 + return %res : f64 +} + +// ----- + +// CHECK-LABEL: func @atan2_f32 +// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32) -> f32 +func @atan2_f32(%arg0 : f32, %arg1 : f32) -> f32 { + // CHECK: %[[CST:.*]] = constant 0.0027856871 : f32 + // CHECK: %[[CST_0:.*]] = constant -1.586600e-02 : f32 + // CHECK: %[[CST_1:.*]] = constant 0.042472221 : f32 + // CHECK: %[[CST_2:.*]] = constant -0.0749753043 : f32 + // CHECK: %[[CST_3:.*]] = constant 0.106448799 : f32 + // CHECK: %[[CST_4:.*]] = constant -0.142070308 : f32 + // CHECK: %[[CST_5:.*]] = constant 0.199934542 : f32 + // CHECK: %[[CST_6:.*]] = constant -0.333331466 : f32 + // CHECK: %[[CST_7:.*]] = constant 1.57079637 : f32 + // CHECK: %[[CST_8:.*]] = constant 0.000000e+00 : f32 + // CHECK: %[[CST_9:.*]] = constant 3.14159274 : f32 + // CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32 + // CHECK: %[[CST_11:.*]] = constant 2.3561945 : f32 + // CHECK: %[[CST_12:.*]] = constant 0.785398185 : f32 + // CHECK: %[[CST_13:.*]] = constant 0x7F800000 : f32 + // CHECK: %[[VAL_0:.*]] = absf %[[ARG1]] : f32 + // CHECK: %[[VAL_1:.*]] = absf %[[ARG0]] : f32 + // CHECK: %[[VAL_2:.*]] = cmpf "ole", %[[VAL_0]], %[[VAL_1]] : f32 + // CHECK: %[[VAL_3:.*]] = select %[[VAL_2]], %[[VAL_0]], %[[VAL_1]] : f32 + // CHECK: %[[VAL_4:.*]] = select %[[VAL_2]], %[[VAL_1]], %[[VAL_0]] : f32 + // CHECK: %[[VAL_5:.*]] = divf %[[VAL_3]], %[[VAL_4]] : f32 + // CHECK: %[[VAL_6:.*]] = mulf %[[VAL_5]], %[[VAL_5]] : f32 + // CHECK: %[[VAL_7:.*]] = mulf %[[CST]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_8:.*]] = addf %[[VAL_7]], %[[CST_0]] : f32 + // CHECK: %[[VAL_9:.*]] = mulf %[[VAL_8]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_10:.*]] = addf %[[VAL_9]], %[[CST_1]] : f32 + // CHECK: %[[VAL_11:.*]] = mulf %[[VAL_10]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_12:.*]] = addf %[[VAL_11]], %[[CST_2]] : f32 + // CHECK: %[[VAL_13:.*]] = mulf %[[VAL_12]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_14:.*]] = addf %[[VAL_13]], %[[CST_3]] : f32 + // CHECK: %[[VAL_15:.*]] = mulf %[[VAL_14]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_16:.*]] = addf %[[VAL_15]], %[[CST_4]] : f32 + // CHECK: %[[VAL_17:.*]] = mulf %[[VAL_16]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_18:.*]] = addf %[[VAL_17]], %[[CST_5]] : f32 + // CHECK: %[[VAL_19:.*]] = mulf %[[VAL_18]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_20:.*]] = addf %[[VAL_19]], %[[CST_6]] : f32 + // CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_5]] : f32 + // CHECK: %[[VAL_23:.*]] = addf %[[VAL_22]], %[[VAL_5]] : f32 + // CHECK: %[[VAL_24:.*]] = subf %[[CST_7]], %[[VAL_23]] : f32 + // CHECK: %[[VAL_25:.*]] = select %[[VAL_2]], %[[VAL_24]], %[[VAL_23]] : f32 + // CHECK: %[[VAL_26:.*]] = cmpf "olt", %[[ARG1]], %[[CST_8]] : f32 + // CHECK: %[[VAL_27:.*]] = subf %[[CST_9]], %[[VAL_25]] : f32 + // CHECK: %[[VAL_28:.*]] = select %[[VAL_26]], %[[VAL_27]], %[[VAL_25]] : f32 + // CHECK: %[[VAL_29:.*]] = select %[[VAL_26]], %[[CST_9]], %[[CST_8]] : f32 + // CHECK: %[[VAL_30:.*]] = cmpf "oeq", %[[ARG0]], %[[CST_8]] : f32 + // CHECK: %[[VAL_31:.*]] = select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : f32 + // CHECK: %[[VAL_32:.*]] = cmpf "uno", %[[ARG0]], %[[ARG1]] : f32 + // CHECK: %[[VAL_35:.*]] = select %[[VAL_32]], %[[CST_10]], %[[VAL_31]] : f32 + // CHECK: %[[VAL_36:.*]] = select %[[VAL_26]], %[[CST_11]], %[[CST_12]] : f32 + // CHECK: %[[VAL_37:.*]] = cmpf "oeq", %[[ARG1]], %[[CST_13]] : f32 + // CHECK: %[[VAL_38:.*]] = cmpf "oeq", %[[ARG0]], %[[CST_13]] : f32 + // CHECK: %[[VAL_39:.*]] = and %[[VAL_37]], %[[VAL_38]] : i1 + // CHECK: %[[VAL_40:.*]] = select %[[VAL_39]], %[[VAL_36]], %[[VAL_35]] : f32 + // CHECK: %[[VAL_41:.*]] = copysign %[[VAL_40]], %[[ARG0]] : f32 + // CHECK: return %[[VAL_41]] : f32 + %res = atan2 %arg0, %arg1 : f32 + return %res : f32 +} + +// ----- + +// CHECK-LABEL: @atan2_f16 +// CHECK-SAME: (%[[ARG0:.*]]: f16, %[[ARG1:.*]]: f16) -> f16 +func @atan2_f16(%arg0 : f16, %arg1 : f16) -> f16 { + // CHECK: %[[CST:.*]] = constant 0.0027856871 : f32 + // CHECK: %[[CST_0:.*]] = constant -1.586600e-02 : f32 + // CHECK: %[[CST_1:.*]] = constant 0.042472221 : f32 + // CHECK: %[[CST_2:.*]] = constant -0.0749753043 : f32 + // CHECK: %[[CST_3:.*]] = constant 0.106448799 : f32 + // CHECK: %[[CST_4:.*]] = constant -0.142070308 : f32 + // CHECK: %[[CST_5:.*]] = constant 0.199934542 : f32 + // CHECK: %[[CST_6:.*]] = constant -0.333331466 : f32 + // CHECK: %[[CST_7:.*]] = constant 1.57079637 : f32 + // CHECK: %[[CST_8:.*]] = constant 0.000000e+00 : f32 + // CHECK: %[[CST_9:.*]] = constant 3.14159274 : f32 + // CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32 + // CHECK: %[[CST_11:.*]] = constant 2.3561945 : f32 + // CHECK: %[[CST_12:.*]] = constant 0.785398185 : f32 + // CHECK: %[[CST_13:.*]] = constant 0x7F800000 : f32 + // CHECK: %[[VAL_0:.*]] = fpext %[[ARG0]] : f16 to f32 + // CHECK: %[[VAL_1:.*]] = fpext %[[ARG1]] : f16 to f32 + // CHECK: %[[VAL_2:.*]] = absf %[[VAL_1]] : f32 + // CHECK: %[[VAL_3:.*]] = absf %[[VAL_0]] : f32 + // CHECK: %[[VAL_4:.*]] = cmpf "ole", %[[VAL_2]], %[[VAL_3]] : f32 + // CHECK: %[[VAL_5:.*]] = select %[[VAL_4]], %[[VAL_2]], %[[VAL_3]] : f32 + // CHECK: %[[VAL_6:.*]] = select %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : f32 + // CHECK: %[[VAL_7:.*]] = divf %[[VAL_5]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_8:.*]] = mulf %[[VAL_7]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_9:.*]] = mulf %[[CST]], %[[VAL_8]] : f32 + // CHECK: %[[VAL_10:.*]] = addf %[[VAL_9]], %[[CST_0]] : f32 + // CHECK: %[[VAL_11:.*]] = mulf %[[VAL_10]], %[[VAL_8]] : f32 + // CHECK: %[[VAL_12:.*]] = addf %[[VAL_11]], %[[CST_1]] : f32 + // CHECK: %[[VAL_13:.*]] = mulf %[[VAL_12]], %[[VAL_8]] : f32 + // CHECK: %[[VAL_14:.*]] = addf %[[VAL_13]], %[[CST_2]] : f32 + // CHECK: %[[VAL_15:.*]] = mulf %[[VAL_14]], %[[VAL_8]] : f32 + // CHECK: %[[VAL_16:.*]] = addf %[[VAL_15]], %[[CST_3]] : f32 + // CHECK: %[[VAL_17:.*]] = mulf %[[VAL_16]], %[[VAL_8]] : f32 + // CHECK: %[[VAL_18:.*]] = addf %[[VAL_17]], %[[CST_4]] : f32 + // CHECK: %[[VAL_19:.*]] = mulf %[[VAL_18]], %[[VAL_8]] : f32 + // CHECK: %[[VAL_20:.*]] = addf %[[VAL_19]], %[[CST_5]] : f32 + // CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_8]] : f32 + // CHECK: %[[VAL_22:.*]] = addf %[[VAL_21]], %[[CST_6]] : f32 + // CHECK: %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_8]] : f32 + // CHECK: %[[VAL_24:.*]] = mulf %[[VAL_23]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_25:.*]] = addf %[[VAL_24]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_26:.*]] = subf %[[CST_7]], %[[VAL_25]] : f32 + // CHECK: %[[VAL_27:.*]] = select %[[VAL_4]], %[[VAL_26]], %[[VAL_25]] : f32 + // CHECK: %[[VAL_28:.*]] = cmpf "olt", %[[VAL_1]], %[[CST_8]] : f32 + // CHECK: %[[VAL_29:.*]] = subf %[[CST_9]], %[[VAL_27]] : f32 + // CHECK: %[[VAL_30:.*]] = select %[[VAL_28]], %[[VAL_29]], %[[VAL_27]] : f32 + // CHECK: %[[VAL_31:.*]] = select %[[VAL_28]], %[[CST_9]], %[[CST_8]] : f32 + // CHECK: %[[VAL_32:.*]] = cmpf "oeq", %[[VAL_0]], %[[CST_8]] : f32 + // CHECK: %[[VAL_33:.*]] = select %[[VAL_32]], %[[VAL_31]], %[[VAL_30]] : f32 + // CHECK: %[[VAL_34:.*]] = cmpf "uno", %[[VAL_0]], %[[VAL_1]] : f32 + // CHECK: %[[VAL_37:.*]] = select %[[VAL_34]], %[[CST_10]], %[[VAL_33]] : f32 + // CHECK: %[[VAL_38:.*]] = select %[[VAL_28]], %[[CST_11]], %[[CST_12]] : f32 + // CHECK: %[[VAL_39:.*]] = cmpf "oeq", %[[VAL_1]], %[[CST_13]] : f32 + // CHECK: %[[VAL_40:.*]] = cmpf "oeq", %[[VAL_0]], %[[CST_13]] : f32 + // CHECK: %[[VAL_41:.*]] = and %[[VAL_39]], %[[VAL_40]] : i1 + // CHECK: %[[VAL_42:.*]] = select %[[VAL_41]], %[[VAL_38]], %[[VAL_37]] : f32 + // CHECK: %[[VAL_43:.*]] = copysign %[[VAL_42]], %[[VAL_0]] : f32 + // CHECK: %[[VAL_44:.*]] = fptrunc %[[VAL_43]] : f32 to f16 + // CHECK: return %[[VAL_44]] : f16 + %res = atan2 %arg0, %arg1 : f16 + return %res : f16 +}