[MLIR][KernelGen] Legalize `atan2` to approximation

Legalize `atan2` analogously to XLA.  `atan2` is first reduced to `atan` on the
interval [-1, 1] and subsequently approximated.  This CL also adds e2e tests for
trigonometric approximations.

PiperOrigin-RevId: 334794336
This commit is contained in:
A. Unique TensorFlower 2020-10-01 05:33:59 -07:00 committed by TensorFlow MLIR Team
parent 4b1809784a
commit 049ca060a1
3 changed files with 463 additions and 87 deletions

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for lowering the tanh standard ops to an
// approximation.
// This file implements the lowering for trigonometric standard ops to
// approximations.
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
@ -27,102 +27,211 @@ namespace mlir {
namespace mhlo {
namespace {
/// Emits the fast tanh approximation that is also used by XLA.
Value EmitTanhApproximation(Value input, Location loc,
PatternRewriter &rewriter) {
// For small values of x, we can approximate tanh(x)=x. For extremely small
// values of x (|x| < 1e-37), the other approximation would evaluate
// tanh(x) = 0.
constexpr float kCanUseApprox = 0.0004;
Value abs_value = rewriter.create<AbsFOp>(loc, input);
Value can_use_approx =
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(kCanUseApprox));
Value return_input = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT,
abs_value, can_use_approx);
// Clamp the input to [-c, c].
Value max_clamp = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(7.90531110763549805f));
Value smaller_than_max =
rewriter.create<CmpFOp>(loc, CmpFPredicate::ULE, input, max_clamp);
Value clamped_half =
rewriter.create<SelectOp>(loc, smaller_than_max, input, max_clamp);
Value min_clamp = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(-7.90531110763549805f));
Value larger_than_min =
rewriter.create<CmpFOp>(loc, CmpFPredicate::UGE, clamped_half, min_clamp);
Value input_clamped =
rewriter.create<SelectOp>(loc, larger_than_min, clamped_half, min_clamp);
static constexpr std::array<float, 7> numerator_coeffs{
-2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f,
4.89352455891786e-03f};
static constexpr std::array<float, 4> denominator_coeffs{
1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f,
4.89352518554385e-03f};
Value input_squared =
rewriter.create<MulFOp>(loc, input_clamped, input_clamped);
Value numerator = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(numerator_coeffs[0]));
for (int i = 1; i < numerator_coeffs.size(); i++) {
numerator = rewriter.create<AddFOp>(
loc, rewriter.create<MulFOp>(loc, input_squared, numerator),
rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(numerator_coeffs[i])));
}
numerator = rewriter.create<MulFOp>(loc, input_clamped, numerator);
Value denominator = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(denominator_coeffs[0]));
for (int i = 1; i < denominator_coeffs.size(); i++) {
denominator = rewriter.create<AddFOp>(
loc, rewriter.create<MulFOp>(loc, input_squared, denominator),
rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(denominator_coeffs[i])));
}
Value approx = rewriter.create<DivFOp>(loc, numerator, denominator);
return rewriter.create<SelectOp>(loc, return_input, input, approx);
}
class ApproximateTanhLowering : public OpRewritePattern<TanhOp> {
template <typename OpTy>
class ApproximateOnExtendedF32Lowering : public OpRewritePattern<OpTy> {
public:
explicit ApproximateTanhLowering(MLIRContext *ctx)
: OpRewritePattern<TanhOp>(ctx, 100) {}
explicit ApproximateOnExtendedF32Lowering(MLIRContext *ctx)
: OpRewritePattern<OpTy>(ctx, /*benefit=*/100) {}
LogicalResult matchAndRewrite(TanhOp tanhOp,
virtual Value emitApproximation(ValueRange, Location,
PatternRewriter &) const = 0;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
Type operand_type = tanhOp.getType();
Location loc = op.getLoc();
auto raw_args = op.getOperation()->getOperands();
if (operand_type.isF64()) {
// Supports only f16 and f32 for now.
if (!op.getType().isF16() && !op.getType().isF32()) return failure();
// Extend operands to f32 if needed and possible.
SmallVector<Value, 2> f32_args;
f32_args.reserve(raw_args.size());
for (Value arg : raw_args) {
// Similar to XLA, do not rewrite f64 as precision might matter.
return failure();
Type arg_ty = arg.getType();
if (arg_ty.isF64()) return failure();
if (arg_ty.isF16())
arg = rewriter.create<FPExtOp>(loc, arg, rewriter.getF32Type());
// If we still do not have f32, fail.
if (!arg.getType().isF32()) return failure();
f32_args.push_back(arg);
}
Location loc = tanhOp.getLoc();
Value input = tanhOp.operand();
if (operand_type.isF16()) {
input = rewriter.create<FPExtOp>(loc, input, rewriter.getF32Type());
}
// If we still do not have f32, fail.
if (!input.getType().isF32()) {
return failure();
}
Value result = EmitTanhApproximation(input, loc, rewriter);
Value result = emitApproximation(f32_args, loc, rewriter);
assert(result.getType().isF32() && "Expect f32 intermediate result.");
// Truncate back if needed.
if (operand_type.isF16()) {
if (op.getType().isF16())
result = rewriter.create<FPTruncOp>(loc, result, rewriter.getF16Type());
rewriter.replaceOp(op, {result});
return success();
}
};
class ApproximateTanhLowering
: public ApproximateOnExtendedF32Lowering<TanhOp> {
public:
explicit ApproximateTanhLowering(MLIRContext *ctx)
: ApproximateOnExtendedF32Lowering<TanhOp>(ctx) {}
// Emits the fast tanh approximation that is also used by XLA.
Value emitApproximation(ValueRange args, Location loc,
PatternRewriter &rewriter) const override {
// For small values of x, we can approximate tanh(x) = x. For extremely
// small values of x (|x| < 1e-37), the other approximation would evaluate
// tanh(x) = 0.
Value input = args.front();
assert(input.getType().isF32());
constexpr float kCanUseApprox = 0.0004;
Value abs_value = rewriter.create<AbsFOp>(loc, input);
Value can_use_approx = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(kCanUseApprox));
Value return_input = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT,
abs_value, can_use_approx);
// Clamp the input to [-c, c].
Value max_clamp = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(7.90531110763549805f));
Value smaller_than_max =
rewriter.create<CmpFOp>(loc, CmpFPredicate::ULE, input, max_clamp);
Value clamped_half =
rewriter.create<SelectOp>(loc, smaller_than_max, input, max_clamp);
Value min_clamp = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(-7.90531110763549805f));
Value larger_than_min = rewriter.create<CmpFOp>(loc, CmpFPredicate::UGE,
clamped_half, min_clamp);
Value input_clamped = rewriter.create<SelectOp>(loc, larger_than_min,
clamped_half, min_clamp);
static constexpr std::array<float, 7> numerator_coeffs{
-2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f,
4.89352455891786e-03f};
static constexpr std::array<float, 4> denominator_coeffs{
1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f,
4.89352518554385e-03f};
Value input_squared =
rewriter.create<MulFOp>(loc, input_clamped, input_clamped);
Value numerator = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(numerator_coeffs[0]));
for (int i = 1; i < numerator_coeffs.size(); i++) {
numerator = rewriter.create<AddFOp>(
loc, rewriter.create<MulFOp>(loc, input_squared, numerator),
rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(numerator_coeffs[i])));
}
rewriter.replaceOp(tanhOp, {result});
return success();
numerator = rewriter.create<MulFOp>(loc, input_clamped, numerator);
Value denominator = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(denominator_coeffs[0]));
for (int i = 1; i < denominator_coeffs.size(); i++) {
denominator = rewriter.create<AddFOp>(
loc, rewriter.create<MulFOp>(loc, input_squared, denominator),
rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(denominator_coeffs[i])));
}
Value approx = rewriter.create<DivFOp>(loc, numerator, denominator);
return rewriter.create<SelectOp>(loc, return_input, input, approx);
}
};
class ApproximateAtan2Lowering
: public ApproximateOnExtendedF32Lowering<Atan2Op> {
public:
explicit ApproximateAtan2Lowering(MLIRContext *ctx)
: ApproximateOnExtendedF32Lowering<Atan2Op>(ctx) {}
// Reduces atan2 to atan in the same way XLA does it.
Value emitApproximation(ValueRange args, Location loc,
PatternRewriter &rewriter) const override {
Value y = args[0];
Value x = args[1];
assert(x.getType().isF32() && y.getType().isF32() &&
"expect f32 arguments");
Value ax = rewriter.create<AbsFOp>(loc, x);
Value ay = rewriter.create<AbsFOp>(loc, y);
Value le_ax_ay = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLE, ax, ay);
Value min_ax_ay = rewriter.create<mlir::SelectOp>(loc, le_ax_ay, ax, ay);
Value max_ax_ay = rewriter.create<mlir::SelectOp>(loc, le_ax_ay, ay, ax);
Value zero_to_one = rewriter.create<DivFOp>(loc, min_ax_ay, max_ax_ay);
Value a = emitAtanCoreApproximation(zero_to_one, loc, rewriter);
Value pi_over_2 =
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.57079637f));
a = rewriter.create<mlir::SelectOp>(
loc, le_ax_ay, rewriter.create<SubFOp>(loc, pi_over_2, a), a);
Value zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0));
Value lt_x_0 = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, x, zero);
Value pi =
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(3.14159274f));
a = rewriter.create<mlir::SelectOp>(loc, lt_x_0,
rewriter.create<SubFOp>(loc, pi, a), a);
Value t = rewriter.create<mlir::SelectOp>(loc, lt_x_0, pi, zero);
Value eq_y_0 = rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, y, zero);
a = rewriter.create<mlir::SelectOp>(loc, eq_y_0, t, a);
// Propagate nan.
Value is_nan = rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, y, x);
Value nan = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(std::numeric_limits<float>::quiet_NaN()));
a = rewriter.create<mlir::SelectOp>(loc, is_nan, nan, a);
// x and y are +- inf.
Value three_pi_over_4 =
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.3561945f));
Value pi_over_4 = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(0.785398185f));
t = rewriter.create<mlir::SelectOp>(loc, lt_x_0, three_pi_over_4,
pi_over_4);
Value inf = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(std::numeric_limits<float>::infinity()));
Value eq_x_inf = rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, x, inf);
Value eq_y_inf = rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, y, inf);
Value all_inf = rewriter.create<mlir::AndOp>(loc, eq_x_inf, eq_y_inf);
a = rewriter.create<mlir::SelectOp>(loc, all_inf, t, a);
return rewriter.create<CopySignOp>(loc, a, y);
}
private:
// The core atan reduction derives from the heuristic described in
// https://arxiv.org/abs/1508.03211 and has a < 0.95 ulp error in the [-1, 1]
// range (though that assumed FMA was available, and it is not here). This is
// the same approximation that is also used by XLA.
Value emitAtanCoreApproximation(Value x, Location loc,
PatternRewriter &rewriter) const {
auto constant = [&](float c) {
return rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(c));
};
// Computes ab + c.
auto mul_add = [&](Value a, Value b, Value c) {
Value prod = rewriter.create<MulFOp>(loc, a, b);
return rewriter.create<AddFOp>(loc, prod, c);
};
Value s = rewriter.create<MulFOp>(loc, x, x);
Value r = constant(0.0027856871f);
r = mul_add(r, s, constant(-0.0158660002f));
r = mul_add(r, s, constant(0.042472221f));
r = mul_add(r, s, constant(-0.0749753043f));
r = mul_add(r, s, constant(0.106448799f));
r = mul_add(r, s, constant(-0.142070308f));
r = mul_add(r, s, constant(0.199934542f));
r = mul_add(r, s, constant(-0.333331466f));
r = rewriter.create<MulFOp>(loc, r, s);
return mul_add(r, x, x);
}
};
@ -146,7 +255,11 @@ createLegalizeTrigonometricToApproximationPass() {
void PopulateTrigonometricToApproximationPatterns(
mlir::MLIRContext *context, OwningRewritePatternList *patterns) {
patterns->insert<ApproximateTanhLowering>(context);
// clang-format off
patterns->insert<
ApproximateAtan2Lowering,
ApproximateTanhLowering>(context);
// clang-format on
}
} // namespace mhlo

