From bae0815ef064c2662ae663599cc1db56ec4835e2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 5 Oct 2020 08:04:23 -0700 Subject: [PATCH] [MLIR][KernelGen] Legalize `atan` to approximation PiperOrigin-RevId: 335417836 --- .../mlir-hlo/Dialect/mhlo/transforms/passes.h | 2 +- ...legalize_trigonometric_to_approximation.cc | 18 +++ ...galize-trigonometric-to-approximation.mlir | 38 ++++++ ...galize-trigonometric-to-approximation.mlir | 117 ++++++++++++++++++ 4 files changed, 174 insertions(+), 1 deletion(-) diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h index 2c0735a..080cde5 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h @@ -63,7 +63,7 @@ std::unique_ptr> createSinkConstantsToControlFlowPass(); std::unique_ptr> createMhloFusionPass(); /// Lowers trigonometric operations from the standard dialect to approximations -// that do not use intrinsics. +/// that do not use intrinsics. std::unique_ptr> createLegalizeTrigonometricToApproximationPass(); diff --git a/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc b/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc index e24cd9c..1003086 100644 --- a/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc +++ b/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc @@ -235,6 +235,23 @@ class ApproximateAtan2Lowering } }; +class ApproximateAtanLowering + : public ApproximateOnExtendedF32Lowering { + public: + explicit ApproximateAtanLowering(MLIRContext *ctx) + : ApproximateOnExtendedF32Lowering(ctx) {} + + // Reduce atan(x) to atan2(x, 1) to subsequently rely on an atan approximation + // for the argument range [-1, 1]. + Value emitApproximation(ValueRange args, Location loc, + PatternRewriter &rewriter) const override { + Value x = args.front(); + assert(x.getType().isF32()); + Value one = rewriter.create(loc, rewriter.getF32FloatAttr(1)); + return rewriter.create(loc, x, one); + } +}; + struct LegalizeTrigonometricToApproximationPass : public PassWrapper { @@ -257,6 +274,7 @@ void PopulateTrigonometricToApproximationPatterns( mlir::MLIRContext *context, OwningRewritePatternList *patterns) { // clang-format off patterns->insert< + ApproximateAtanLowering, ApproximateAtan2Lowering, ApproximateTanhLowering>(context); // clang-format on diff --git a/tests/end2end/legalize-trigonometric-to-approximation.mlir b/tests/end2end/legalize-trigonometric-to-approximation.mlir index 48b9d3b..a7e77b7 100644 --- a/tests/end2end/legalize-trigonometric-to-approximation.mlir +++ b/tests/end2end/legalize-trigonometric-to-approximation.mlir @@ -13,6 +13,12 @@ func @print_f32(%arg : f32) { } // Compute and print trigonometric function. +func @atan_f32(%arg : f32) { + %res = atan %arg : f32 + call @print_f32(%res) : (f32) -> () + return +} + func @atan2_f32(%arg0 : f32, %arg1 : f32) { %res = atan2 %arg0, %arg1 : f32 call @print_f32(%res) : (f32) -> () @@ -117,5 +123,37 @@ func @main() { call @atan2_f32(%cf_1_0, %cf_1_0) : (f32, f32) -> () // CHECK: 0.785{{.*}} + // Atan. + call @atan_f32(%cf_n50_0) : (f32) -> () + // CHECK: -1.550{{.*}} + call @atan_f32(%cf_n5_0) : (f32) -> () + // CHECK: -1.373{{.*}} + call @atan_f32(%cf_n3_0) : (f32) -> () + // CHECK: -1.249{{.*}} + call @atan_f32(%cf_n2_0) : (f32) -> () + // CHECK: -1.107{{.*}} + call @atan_f32(%cf_n1_0) : (f32) -> () + // CHECK: -0.785{{.*}} + call @atan_f32(%cf_n0_5) : (f32) -> () + // CHECK: -0.463{{.*}} + call @atan_f32(%cf_n0_1) : (f32) -> () + // CHECK: -0.099{{.*}} + call @atan_f32(%cf_0_0) : (f32) -> () + // CHECK: 0 + call @atan_f32(%cf_0_1) : (f32) -> () + // CHECK: 0.099{{.*}} + call @atan_f32(%cf_0_5) : (f32) -> () + // CHECK: 0.463{{.*}} + call @atan_f32(%cf_1_0) : (f32) -> () + // CHECK: 0.785{{.*}} + call @atan_f32(%cf_2_0) : (f32) -> () + // CHECK: 1.107{{.*}} + call @atan_f32(%cf_3_0) : (f32) -> () + // CHECK: 1.249{{.*}} + call @atan_f32(%cf_5_0) : (f32) -> () + // CHECK: 1.373{{.*}} + call @atan_f32(%cf_50_0) : (f32) -> () + // CHECK: 1.550{{.*}} + return } diff --git a/tests/legalize-trigonometric-to-approximation.mlir b/tests/legalize-trigonometric-to-approximation.mlir index 43278ff..c25545c 100644 --- a/tests/legalize-trigonometric-to-approximation.mlir +++ b/tests/legalize-trigonometric-to-approximation.mlir @@ -261,3 +261,120 @@ func @atan2_f16(%arg0 : f16, %arg1 : f16) -> f16 { %res = atan2 %arg0, %arg1 : f16 return %res : f16 } + +// ----- + +// CHECK-LABEL: @atan_f64 +func @atan_f64(%arg : f64) -> f64 { + // CHECK: atan + %res = atan %arg : f64 + return %res : f64 +} + +// ----- + +// CHECK-LABEL: func @atan_f32 +// CHECK-SAME: (%[[ARG:.*]]: f32) -> f32 +func @atan_f32(%arg : f32) -> f32 { + // CHECK: %[[CST:.*]] = constant 1.000000e+00 : f32 + // CHECK: %[[CST_0:.*]] = constant 0.0027856871 : f32 + // CHECK: %[[CST_1:.*]] = constant -1.586600e-02 : f32 + // CHECK: %[[CST_2:.*]] = constant 0.042472221 : f32 + // CHECK: %[[CST_3:.*]] = constant -0.0749753043 : f32 + // CHECK: %[[CST_4:.*]] = constant 0.106448799 : f32 + // CHECK: %[[CST_5:.*]] = constant -0.142070308 : f32 + // CHECK: %[[CST_6:.*]] = constant 0.199934542 : f32 + // CHECK: %[[CST_7:.*]] = constant -0.333331466 : f32 + // CHECK: %[[CST_8:.*]] = constant 1.57079637 : f32 + // CHECK: %[[CST_9:.*]] = constant 0.000000e+00 : f32 + // CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32 + // CHECK: %[[VAL_0:.*]] = absf %[[CST]] : f32 + // CHECK: %[[VAL_1:.*]] = absf %arg0 : f32 + // CHECK: %[[VAL_2:.*]] = cmpf "ole", %[[VAL_0]], %[[VAL_1]] : f32 + // CHECK: %[[VAL_3:.*]] = select %[[VAL_2]], %[[VAL_0]], %[[VAL_1]] : f32 + // CHECK: %[[VAL_4:.*]] = select %[[VAL_2]], %[[VAL_1]], %[[VAL_0]] : f32 + // CHECK: %[[VAL_5:.*]] = divf %[[VAL_3]], %[[VAL_4]] : f32 + // CHECK: %[[VAL_6:.*]] = mulf %[[VAL_5]], %[[VAL_5]] : f32 + // CHECK: %[[VAL_7:.*]] = mulf %[[CST_0]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_8:.*]] = addf %[[VAL_7]], %[[CST_1]] : f32 + // CHECK: %[[VAL_9:.*]] = mulf %[[VAL_8]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_10:.*]] = addf %[[VAL_9]], %[[CST_2]] : f32 + // CHECK: %[[VAL_11:.*]] = mulf %[[VAL_10]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_12:.*]] = addf %[[VAL_11]], %[[CST_3]] : f32 + // CHECK: %[[VAL_13:.*]] = mulf %[[VAL_12]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_14:.*]] = addf %[[VAL_13]], %[[CST_4]] : f32 + // CHECK: %[[VAL_15:.*]] = mulf %[[VAL_14]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_16:.*]] = addf %[[VAL_15]], %[[CST_5]] : f32 + // CHECK: %[[VAL_17:.*]] = mulf %[[VAL_16]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_18:.*]] = addf %[[VAL_17]], %[[CST_6]] : f32 + // CHECK: %[[VAL_19:.*]] = mulf %[[VAL_18]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_20:.*]] = addf %[[VAL_19]], %[[CST_7]] : f32 + // CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_5]] : f32 + // CHECK: %[[VAL_23:.*]] = addf %[[VAL_22]], %[[VAL_5]] : f32 + // CHECK: %[[VAL_24:.*]] = subf %[[CST_8]], %[[VAL_23]] : f32 + // CHECK: %[[VAL_25:.*]] = select %[[VAL_2]], %[[VAL_24]], %[[VAL_23]] : f32 + // CHECK: %[[VAL_26:.*]] = cmpf "oeq", %arg0, %[[CST_9]] : f32 + // CHECK: %[[VAL_27:.*]] = select %[[VAL_26]], %[[CST_9]], %[[VAL_25]] : f32 + // CHECK: %[[VAL_28:.*]] = cmpf "uno", %arg0, %[[CST]] : f32 + // CHECK: %[[VAL_29:.*]] = select %[[VAL_28]], %[[CST_10]], %[[VAL_27]] : f32 + // CHECK: %[[VAL_30:.*]] = copysign %[[VAL_29]], %arg0 : f32 + // CHECK: return %[[VAL_30]] : f32 + %res = atan %arg : f32 + return %res : f32 +} + +// ----- + +// CHECK-LABEL: @atan_f16 +// CHECK-SAME: (%[[ARG:.*]]: f16) -> f16 +func @atan_f16(%arg : f16) -> f16 { + // CHECK: %[[CST:.*]] = constant 1.000000e+00 : f32 + // CHECK: %[[CST_0:.*]] = constant 0.0027856871 : f32 + // CHECK: %[[CST_1:.*]] = constant -1.586600e-02 : f32 + // CHECK: %[[CST_2:.*]] = constant 0.042472221 : f32 + // CHECK: %[[CST_3:.*]] = constant -0.0749753043 : f32 + // CHECK: %[[CST_4:.*]] = constant 0.106448799 : f32 + // CHECK: %[[CST_5:.*]] = constant -0.142070308 : f32 + // CHECK: %[[CST_6:.*]] = constant 0.199934542 : f32 + // CHECK: %[[CST_7:.*]] = constant -0.333331466 : f32 + // CHECK: %[[CST_8:.*]] = constant 1.57079637 : f32 + // CHECK: %[[CST_9:.*]] = constant 0.000000e+00 : f32 + // CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32 + // CHECK: %[[VAL_0:.*]] = fpext %arg0 : f16 to f32 + // CHECK: %[[VAL_1:.*]] = absf %[[CST]] : f32 + // CHECK: %[[VAL_2:.*]] = absf %[[VAL_0]] : f32 + // CHECK: %[[VAL_3:.*]] = cmpf "ole", %[[VAL_1]], %[[VAL_2]] : f32 + // CHECK: %[[VAL_4:.*]] = select %[[VAL_3]], %[[VAL_1]], %[[VAL_2]] : f32 + // CHECK: %[[VAL_5:.*]] = select %[[VAL_3]], %[[VAL_2]], %[[VAL_1]] : f32 + // CHECK: %[[VAL_6:.*]] = divf %[[VAL_4]], %[[VAL_5]] : f32 + // CHECK: %[[VAL_7:.*]] = mulf %[[VAL_6]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_8:.*]] = mulf %[[CST_0]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_9:.*]] = addf %[[VAL_8]], %[[CST_1]] : f32 + // CHECK: %[[VAL_10:.*]] = mulf %[[VAL_9]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_11:.*]] = addf %[[VAL_10]], %[[CST_2]] : f32 + // CHECK: %[[VAL_12:.*]] = mulf %[[VAL_11]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_13:.*]] = addf %[[VAL_12]], %[[CST_3]] : f32 + // CHECK: %[[VAL_14:.*]] = mulf %[[VAL_13]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_15:.*]] = addf %[[VAL_14]], %[[CST_4]] : f32 + // CHECK: %[[VAL_16:.*]] = mulf %[[VAL_15]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_17:.*]] = addf %[[VAL_16]], %[[CST_5]] : f32 + // CHECK: %[[VAL_18:.*]] = mulf %[[VAL_17]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_19:.*]] = addf %[[VAL_18]], %[[CST_6]] : f32 + // CHECK: %[[VAL_20:.*]] = mulf %[[VAL_19]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_21:.*]] = addf %[[VAL_20]], %[[CST_7]] : f32 + // CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_7]] : f32 + // CHECK: %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_24:.*]] = addf %[[VAL_23]], %[[VAL_6]] : f32 + // CHECK: %[[VAL_25:.*]] = subf %[[CST_8]], %[[VAL_24]] : f32 + // CHECK: %[[VAL_26:.*]] = select %[[VAL_3]], %[[VAL_25]], %[[VAL_24]] : f32 + // CHECK: %[[VAL_27:.*]] = cmpf "oeq", %[[VAL_0]], %[[CST_9]] : f32 + // CHECK: %[[VAL_28:.*]] = select %[[VAL_27]], %[[CST_9]], %[[VAL_26]] : f32 + // CHECK: %[[VAL_29:.*]] = cmpf "uno", %[[VAL_0]], %[[CST]] : f32 + // CHECK: %[[VAL_30:.*]] = select %[[VAL_29]], %[[CST_10]], %[[VAL_28]] : f32 + // CHECK: %[[VAL_31:.*]] = copysign %[[VAL_30]], %[[VAL_0]] : f32 + // CHECK: %[[VAL_32:.*]] = fptrunc %[[VAL_31]] : f32 to f16 + // CHECK: return %[[VAL_32]] : f16 + %res = atan %arg : f16 + return %res : f16 +}