Fix Sinh approximation for F16.

We should upcast F16 to F32 to prevent precision loss.
E.g. sinh(-9) would evaluate to -4042 previously instead of -4052.
This allows to enable the MLIR generated kernel for F16 type.

PiperOrigin-RevId: 377901896
This commit is contained in:
Adrian Kuegel 2021-06-07 06:37:20 -07:00 committed by TensorFlow MLIR Team
parent fc723380e6
commit 5315997402
3 changed files with 103 additions and 42 deletions

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir-hlo/utils/broadcast_utils.h" #include "mlir-hlo/utils/broadcast_utils.h"
#include "mlir-hlo/utils/hlo_utils.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -989,6 +990,68 @@ struct ConvertPolygammaOp : public OpConversionPattern<PolygammaOp> {
} }
}; };
Value MaterializeSinhApproximationForLargeX(ConversionPatternRewriter &rewriter,
Location loc, ValueRange operands) {
SinhOp::Adaptor transformed(operands);
Value x = transformed.operand();
Value log_one_half =
rewriter.create<mhlo::LogOp>(loc, getConstantLike(rewriter, loc, 0.5, x));
Value exp_add = rewriter.create<mhlo::ExpOp>(
loc, rewriter.create<mhlo::AddOp>(loc, x, log_one_half));
Value exp_sub = rewriter.create<mhlo::ExpOp>(
loc, rewriter.create<mhlo::SubOp>(loc, log_one_half, x));
return rewriter.create<mhlo::SubOp>(loc, exp_add, exp_sub);
}
// Express `sinh` as
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
Value MaterializeSinhApproximation(ConversionPatternRewriter &rewriter,
Location loc, ValueRange operands) {
Value large_sinh_result =
MaterializeSinhApproximationForLargeX(rewriter, loc, operands);
SinhOp::Adaptor transformed(operands);
Value x = transformed.operand();
const StringAttr kLT = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
Value exp_x = rewriter.create<mhlo::ExpOp>(loc, x);
Value exp_neg_x =
rewriter.create<mhlo::ExpOp>(loc, rewriter.create<mhlo::NegOp>(loc, x));
Value exp_difference = rewriter.create<mhlo::SubOp>(loc, exp_x, exp_neg_x);
Value two = getConstantLike(rewriter, loc, 2.0, x);
Value small_sinh_result =
rewriter.create<mhlo::DivOp>(loc, exp_difference, two);
Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
Value one = getConstantLike(rewriter, loc, 1.0, x);
Value abs_x_lt_one = rewriter.create<mhlo::CompareOp>(loc, abs_x, one, kLT);
return rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_one, small_sinh_result,
large_sinh_result);
}
struct ConvertSinhOp : public OpConversionPattern<SinhOp> {
using OpConversionPattern<SinhOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
SinhOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
SinhOp::Adaptor transformed(operands);
Value x = transformed.operand();
if (x.getType().cast<ShapedType>().getElementType().isa<ComplexType>()) {
// TODO(hinsu): Support operands with complex element types by always
// using the formula for large x. The compare op is not legal for complex
// numbers.
return failure();
}
rewriter.replaceOp(op,
MaterializeWithUpcast(rewriter, op.getLoc(), operands,
rewriter.getF32Type(),
&MaterializeSinhApproximation));
return success();
}
};
struct ConvertZetaOp : public OpConversionPattern<ZetaOp> { struct ConvertZetaOp : public OpConversionPattern<ZetaOp> {
using OpConversionPattern<ZetaOp>::OpConversionPattern; using OpConversionPattern<ZetaOp>::OpConversionPattern;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
@ -1248,6 +1311,7 @@ void PopulateDecomposeChloPatterns(MLIRContext *context,
ConvertErfcOp, ConvertErfcOp,
ConvertLgammaOp, ConvertLgammaOp,
ConvertPolygammaOp, ConvertPolygammaOp,
ConvertSinhOp,
ConvertZetaOp>(context); ConvertZetaOp>(context);
// clang-format on // clang-format on
} }

View File

@ -312,48 +312,6 @@ def : Pat<(HLOClient_IsNegInfOp NonComplexElementType:$input),
(HLO_DEFAULT_COMPARISON_TYPE) (HLO_DEFAULT_COMPARISON_TYPE)
)>; )>;
// Express `sinh` as
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
// TODO(hinsu): Support operands with complex element types by always using the
// second formula. The compare op below is not legal for complex numbers.
def : Pat<(HLOClient_SinhOp NonComplexElementType:$input),
(HLO_SelectOp
(HLO_CompareOp
(HLO_AbsOp $input),
(HLO_ConstantLike<"1"> $input),
HLO_COMPARISON_DIRECTION_LT,
(HLO_DEFAULT_COMPARISON_TYPE)
),
(HLO_DivOp
(HLO_SubOp
(HLO_ExpOp $input),
(HLO_ExpOp
(HLO_NegOp $input)
)
),
(HLO_ConstantLike<"2"> $input)
),
(HLO_SubOp
(HLO_ExpOp
(HLO_AddOp
$input,
(HLO_LogOp
(HLO_ConstantLike<"0.5"> $input)
)
)
),
(HLO_ExpOp
(HLO_SubOp
(HLO_LogOp
(HLO_ConstantLike<"0.5"> $input)
),
$input
)
)
)
)>;
// Express tan in MHLO dialect as // Express tan in MHLO dialect as
// tan(x) = sin(x) / cos(x). // tan(x) = sin(x) / cos(x).
def : Pat<(HLOClient_TanOp NonComplexElementType:$input), def : Pat<(HLOClient_TanOp NonComplexElementType:$input),

