[MLIR][KernelGen] Legalize `atan` to approximation

PiperOrigin-RevId: 335417836
This commit is contained in:
A. Unique TensorFlower 2020-10-05 08:04:23 -07:00 committed by TensorFlow MLIR Team
parent 7f84a86cf5
commit bae0815ef0
4 changed files with 174 additions and 1 deletions

View File

@ -63,7 +63,7 @@ std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass(); std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass();
/// Lowers trigonometric operations from the standard dialect to approximations /// Lowers trigonometric operations from the standard dialect to approximations
// that do not use intrinsics. /// that do not use intrinsics.
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<FuncOp>>
createLegalizeTrigonometricToApproximationPass(); createLegalizeTrigonometricToApproximationPass();

View File

@ -235,6 +235,23 @@ class ApproximateAtan2Lowering
} }
}; };
class ApproximateAtanLowering
: public ApproximateOnExtendedF32Lowering<AtanOp> {
public:
explicit ApproximateAtanLowering(MLIRContext *ctx)
: ApproximateOnExtendedF32Lowering<AtanOp>(ctx) {}
// Reduce atan(x) to atan2(x, 1) to subsequently rely on an atan approximation
// for the argument range [-1, 1].
Value emitApproximation(ValueRange args, Location loc,
PatternRewriter &rewriter) const override {
Value x = args.front();
assert(x.getType().isF32());
Value one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1));
return rewriter.create<Atan2Op>(loc, x, one);
}
};
struct LegalizeTrigonometricToApproximationPass struct LegalizeTrigonometricToApproximationPass
: public PassWrapper<LegalizeTrigonometricToApproximationPass, : public PassWrapper<LegalizeTrigonometricToApproximationPass,
FunctionPass> { FunctionPass> {
@ -257,6 +274,7 @@ void PopulateTrigonometricToApproximationPatterns(
mlir::MLIRContext *context, OwningRewritePatternList *patterns) { mlir::MLIRContext *context, OwningRewritePatternList *patterns) {
// clang-format off // clang-format off
patterns->insert< patterns->insert<
ApproximateAtanLowering,
ApproximateAtan2Lowering, ApproximateAtan2Lowering,
ApproximateTanhLowering>(context); ApproximateTanhLowering>(context);
// clang-format on // clang-format on

View File

@ -13,6 +13,12 @@ func @print_f32(%arg : f32) {
} }
// Compute and print trigonometric function. // Compute and print trigonometric function.
func @atan_f32(%arg : f32) {
%res = atan %arg : f32
call @print_f32(%res) : (f32) -> ()
return
}
func @atan2_f32(%arg0 : f32, %arg1 : f32) { func @atan2_f32(%arg0 : f32, %arg1 : f32) {
%res = atan2 %arg0, %arg1 : f32 %res = atan2 %arg0, %arg1 : f32
call @print_f32(%res) : (f32) -> () call @print_f32(%res) : (f32) -> ()
@ -117,5 +123,37 @@ func @main() {
call @atan2_f32(%cf_1_0, %cf_1_0) : (f32, f32) -> () call @atan2_f32(%cf_1_0, %cf_1_0) : (f32, f32) -> ()
// CHECK: 0.785{{.*}} // CHECK: 0.785{{.*}}
// Atan.
call @atan_f32(%cf_n50_0) : (f32) -> ()
// CHECK: -1.550{{.*}}
call @atan_f32(%cf_n5_0) : (f32) -> ()
// CHECK: -1.373{{.*}}
call @atan_f32(%cf_n3_0) : (f32) -> ()
// CHECK: -1.249{{.*}}
call @atan_f32(%cf_n2_0) : (f32) -> ()
// CHECK: -1.107{{.*}}
call @atan_f32(%cf_n1_0) : (f32) -> ()
// CHECK: -0.785{{.*}}
call @atan_f32(%cf_n0_5) : (f32) -> ()
// CHECK: -0.463{{.*}}
call @atan_f32(%cf_n0_1) : (f32) -> ()
// CHECK: -0.099{{.*}}
call @atan_f32(%cf_0_0) : (f32) -> ()
// CHECK: 0
call @atan_f32(%cf_0_1) : (f32) -> ()
// CHECK: 0.099{{.*}}
call @atan_f32(%cf_0_5) : (f32) -> ()
// CHECK: 0.463{{.*}}
call @atan_f32(%cf_1_0) : (f32) -> ()
// CHECK: 0.785{{.*}}
call @atan_f32(%cf_2_0) : (f32) -> ()
// CHECK: 1.107{{.*}}
call @atan_f32(%cf_3_0) : (f32) -> ()
// CHECK: 1.249{{.*}}
call @atan_f32(%cf_5_0) : (f32) -> ()
// CHECK: 1.373{{.*}}
call @atan_f32(%cf_50_0) : (f32) -> ()
// CHECK: 1.550{{.*}}
return return
} }

View File

@ -261,3 +261,120 @@ func @atan2_f16(%arg0 : f16, %arg1 : f16) -> f16 {
%res = atan2 %arg0, %arg1 : f16 %res = atan2 %arg0, %arg1 : f16
return %res : f16 return %res : f16
} }
// -----
// CHECK-LABEL: @atan_f64
func @atan_f64(%arg : f64) -> f64 {
// CHECK: atan
%res = atan %arg : f64
return %res : f64
}
// -----
// CHECK-LABEL: func @atan_f32
// CHECK-SAME: (%[[ARG:.*]]: f32) -> f32
func @atan_f32(%arg : f32) -> f32 {
// CHECK: %[[CST:.*]] = constant 1.000000e+00 : f32
// CHECK: %[[CST_0:.*]] = constant 0.0027856871 : f32
// CHECK: %[[CST_1:.*]] = constant -1.586600e-02 : f32
// CHECK: %[[CST_2:.*]] = constant 0.042472221 : f32
// CHECK: %[[CST_3:.*]] = constant -0.0749753043 : f32
// CHECK: %[[CST_4:.*]] = constant 0.106448799 : f32
// CHECK: %[[CST_5:.*]] = constant -0.142070308 : f32
// CHECK: %[[CST_6:.*]] = constant 0.199934542 : f32
// CHECK: %[[CST_7:.*]] = constant -0.333331466 : f32
// CHECK: %[[CST_8:.*]] = constant 1.57079637 : f32
// CHECK: %[[CST_9:.*]] = constant 0.000000e+00 : f32
// CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32
// CHECK: %[[VAL_0:.*]] = absf %[[CST]] : f32
// CHECK: %[[VAL_1:.*]] = absf %arg0 : f32
// CHECK: %[[VAL_2:.*]] = cmpf "ole", %[[VAL_0]], %[[VAL_1]] : f32
// CHECK: %[[VAL_3:.*]] = select %[[VAL_2]], %[[VAL_0]], %[[VAL_1]] : f32
// CHECK: %[[VAL_4:.*]] = select %[[VAL_2]], %[[VAL_1]], %[[VAL_0]] : f32
// CHECK: %[[VAL_5:.*]] = divf %[[VAL_3]], %[[VAL_4]] : f32
// CHECK: %[[VAL_6:.*]] = mulf %[[VAL_5]], %[[VAL_5]] : f32
// CHECK: %[[VAL_7:.*]] = mulf %[[CST_0]], %[[VAL_6]] : f32
// CHECK: %[[VAL_8:.*]] = addf %[[VAL_7]], %[[CST_1]] : f32
// CHECK: %[[VAL_9:.*]] = mulf %[[VAL_8]], %[[VAL_6]] : f32
// CHECK: %[[VAL_10:.*]] = addf %[[VAL_9]], %[[CST_2]] : f32
// CHECK: %[[VAL_11:.*]] = mulf %[[VAL_10]], %[[VAL_6]] : f32
// CHECK: %[[VAL_12:.*]] = addf %[[VAL_11]], %[[CST_3]] : f32
// CHECK: %[[VAL_13:.*]] = mulf %[[VAL_12]], %[[VAL_6]] : f32
// CHECK: %[[VAL_14:.*]] = addf %[[VAL_13]], %[[CST_4]] : f32
// CHECK: %[[VAL_15:.*]] = mulf %[[VAL_14]], %[[VAL_6]] : f32
// CHECK: %[[VAL_16:.*]] = addf %[[VAL_15]], %[[CST_5]] : f32
// CHECK: %[[VAL_17:.*]] = mulf %[[VAL_16]], %[[VAL_6]] : f32
// CHECK: %[[VAL_18:.*]] = addf %[[VAL_17]], %[[CST_6]] : f32
// CHECK: %[[VAL_19:.*]] = mulf %[[VAL_18]], %[[VAL_6]] : f32
// CHECK: %[[VAL_20:.*]] = addf %[[VAL_19]], %[[CST_7]] : f32
// CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_6]] : f32
// CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_5]] : f32
// CHECK: %[[VAL_23:.*]] = addf %[[VAL_22]], %[[VAL_5]] : f32
// CHECK: %[[VAL_24:.*]] = subf %[[CST_8]], %[[VAL_23]] : f32
// CHECK: %[[VAL_25:.*]] = select %[[VAL_2]], %[[VAL_24]], %[[VAL_23]] : f32
// CHECK: %[[VAL_26:.*]] = cmpf "oeq", %arg0, %[[CST_9]] : f32
// CHECK: %[[VAL_27:.*]] = select %[[VAL_26]], %[[CST_9]], %[[VAL_25]] : f32
// CHECK: %[[VAL_28:.*]] = cmpf "uno", %arg0, %[[CST]] : f32
// CHECK: %[[VAL_29:.*]] = select %[[VAL_28]], %[[CST_10]], %[[VAL_27]] : f32
// CHECK: %[[VAL_30:.*]] = copysign %[[VAL_29]], %arg0 : f32
// CHECK: return %[[VAL_30]] : f32
%res = atan %arg : f32
return %res : f32
}
// -----
// CHECK-LABEL: @atan_f16
// CHECK-SAME: (%[[ARG:.*]]: f16) -> f16
func @atan_f16(%arg : f16) -> f16 {
// CHECK: %[[CST:.*]] = constant 1.000000e+00 : f32
// CHECK: %[[CST_0:.*]] = constant 0.0027856871 : f32
// CHECK: %[[CST_1:.*]] = constant -1.586600e-02 : f32
// CHECK: %[[CST_2:.*]] = constant 0.042472221 : f32
// CHECK: %[[CST_3:.*]] = constant -0.0749753043 : f32
// CHECK: %[[CST_4:.*]] = constant 0.106448799 : f32
// CHECK: %[[CST_5:.*]] = constant -0.142070308 : f32
// CHECK: %[[CST_6:.*]] = constant 0.199934542 : f32
// CHECK: %[[CST_7:.*]] = constant -0.333331466 : f32
// CHECK: %[[CST_8:.*]] = constant 1.57079637 : f32
// CHECK: %[[CST_9:.*]] = constant 0.000000e+00 : f32
// CHECK: %[[CST_10:.*]] = constant 0x7FC00000 : f32
// CHECK: %[[VAL_0:.*]] = fpext %arg0 : f16 to f32
// CHECK: %[[VAL_1:.*]] = absf %[[CST]] : f32
// CHECK: %[[VAL_2:.*]] = absf %[[VAL_0]] : f32
// CHECK: %[[VAL_3:.*]] = cmpf "ole", %[[VAL_1]], %[[VAL_2]] : f32
// CHECK: %[[VAL_4:.*]] = select %[[VAL_3]], %[[VAL_1]], %[[VAL_2]] : f32
// CHECK: %[[VAL_5:.*]] = select %[[VAL_3]], %[[VAL_2]], %[[VAL_1]] : f32
// CHECK: %[[VAL_6:.*]] = divf %[[VAL_4]], %[[VAL_5]] : f32
// CHECK: %[[VAL_7:.*]] = mulf %[[VAL_6]], %[[VAL_6]] : f32
// CHECK: %[[VAL_8:.*]] = mulf %[[CST_0]], %[[VAL_7]] : f32
// CHECK: %[[VAL_9:.*]] = addf %[[VAL_8]], %[[CST_1]] : f32
// CHECK: %[[VAL_10:.*]] = mulf %[[VAL_9]], %[[VAL_7]] : f32
// CHECK: %[[VAL_11:.*]] = addf %[[VAL_10]], %[[CST_2]] : f32
// CHECK: %[[VAL_12:.*]] = mulf %[[VAL_11]], %[[VAL_7]] : f32
// CHECK: %[[VAL_13:.*]] = addf %[[VAL_12]], %[[CST_3]] : f32
// CHECK: %[[VAL_14:.*]] = mulf %[[VAL_13]], %[[VAL_7]] : f32
// CHECK: %[[VAL_15:.*]] = addf %[[VAL_14]], %[[CST_4]] : f32
// CHECK: %[[VAL_16:.*]] = mulf %[[VAL_15]], %[[VAL_7]] : f32
// CHECK: %[[VAL_17:.*]] = addf %[[VAL_16]], %[[CST_5]] : f32
// CHECK: %[[VAL_18:.*]] = mulf %[[VAL_17]], %[[VAL_7]] : f32
// CHECK: %[[VAL_19:.*]] = addf %[[VAL_18]], %[[CST_6]] : f32
// CHECK: %[[VAL_20:.*]] = mulf %[[VAL_19]], %[[VAL_7]] : f32
// CHECK: %[[VAL_21:.*]] = addf %[[VAL_20]], %[[CST_7]] : f32
// CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_7]] : f32
// CHECK: %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_6]] : f32
// CHECK: %[[VAL_24:.*]] = addf %[[VAL_23]], %[[VAL_6]] : f32
// CHECK: %[[VAL_25:.*]] = subf %[[CST_8]], %[[VAL_24]] : f32
// CHECK: %[[VAL_26:.*]] = select %[[VAL_3]], %[[VAL_25]], %[[VAL_24]] : f32
// CHECK: %[[VAL_27:.*]] = cmpf "oeq", %[[VAL_0]], %[[CST_9]] : f32
// CHECK: %[[VAL_28:.*]] = select %[[VAL_27]], %[[CST_9]], %[[VAL_26]] : f32
// CHECK: %[[VAL_29:.*]] = cmpf "uno", %[[VAL_0]], %[[CST]] : f32
// CHECK: %[[VAL_30:.*]] = select %[[VAL_29]], %[[CST_10]], %[[VAL_28]] : f32
// CHECK: %[[VAL_31:.*]] = copysign %[[VAL_30]], %[[VAL_0]] : f32
// CHECK: %[[VAL_32:.*]] = fptrunc %[[VAL_31]] : f32 to f16
// CHECK: return %[[VAL_32]] : f16
%res = atan %arg : f16
return %res : f16
}