diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index bc5d8d2..c62cbbd 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir-hlo/utils/broadcast_utils.h" +#include "mlir-hlo/utils/hlo_utils.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -989,6 +990,68 @@ struct ConvertPolygammaOp : public OpConversionPattern { } }; +Value MaterializeSinhApproximationForLargeX(ConversionPatternRewriter &rewriter, + Location loc, ValueRange operands) { + SinhOp::Adaptor transformed(operands); + Value x = transformed.operand(); + + Value log_one_half = + rewriter.create(loc, getConstantLike(rewriter, loc, 0.5, x)); + Value exp_add = rewriter.create( + loc, rewriter.create(loc, x, log_one_half)); + Value exp_sub = rewriter.create( + loc, rewriter.create(loc, log_one_half, x)); + return rewriter.create(loc, exp_add, exp_sub); +} + +// Express `sinh` as +// sinh(x) = (e^x - e^-x) / 2 if |x| < 1 +// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise. +Value MaterializeSinhApproximation(ConversionPatternRewriter &rewriter, + Location loc, ValueRange operands) { + Value large_sinh_result = + MaterializeSinhApproximationForLargeX(rewriter, loc, operands); + + SinhOp::Adaptor transformed(operands); + Value x = transformed.operand(); + const StringAttr kLT = rewriter.getStringAttr( + mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT)); + Value exp_x = rewriter.create(loc, x); + Value exp_neg_x = + rewriter.create(loc, rewriter.create(loc, x)); + Value exp_difference = rewriter.create(loc, exp_x, exp_neg_x); + Value two = getConstantLike(rewriter, loc, 2.0, x); + Value small_sinh_result = + rewriter.create(loc, exp_difference, two); + + Value abs_x = rewriter.create(loc, x); + Value one = getConstantLike(rewriter, loc, 1.0, x); + Value abs_x_lt_one = rewriter.create(loc, abs_x, one, kLT); + return rewriter.create(loc, abs_x_lt_one, small_sinh_result, + large_sinh_result); +} + +struct ConvertSinhOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + SinhOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + SinhOp::Adaptor transformed(operands); + Value x = transformed.operand(); + if (x.getType().cast().getElementType().isa()) { + // TODO(hinsu): Support operands with complex element types by always + // using the formula for large x. The compare op is not legal for complex + // numbers. + return failure(); + } + rewriter.replaceOp(op, + MaterializeWithUpcast(rewriter, op.getLoc(), operands, + rewriter.getF32Type(), + &MaterializeSinhApproximation)); + return success(); + } +}; + struct ConvertZetaOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( @@ -1248,6 +1311,7 @@ void PopulateDecomposeChloPatterns(MLIRContext *context, ConvertErfcOp, ConvertLgammaOp, ConvertPolygammaOp, + ConvertSinhOp, ConvertZetaOp>(context); // clang-format on } diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td index 5266686..559abad 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td @@ -312,48 +312,6 @@ def : Pat<(HLOClient_IsNegInfOp NonComplexElementType:$input), (HLO_DEFAULT_COMPARISON_TYPE) )>; -// Express `sinh` as -// sinh(x) = (e^x - e^-x) / 2 if |x| < 1 -// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise. -// TODO(hinsu): Support operands with complex element types by always using the -// second formula. The compare op below is not legal for complex numbers. -def : Pat<(HLOClient_SinhOp NonComplexElementType:$input), - (HLO_SelectOp - (HLO_CompareOp - (HLO_AbsOp $input), - (HLO_ConstantLike<"1"> $input), - HLO_COMPARISON_DIRECTION_LT, - (HLO_DEFAULT_COMPARISON_TYPE) - ), - (HLO_DivOp - (HLO_SubOp - (HLO_ExpOp $input), - (HLO_ExpOp - (HLO_NegOp $input) - ) - ), - (HLO_ConstantLike<"2"> $input) - ), - (HLO_SubOp - (HLO_ExpOp - (HLO_AddOp - $input, - (HLO_LogOp - (HLO_ConstantLike<"0.5"> $input) - ) - ) - ), - (HLO_ExpOp - (HLO_SubOp - (HLO_LogOp - (HLO_ConstantLike<"0.5"> $input) - ), - $input - ) - ) - ) - )>; - // Express tan in MHLO dialect as // tan(x) = sin(x) / cos(x). def : Pat<(HLOClient_TanOp NonComplexElementType:$input), diff --git a/tests/chlo_legalize_to_mhlo.mlir b/tests/chlo_legalize_to_mhlo.mlir index 6d1ecaa..50d985a 100644 --- a/tests/chlo_legalize_to_mhlo.mlir +++ b/tests/chlo_legalize_to_mhlo.mlir @@ -2123,3 +2123,42 @@ func @polygamma_f16(%lhs : tensor, %rhs : tensor) -> tensor { %1 = chlo.polygamma %lhs, %rhs : tensor, tensor -> tensor return %1 : tensor } + +// ---- + +// CHECK-LABEL: @sinh_f32 +// CHECK-SAME: (%[[X:.*]]: tensor) +func @sinh_f32(%x : tensor) -> tensor { + // CHECK: %[[HALF:.*]] = mhlo.constant dense<5.000000e-01> : tensor + // CHECK: %[[LOG_HALF:.*]] = "mhlo.log"(%[[HALF]]) : (tensor) -> tensor + // CHECK: %[[X_PLUS_LOG_HALF:.*]] = mhlo.add %[[X]], %[[LOG_HALF]] : tensor + // CHECK: %[[EXP_1:.*]] = "mhlo.exponential"(%[[X_PLUS_LOG_HALF]]) : (tensor) -> tensor + // CHECK: %[[LOG_HALF_MINUS_X:.*]] = mhlo.subtract %[[LOG_HALF]], %[[X]] : tensor + // CHECK: %[[EXP_2:.*]] = "mhlo.exponential"(%[[LOG_HALF_MINUS_X]]) : (tensor) -> tensor + // CHECK: %[[LARGE_SINH_RESULT:.*]] = mhlo.subtract %[[EXP_1]], %[[EXP_2]] : tensor + // CHECK: %[[EXP_X:.*]] = "mhlo.exponential"(%[[X]]) : (tensor) -> tensor + // CHECK: %[[NEG_X:.*]] = "mhlo.negate"(%[[X]]) : (tensor) -> tensor + // CHECK: %[[EXP_NEG_X:.*]] = "mhlo.exponential"(%[[NEG_X]]) : (tensor) -> tensor + // CHECK: %[[EXP_X_MINUS_EXP_NEG_X:.*]] = mhlo.subtract %[[EXP_X]], %[[EXP_NEG_X]] : tensor + // CHECK: %[[TWO:.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK: %[[SMALL_SINH_RESULT:.*]] = mhlo.divide %[[EXP_X_MINUS_EXP_NEG_X]], %[[TWO]] : tensor + // CHECK: %[[ABS_X:.*]] = "mhlo.abs"(%[[X]]) : (tensor) -> tensor + // CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK: %[[ABS_X_LT_ONE:.*]] = "mhlo.compare"(%[[ABS_X]], %[[ONE]]) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[ABS_X_LT_ONE]], %[[SMALL_SINH_RESULT]], %[[LARGE_SINH_RESULT]]) : (tensor, tensor, tensor) -> tensor + // CHECK: return %[[RESULT]] : tensor + %1 = chlo.sinh %x : tensor -> tensor + return %1 : tensor +} + +// ---- + +// CHECK-LABEL: @sinh_f16 +// CHECK-SAME: (%[[ARG0:.*]]: tensor) +func @sinh_f16(%x : tensor) -> tensor { + // CHECK: "mhlo.convert"(%[[ARG0]]) : (tensor) -> tensor + // CHECK: %[[RES:.*]] = "mhlo.convert"(%{{.*}}) : (tensor) -> tensor + // CHECK: return %[[RES]] + %1 = chlo.sinh %x : tensor -> tensor + return %1 : tensor +}