View File

@ -0,0 +1,121 @@
// RUN: mlir-hlo-opt %s --mhlo-legalize-trigonometric-to-approximation --convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void --shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface }
// Helper function to print scalar values.
func @print_f32(%arg : f32) {
%mem = alloca() : memref<1xf32>
%c0 = constant 0 : index
store %arg, %mem[%c0] : memref<1xf32>
%mem_unranked = memref_cast %mem : memref<1xf32> to memref<*xf32>
call @print_memref_f32(%mem_unranked) : (memref<*xf32>) -> ()
return
}
// Compute and print trigonometric function.
func @atan2_f32(%arg0 : f32, %arg1 : f32) {
%res = atan2 %arg0, %arg1 : f32
call @print_f32(%res) : (f32) -> ()
return
}
func @tanh_f32(%arg : f32) {
%res = tanh %arg : f32
call @print_f32(%res) : (f32) -> ()
return
}
func @main() {
// Some constants to use as arguments.
%cf_n50_0 = constant -50.0 : f32
%cf_n5_0 = constant -5.0 : f32
%cf_n3_0 = constant -3.0 : f32
%cf_n2_0 = constant -2.0 : f32
%cf_n1_0 = constant -1.0 : f32
%cf_n0_5 = constant -0.5 : f32
%cf_n0_1 = constant -0.1 : f32
%cf_0_0 = constant 0.0 : f32
%cf_0_1 = constant 0.1 : f32
%cf_0_5 = constant 0.5 : f32
%cf_1_0 = constant 1.0 : f32
%cf_2_0 = constant 2.0 : f32
%cf_3_0 = constant 3.0 : f32
%cf_5_0 = constant 5.0 : f32
%cf_50_0 = constant 50.0 : f32
// Tanh.
call @tanh_f32(%cf_n50_0) : (f32) -> ()
// CHECK: -1
call @tanh_f32(%cf_n5_0) : (f32) -> ()
// CHECK: -0.999{{.*}}
call @tanh_f32(%cf_n3_0) : (f32) -> ()
// CHECK: -0.995{{.*}}
call @tanh_f32(%cf_n2_0) : (f32) -> ()
// CHECK: -0.964{{.*}}
call @tanh_f32(%cf_n1_0) : (f32) -> ()
// CHECK: -0.761{{.*}}
call @tanh_f32(%cf_n0_5) : (f32) -> ()
// CHECK: -0.462{{.*}}
call @tanh_f32(%cf_n0_1) : (f32) -> ()
// CHECK: -0.099{{.*}}
call @tanh_f32(%cf_0_0) : (f32) -> ()
// CHECK: 0
call @tanh_f32(%cf_0_1) : (f32) -> ()
// CHECK: 0.099{{.*}}
call @tanh_f32(%cf_0_5) : (f32) -> ()
// CHECK: 0.462{{.*}}
call @tanh_f32(%cf_1_0) : (f32) -> ()
// CHECK: 0.761{{.*}}
call @tanh_f32(%cf_2_0) : (f32) -> ()
// CHECK: 0.964{{.*}}
call @tanh_f32(%cf_3_0) : (f32) -> ()
// CHECK: 0.995{{.*}}
call @tanh_f32(%cf_5_0) : (f32) -> ()
// CHECK: 0.999{{.*}}
call @tanh_f32(%cf_50_0) : (f32) -> ()
// CHECK: 1
// Atan2 with divisor 1.
call @atan2_f32(%cf_n50_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: -1.550{{.*}}
call @atan2_f32(%cf_n5_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: -1.373{{.*}}
call @atan2_f32(%cf_n3_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: -1.249{{.*}}
call @atan2_f32(%cf_n2_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: -1.107{{.*}}
call @atan2_f32(%cf_n1_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: -0.785{{.*}}
call @atan2_f32(%cf_n0_5, %cf_1_0) : (f32, f32) -> ()
// CHECK: -0.463{{.*}}
call @atan2_f32(%cf_n0_1, %cf_1_0) : (f32, f32) -> ()
// CHECK: -0.099{{.*}}
call @atan2_f32(%cf_0_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: 0
call @atan2_f32(%cf_0_1, %cf_1_0) : (f32, f32) -> ()
// CHECK: 0.099{{.*}}
call @atan2_f32(%cf_0_5, %cf_1_0) : (f32, f32) -> ()
// CHECK: 0.463{{.*}}
call @atan2_f32(%cf_1_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: 0.785{{.*}}
call @atan2_f32(%cf_2_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: 1.107{{.*}}
call @atan2_f32(%cf_3_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: 1.249{{.*}}
call @atan2_f32(%cf_5_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: 1.373{{.*}}
call @atan2_f32(%cf_50_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: 1.550{{.*}}
// Atan2 all four quadrants.
call @atan2_f32(%cf_n1_0, %cf_n1_0) : (f32, f32) -> ()
// CHECK: -2.356{{.*}}
call @atan2_f32(%cf_n1_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: -0.785{{.*}}
call @atan2_f32(%cf_1_0, %cf_n1_0) : (f32, f32) -> ()
// CHECK: 2.356{{.*}}
call @atan2_f32(%cf_1_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: 0.785{{.*}}
return
}

View File

@ -119,3 +119,145 @@ func @tanh_f16(%arg0 : f16) -> f16 {
// CHECK: %[[VAL_43:.*]] = select %[[VAL_17]], %[[VAL_15]], %[[VAL_42]] : f32
// CHECK: %[[VAL_44:.*]] = fptrunc %[[VAL_43]] : f32 to f16
// CHECK: return %[[VAL_44]] : f16
// -----
// CHECK-LABEL: @atan2_f64
func @atan2_f64(%arg0 : f64, %arg1 : f64) -> f64 {
// CHECK: atan2
%res = atan2 %arg0, %arg1 : f64
return %res : f64
}
// -----
// CHECK-LABEL: func @atan2_f32
// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32) -> f32
func @atan2_f32(%arg0 : f32, %arg1 : f32) -> f32 {
// CHECK: %[[CST:.*]] = constant 0.0027856871 : f32
// CHECK: %[[CST_0:.*]] = constant -1.586600e-02 : f32
// CHECK: %[[CST_1:.*]] = constant 0.042472221 : f32
// CHECK: %[[CST_2:.*]] = constant -0.0749753043 : f32
// CHECK: %[[CST_3:.*]] = constant 0.106448799 : f32
// CHECK: %[[CST_4:.*]] = constant -0.142070308 : f32
// CHECK: %[[CST_5:.*]] = constant 0.199934542 : f32
// CHECK: %[[CST_6:.*]] = constant -0.333331466 : f32
// CHECK: %[[CST_7:.*]] = constant 1.57079637 : f32
// CHECK: %[[CST_8:.*]] = constant 0.000000e+00 : f32
// CHECK: %[[CST_9:.*]] = constant 3.14159274 : f32
// CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32
// CHECK: %[[CST_11:.*]] = constant 2.3561945 : f32
// CHECK: %[[CST_12:.*]] = constant 0.785398185 : f32
// CHECK: %[[CST_13:.*]] = constant 0x7F800000 : f32
// CHECK: %[[VAL_0:.*]] = absf %[[ARG1]] : f32
// CHECK: %[[VAL_1:.*]] = absf %[[ARG0]] : f32
// CHECK: %[[VAL_2:.*]] = cmpf "ole", %[[VAL_0]], %[[VAL_1]] : f32
// CHECK: %[[VAL_3:.*]] = select %[[VAL_2]], %[[VAL_0]], %[[VAL_1]] : f32
// CHECK: %[[VAL_4:.*]] = select %[[VAL_2]], %[[VAL_1]], %[[VAL_0]] : f32
// CHECK: %[[VAL_5:.*]] = divf %[[VAL_3]], %[[VAL_4]] : f32
// CHECK: %[[VAL_6:.*]] = mulf %[[VAL_5]], %[[VAL_5]] : f32
// CHECK: %[[VAL_7:.*]] = mulf %[[CST]], %[[VAL_6]] : f32
// CHECK: %[[VAL_8:.*]] = addf %[[VAL_7]], %[[CST_0]] : f32
// CHECK: %[[VAL_9:.*]] = mulf %[[VAL_8]], %[[VAL_6]] : f32
// CHECK: %[[VAL_10:.*]] = addf %[[VAL_9]], %[[CST_1]] : f32
// CHECK: %[[VAL_11:.*]] = mulf %[[VAL_10]], %[[VAL_6]] : f32
// CHECK: %[[VAL_12:.*]] = addf %[[VAL_11]], %[[CST_2]] : f32
// CHECK: %[[VAL_13:.*]] = mulf %[[VAL_12]], %[[VAL_6]] : f32
// CHECK: %[[VAL_14:.*]] = addf %[[VAL_13]], %[[CST_3]] : f32
// CHECK: %[[VAL_15:.*]] = mulf %[[VAL_14]], %[[VAL_6]] : f32
// CHECK: %[[VAL_16:.*]] = addf %[[VAL_15]], %[[CST_4]] : f32
// CHECK: %[[VAL_17:.*]] = mulf %[[VAL_16]], %[[VAL_6]] : f32
// CHECK: %[[VAL_18:.*]] = addf %[[VAL_17]], %[[CST_5]] : f32
// CHECK: %[[VAL_19:.*]] = mulf %[[VAL_18]], %[[VAL_6]] : f32
// CHECK: %[[VAL_20:.*]] = addf %[[VAL_19]], %[[CST_6]] : f32
// CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_6]] : f32
// CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_5]] : f32
// CHECK: %[[VAL_23:.*]] = addf %[[VAL_22]], %[[VAL_5]] : f32
// CHECK: %[[VAL_24:.*]] = subf %[[CST_7]], %[[VAL_23]] : f32
// CHECK: %[[VAL_25:.*]] = select %[[VAL_2]], %[[VAL_24]], %[[VAL_23]] : f32
// CHECK: %[[VAL_26:.*]] = cmpf "olt", %[[ARG1]], %[[CST_8]] : f32
// CHECK: %[[VAL_27:.*]] = subf %[[CST_9]], %[[VAL_25]] : f32
// CHECK: %[[VAL_28:.*]] = select %[[VAL_26]], %[[VAL_27]], %[[VAL_25]] : f32
// CHECK: %[[VAL_29:.*]] = select %[[VAL_26]], %[[CST_9]], %[[CST_8]] : f32
// CHECK: %[[VAL_30:.*]] = cmpf "oeq", %[[ARG0]], %[[CST_8]] : f32
// CHECK: %[[VAL_31:.*]] = select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : f32
// CHECK: %[[VAL_32:.*]] = cmpf "uno", %[[ARG0]], %[[ARG1]] : f32
// CHECK: %[[VAL_35:.*]] = select %[[VAL_32]], %[[CST_10]], %[[VAL_31]] : f32
// CHECK: %[[VAL_36:.*]] = select %[[VAL_26]], %[[CST_11]], %[[CST_12]] : f32
// CHECK: %[[VAL_37:.*]] = cmpf "oeq", %[[ARG1]], %[[CST_13]] : f32
// CHECK: %[[VAL_38:.*]] = cmpf "oeq", %[[ARG0]], %[[CST_13]] : f32
// CHECK: %[[VAL_39:.*]] = and %[[VAL_37]], %[[VAL_38]] : i1
// CHECK: %[[VAL_40:.*]] = select %[[VAL_39]], %[[VAL_36]], %[[VAL_35]] : f32
// CHECK: %[[VAL_41:.*]] = copysign %[[VAL_40]], %[[ARG0]] : f32
// CHECK: return %[[VAL_41]] : f32
%res = atan2 %arg0, %arg1 : f32
return %res : f32
}
// -----
// CHECK-LABEL: @atan2_f16
// CHECK-SAME: (%[[ARG0:.*]]: f16, %[[ARG1:.*]]: f16) -> f16
func @atan2_f16(%arg0 : f16, %arg1 : f16) -> f16 {
// CHECK: %[[CST:.*]] = constant 0.0027856871 : f32
// CHECK: %[[CST_0:.*]] = constant -1.586600e-02 : f32
// CHECK: %[[CST_1:.*]] = constant 0.042472221 : f32
// CHECK: %[[CST_2:.*]] = constant -0.0749753043 : f32
// CHECK: %[[CST_3:.*]] = constant 0.106448799 : f32
// CHECK: %[[CST_4:.*]] = constant -0.142070308 : f32
// CHECK: %[[CST_5:.*]] = constant 0.199934542 : f32
// CHECK: %[[CST_6:.*]] = constant -0.333331466 : f32
// CHECK: %[[CST_7:.*]] = constant 1.57079637 : f32
// CHECK: %[[CST_8:.*]] = constant 0.000000e+00 : f32
// CHECK: %[[CST_9:.*]] = constant 3.14159274 : f32
// CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32
// CHECK: %[[CST_11:.*]] = constant 2.3561945 : f32
// CHECK: %[[CST_12:.*]] = constant 0.785398185 : f32
// CHECK: %[[CST_13:.*]] = constant 0x7F800000 : f32
// CHECK: %[[VAL_0:.*]] = fpext %[[ARG0]] : f16 to f32
// CHECK: %[[VAL_1:.*]] = fpext %[[ARG1]] : f16 to f32
// CHECK: %[[VAL_2:.*]] = absf %[[VAL_1]] : f32
// CHECK: %[[VAL_3:.*]] = absf %[[VAL_0]] : f32
// CHECK: %[[VAL_4:.*]] = cmpf "ole", %[[VAL_2]], %[[VAL_3]] : f32
// CHECK: %[[VAL_5:.*]] = select %[[VAL_4]], %[[VAL_2]], %[[VAL_3]] : f32
// CHECK: %[[VAL_6:.*]] = select %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : f32
// CHECK: %[[VAL_7:.*]] = divf %[[VAL_5]], %[[VAL_6]] : f32
// CHECK: %[[VAL_8:.*]] = mulf %[[VAL_7]], %[[VAL_7]] : f32
// CHECK: %[[VAL_9:.*]] = mulf %[[CST]], %[[VAL_8]] : f32
// CHECK: %[[VAL_10:.*]] = addf %[[VAL_9]], %[[CST_0]] : f32
// CHECK: %[[VAL_11:.*]] = mulf %[[VAL_10]], %[[VAL_8]] : f32
// CHECK: %[[VAL_12:.*]] = addf %[[VAL_11]], %[[CST_1]] : f32
// CHECK: %[[VAL_13:.*]] = mulf %[[VAL_12]], %[[VAL_8]] : f32
// CHECK: %[[VAL_14:.*]] = addf %[[VAL_13]], %[[CST_2]] : f32
// CHECK: %[[VAL_15:.*]] = mulf %[[VAL_14]], %[[VAL_8]] : f32
// CHECK: %[[VAL_16:.*]] = addf %[[VAL_15]], %[[CST_3]] : f32
// CHECK: %[[VAL_17:.*]] = mulf %[[VAL_16]], %[[VAL_8]] : f32
// CHECK: %[[VAL_18:.*]] = addf %[[VAL_17]], %[[CST_4]] : f32
// CHECK: %[[VAL_19:.*]] = mulf %[[VAL_18]], %[[VAL_8]] : f32
// CHECK: %[[VAL_20:.*]] = addf %[[VAL_19]], %[[CST_5]] : f32
// CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_8]] : f32
// CHECK: %[[VAL_22:.*]] = addf %[[VAL_21]], %[[CST_6]] : f32
// CHECK: %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_8]] : f32
// CHECK: %[[VAL_24:.*]] = mulf %[[VAL_23]], %[[VAL_7]] : f32
// CHECK: %[[VAL_25:.*]] = addf %[[VAL_24]], %[[VAL_7]] : f32
// CHECK: %[[VAL_26:.*]] = subf %[[CST_7]], %[[VAL_25]] : f32
// CHECK: %[[VAL_27:.*]] = select %[[VAL_4]], %[[VAL_26]], %[[VAL_25]] : f32
// CHECK: %[[VAL_28:.*]] = cmpf "olt", %[[VAL_1]], %[[CST_8]] : f32
// CHECK: %[[VAL_29:.*]] = subf %[[CST_9]], %[[VAL_27]] : f32
// CHECK: %[[VAL_30:.*]] = select %[[VAL_28]], %[[VAL_29]], %[[VAL_27]] : f32
// CHECK: %[[VAL_31:.*]] = select %[[VAL_28]], %[[CST_9]], %[[CST_8]] : f32
// CHECK: %[[VAL_32:.*]] = cmpf "oeq", %[[VAL_0]], %[[CST_8]] : f32
// CHECK: %[[VAL_33:.*]] = select %[[VAL_32]], %[[VAL_31]], %[[VAL_30]] : f32
// CHECK: %[[VAL_34:.*]] = cmpf "uno", %[[VAL_0]], %[[VAL_1]] : f32
// CHECK: %[[VAL_37:.*]] = select %[[VAL_34]], %[[CST_10]], %[[VAL_33]] : f32
// CHECK: %[[VAL_38:.*]] = select %[[VAL_28]], %[[CST_11]], %[[CST_12]] : f32
// CHECK: %[[VAL_39:.*]] = cmpf "oeq", %[[VAL_1]], %[[CST_13]] : f32
// CHECK: %[[VAL_40:.*]] = cmpf "oeq", %[[VAL_0]], %[[CST_13]] : f32
// CHECK: %[[VAL_41:.*]] = and %[[VAL_39]], %[[VAL_40]] : i1
// CHECK: %[[VAL_42:.*]] = select %[[VAL_41]], %[[VAL_38]], %[[VAL_37]] : f32
// CHECK: %[[VAL_43:.*]] = copysign %[[VAL_42]], %[[VAL_0]] : f32
// CHECK: %[[VAL_44:.*]] = fptrunc %[[VAL_43]] : f32 to f16
// CHECK: return %[[VAL_44]] : f16
%res = atan2 %arg0, %arg1 : f16
return %res : f16
}