View File

@ -2123,3 +2123,42 @@ func @polygamma_f16(%lhs : tensor<f16>, %rhs : tensor<f16>) -> tensor<f16> {
%1 = chlo.polygamma %lhs, %rhs : tensor<f16>, tensor<f16> -> tensor<f16> %1 = chlo.polygamma %lhs, %rhs : tensor<f16>, tensor<f16> -> tensor<f16>
return %1 : tensor<f16> return %1 : tensor<f16>
} }
// ----
// CHECK-LABEL: @sinh_f32
// CHECK-SAME: (%[[X:.*]]: tensor<f32>)
func @sinh_f32(%x : tensor<f32>) -> tensor<f32> {
// CHECK: %[[HALF:.*]] = mhlo.constant dense<5.000000e-01> : tensor<f32>
// CHECK: %[[LOG_HALF:.*]] = "mhlo.log"(%[[HALF]]) : (tensor<f32>) -> tensor<f32>
// CHECK: %[[X_PLUS_LOG_HALF:.*]] = mhlo.add %[[X]], %[[LOG_HALF]] : tensor<f32>
// CHECK: %[[EXP_1:.*]] = "mhlo.exponential"(%[[X_PLUS_LOG_HALF]]) : (tensor<f32>) -> tensor<f32>
// CHECK: %[[LOG_HALF_MINUS_X:.*]] = mhlo.subtract %[[LOG_HALF]], %[[X]] : tensor<f32>
// CHECK: %[[EXP_2:.*]] = "mhlo.exponential"(%[[LOG_HALF_MINUS_X]]) : (tensor<f32>) -> tensor<f32>
// CHECK: %[[LARGE_SINH_RESULT:.*]] = mhlo.subtract %[[EXP_1]], %[[EXP_2]] : tensor<f32>
// CHECK: %[[EXP_X:.*]] = "mhlo.exponential"(%[[X]]) : (tensor<f32>) -> tensor<f32>
// CHECK: %[[NEG_X:.*]] = "mhlo.negate"(%[[X]]) : (tensor<f32>) -> tensor<f32>
// CHECK: %[[EXP_NEG_X:.*]] = "mhlo.exponential"(%[[NEG_X]]) : (tensor<f32>) -> tensor<f32>
// CHECK: %[[EXP_X_MINUS_EXP_NEG_X:.*]] = mhlo.subtract %[[EXP_X]], %[[EXP_NEG_X]] : tensor<f32>
// CHECK: %[[TWO:.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32>
// CHECK: %[[SMALL_SINH_RESULT:.*]] = mhlo.divide %[[EXP_X_MINUS_EXP_NEG_X]], %[[TWO]] : tensor<f32>
// CHECK: %[[ABS_X:.*]] = "mhlo.abs"(%[[X]]) : (tensor<f32>) -> tensor<f32>
// CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: %[[ABS_X_LT_ONE:.*]] = "mhlo.compare"(%[[ABS_X]], %[[ONE]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[ABS_X_LT_ONE]], %[[SMALL_SINH_RESULT]], %[[LARGE_SINH_RESULT]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: return %[[RESULT]] : tensor<f32>
%1 = chlo.sinh %x : tensor<f32> -> tensor<f32>
return %1 : tensor<f32>
}
// ----
// CHECK-LABEL: @sinh_f16
// CHECK-SAME: (%[[ARG0:.*]]: tensor<f16>)
func @sinh_f16(%x : tensor<f16>) -> tensor<f16> {
// CHECK: "mhlo.convert"(%[[ARG0]]) : (tensor<f16>) -> tensor<f32>
// CHECK: %[[RES:.*]] = "mhlo.convert"(%{{.*}}) : (tensor<f32>) -> tensor<f16>
// CHECK: return %[[RES]]
%1 = chlo.sinh %x : tensor<f16> -> tensor<f16>
return %1 : tensor<f16>
}