Fix Sinh approximation for F16.
We should upcast F16 to F32 to prevent precision loss. E.g. sinh(-9) would evaluate to -4042 previously instead of -4052. This allows to enable the MLIR generated kernel for F16 type. PiperOrigin-RevId: 377901896
This commit is contained in:
parent
fc723380e6
commit
5315997402
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
#include "mlir-hlo/utils/broadcast_utils.h"
|
#include "mlir-hlo/utils/broadcast_utils.h"
|
||||||
|
#include "mlir-hlo/utils/hlo_utils.h"
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
@ -989,6 +990,68 @@ struct ConvertPolygammaOp : public OpConversionPattern<PolygammaOp> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Value MaterializeSinhApproximationForLargeX(ConversionPatternRewriter &rewriter,
|
||||||
|
Location loc, ValueRange operands) {
|
||||||
|
SinhOp::Adaptor transformed(operands);
|
||||||
|
Value x = transformed.operand();
|
||||||
|
|
||||||
|
Value log_one_half =
|
||||||
|
rewriter.create<mhlo::LogOp>(loc, getConstantLike(rewriter, loc, 0.5, x));
|
||||||
|
Value exp_add = rewriter.create<mhlo::ExpOp>(
|
||||||
|
loc, rewriter.create<mhlo::AddOp>(loc, x, log_one_half));
|
||||||
|
Value exp_sub = rewriter.create<mhlo::ExpOp>(
|
||||||
|
loc, rewriter.create<mhlo::SubOp>(loc, log_one_half, x));
|
||||||
|
return rewriter.create<mhlo::SubOp>(loc, exp_add, exp_sub);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Express `sinh` as
|
||||||
|
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
|
||||||
|
// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
|
||||||
|
Value MaterializeSinhApproximation(ConversionPatternRewriter &rewriter,
|
||||||
|
Location loc, ValueRange operands) {
|
||||||
|
Value large_sinh_result =
|
||||||
|
MaterializeSinhApproximationForLargeX(rewriter, loc, operands);
|
||||||
|
|
||||||
|
SinhOp::Adaptor transformed(operands);
|
||||||
|
Value x = transformed.operand();
|
||||||
|
const StringAttr kLT = rewriter.getStringAttr(
|
||||||
|
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
|
||||||
|
Value exp_x = rewriter.create<mhlo::ExpOp>(loc, x);
|
||||||
|
Value exp_neg_x =
|
||||||
|
rewriter.create<mhlo::ExpOp>(loc, rewriter.create<mhlo::NegOp>(loc, x));
|
||||||
|
Value exp_difference = rewriter.create<mhlo::SubOp>(loc, exp_x, exp_neg_x);
|
||||||
|
Value two = getConstantLike(rewriter, loc, 2.0, x);
|
||||||
|
Value small_sinh_result =
|
||||||
|
rewriter.create<mhlo::DivOp>(loc, exp_difference, two);
|
||||||
|
|
||||||
|
Value abs_x = rewriter.create<mhlo::AbsOp>(loc, x);
|
||||||
|
Value one = getConstantLike(rewriter, loc, 1.0, x);
|
||||||
|
Value abs_x_lt_one = rewriter.create<mhlo::CompareOp>(loc, abs_x, one, kLT);
|
||||||
|
return rewriter.create<mhlo::SelectOp>(loc, abs_x_lt_one, small_sinh_result,
|
||||||
|
large_sinh_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ConvertSinhOp : public OpConversionPattern<SinhOp> {
|
||||||
|
using OpConversionPattern<SinhOp>::OpConversionPattern;
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
SinhOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
SinhOp::Adaptor transformed(operands);
|
||||||
|
Value x = transformed.operand();
|
||||||
|
if (x.getType().cast<ShapedType>().getElementType().isa<ComplexType>()) {
|
||||||
|
// TODO(hinsu): Support operands with complex element types by always
|
||||||
|
// using the formula for large x. The compare op is not legal for complex
|
||||||
|
// numbers.
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op,
|
||||||
|
MaterializeWithUpcast(rewriter, op.getLoc(), operands,
|
||||||
|
rewriter.getF32Type(),
|
||||||
|
&MaterializeSinhApproximation));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct ConvertZetaOp : public OpConversionPattern<ZetaOp> {
|
struct ConvertZetaOp : public OpConversionPattern<ZetaOp> {
|
||||||
using OpConversionPattern<ZetaOp>::OpConversionPattern;
|
using OpConversionPattern<ZetaOp>::OpConversionPattern;
|
||||||
LogicalResult matchAndRewrite(
|
LogicalResult matchAndRewrite(
|
||||||
|
@ -1248,6 +1311,7 @@ void PopulateDecomposeChloPatterns(MLIRContext *context,
|
||||||
ConvertErfcOp,
|
ConvertErfcOp,
|
||||||
ConvertLgammaOp,
|
ConvertLgammaOp,
|
||||||
ConvertPolygammaOp,
|
ConvertPolygammaOp,
|
||||||
|
ConvertSinhOp,
|
||||||
ConvertZetaOp>(context);
|
ConvertZetaOp>(context);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
|
@ -312,48 +312,6 @@ def : Pat<(HLOClient_IsNegInfOp NonComplexElementType:$input),
|
||||||
(HLO_DEFAULT_COMPARISON_TYPE)
|
(HLO_DEFAULT_COMPARISON_TYPE)
|
||||||
)>;
|
)>;
|
||||||
|
|
||||||
// Express `sinh` as
|
|
||||||
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
|
|
||||||
// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
|
|
||||||
// TODO(hinsu): Support operands with complex element types by always using the
|
|
||||||
// second formula. The compare op below is not legal for complex numbers.
|
|
||||||
def : Pat<(HLOClient_SinhOp NonComplexElementType:$input),
|
|
||||||
(HLO_SelectOp
|
|
||||||
(HLO_CompareOp
|
|
||||||
(HLO_AbsOp $input),
|
|
||||||
(HLO_ConstantLike<"1"> $input),
|
|
||||||
HLO_COMPARISON_DIRECTION_LT,
|
|
||||||
(HLO_DEFAULT_COMPARISON_TYPE)
|
|
||||||
),
|
|
||||||
(HLO_DivOp
|
|
||||||
(HLO_SubOp
|
|
||||||
(HLO_ExpOp $input),
|
|
||||||
(HLO_ExpOp
|
|
||||||
(HLO_NegOp $input)
|
|
||||||
)
|
|
||||||
),
|
|
||||||
(HLO_ConstantLike<"2"> $input)
|
|
||||||
),
|
|
||||||
(HLO_SubOp
|
|
||||||
(HLO_ExpOp
|
|
||||||
(HLO_AddOp
|
|
||||||
$input,
|
|
||||||
(HLO_LogOp
|
|
||||||
(HLO_ConstantLike<"0.5"> $input)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
),
|
|
||||||
(HLO_ExpOp
|
|
||||||
(HLO_SubOp
|
|
||||||
(HLO_LogOp
|
|
||||||
(HLO_ConstantLike<"0.5"> $input)
|
|
||||||
),
|
|
||||||
$input
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)>;
|
|
||||||
|
|
||||||
// Express tan in MHLO dialect as
|
// Express tan in MHLO dialect as
|
||||||
// tan(x) = sin(x) / cos(x).
|
// tan(x) = sin(x) / cos(x).
|
||||||
def : Pat<(HLOClient_TanOp NonComplexElementType:$input),
|
def : Pat<(HLOClient_TanOp NonComplexElementType:$input),
|
||||||
|
|
|
@ -2123,3 +2123,42 @@ func @polygamma_f16(%lhs : tensor<f16>, %rhs : tensor<f16>) -> tensor<f16> {
|
||||||
%1 = chlo.polygamma %lhs, %rhs : tensor<f16>, tensor<f16> -> tensor<f16>
|
%1 = chlo.polygamma %lhs, %rhs : tensor<f16>, tensor<f16> -> tensor<f16>
|
||||||
return %1 : tensor<f16>
|
return %1 : tensor<f16>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @sinh_f32
|
||||||
|
// CHECK-SAME: (%[[X:.*]]: tensor<f32>)
|
||||||
|
func @sinh_f32(%x : tensor<f32>) -> tensor<f32> {
|
||||||
|
// CHECK: %[[HALF:.*]] = mhlo.constant dense<5.000000e-01> : tensor<f32>
|
||||||
|
// CHECK: %[[LOG_HALF:.*]] = "mhlo.log"(%[[HALF]]) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[X_PLUS_LOG_HALF:.*]] = mhlo.add %[[X]], %[[LOG_HALF]] : tensor<f32>
|
||||||
|
// CHECK: %[[EXP_1:.*]] = "mhlo.exponential"(%[[X_PLUS_LOG_HALF]]) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[LOG_HALF_MINUS_X:.*]] = mhlo.subtract %[[LOG_HALF]], %[[X]] : tensor<f32>
|
||||||
|
// CHECK: %[[EXP_2:.*]] = "mhlo.exponential"(%[[LOG_HALF_MINUS_X]]) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[LARGE_SINH_RESULT:.*]] = mhlo.subtract %[[EXP_1]], %[[EXP_2]] : tensor<f32>
|
||||||
|
// CHECK: %[[EXP_X:.*]] = "mhlo.exponential"(%[[X]]) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[NEG_X:.*]] = "mhlo.negate"(%[[X]]) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[EXP_NEG_X:.*]] = "mhlo.exponential"(%[[NEG_X]]) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[EXP_X_MINUS_EXP_NEG_X:.*]] = mhlo.subtract %[[EXP_X]], %[[EXP_NEG_X]] : tensor<f32>
|
||||||
|
// CHECK: %[[TWO:.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: %[[SMALL_SINH_RESULT:.*]] = mhlo.divide %[[EXP_X_MINUS_EXP_NEG_X]], %[[TWO]] : tensor<f32>
|
||||||
|
// CHECK: %[[ABS_X:.*]] = "mhlo.abs"(%[[X]]) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
|
||||||
|
// CHECK: %[[ABS_X_LT_ONE:.*]] = "mhlo.compare"(%[[ABS_X]], %[[ONE]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
|
// CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[ABS_X_LT_ONE]], %[[SMALL_SINH_RESULT]], %[[LARGE_SINH_RESULT]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: return %[[RESULT]] : tensor<f32>
|
||||||
|
%1 = chlo.sinh %x : tensor<f32> -> tensor<f32>
|
||||||
|
return %1 : tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @sinh_f16
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<f16>)
|
||||||
|
func @sinh_f16(%x : tensor<f16>) -> tensor<f16> {
|
||||||
|
// CHECK: "mhlo.convert"(%[[ARG0]]) : (tensor<f16>) -> tensor<f32>
|
||||||
|
// CHECK: %[[RES:.*]] = "mhlo.convert"(%{{.*}}) : (tensor<f32>) -> tensor<f16>
|
||||||
|
// CHECK: return %[[RES]]
|
||||||
|
%1 = chlo.sinh %x : tensor<f16> -> tensor<f16>
|
||||||
|
return %1 : tensor<f16>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue