Fix Cosh approximation for F16.
We should upcast F16 to F32 to prevent precision loss. E.g. cosh(-9) would evaluate to 4042 previously instead of 4052. This allows to enable the MLIR generated kernel for F16 type. Also move template instantiation for Sinh to inside the #ifdef block. This was missed in a previous commit. PiperOrigin-RevId: 378635042
This commit is contained in:
parent
837a1de7c5
commit
6088eb697c
|
@ -648,6 +648,51 @@ Value MaterializeLgamma(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
lgamma);
|
lgamma);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Express `cosh` as
|
||||||
|
// cosh(x) = (e^x + e^-x) / 2
|
||||||
|
// = e^(x + log(1/2)) + e^(-x + log(1/2))
|
||||||
|
//
|
||||||
|
// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not.
|
||||||
|
//
|
||||||
|
// This incorrectly overflows to inf for two f32 input values, namely
|
||||||
|
// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
|
||||||
|
// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
|
||||||
|
// we deem this acceptable.
|
||||||
|
Value MaterializeCoshApproximation(ConversionPatternRewriter &rewriter,
|
||||||
|
Location loc, ValueRange operands) {
|
||||||
|
CoshOp::Adaptor transformed(operands);
|
||||||
|
Value x = transformed.operand();
|
||||||
|
|
||||||
|
Value log_one_half =
|
||||||
|
rewriter.create<mhlo::LogOp>(loc, getConstantLike(rewriter, loc, 0.5, x));
|
||||||
|
Value exp_add = rewriter.create<mhlo::ExpOp>(
|
||||||
|
loc, rewriter.create<mhlo::AddOp>(loc, x, log_one_half));
|
||||||
|
Value exp_sub = rewriter.create<mhlo::ExpOp>(
|
||||||
|
loc, rewriter.create<mhlo::SubOp>(loc, log_one_half, x));
|
||||||
|
return rewriter.create<mhlo::AddOp>(loc, exp_add, exp_sub);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ConvertCoshOp : public OpConversionPattern<CoshOp> {
|
||||||
|
using OpConversionPattern<CoshOp>::OpConversionPattern;
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
CoshOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
CoshOp::Adaptor transformed(operands);
|
||||||
|
Value x = transformed.operand();
|
||||||
|
if (x.getType().cast<ShapedType>().getElementType().isa<ComplexType>()) {
|
||||||
|
// TODO(hinsu): Support operands with complex element types by always
|
||||||
|
// using the formula for large x. The compare op is not legal for complex
|
||||||
|
// numbers.
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op,
|
||||||
|
MaterializeWithUpcast(rewriter, op.getLoc(), operands,
|
||||||
|
rewriter.getF32Type(),
|
||||||
|
&MaterializeCoshApproximation));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Compute the Digamma function using Lanczos' approximation from "A Precision
|
// Compute the Digamma function using Lanczos' approximation from "A Precision
|
||||||
// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
|
// Approximation of the Gamma Function". SIAM Journal on Numerical Analysis
|
||||||
// series B. Vol. 1:
|
// series B. Vol. 1:
|
||||||
|
@ -1318,7 +1363,8 @@ void PopulateDecomposeChloPatterns(MLIRContext *context,
|
||||||
|
|
||||||
// Other patterns.
|
// Other patterns.
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<ConvertDigammaOp,
|
patterns->insert<ConvertCoshOp,
|
||||||
|
ConvertDigammaOp,
|
||||||
ConvertErfOp,
|
ConvertErfOp,
|
||||||
ConvertErfcOp,
|
ConvertErfcOp,
|
||||||
ConvertLgammaOp,
|
ConvertLgammaOp,
|
||||||
|
|
|
@ -255,36 +255,6 @@ def : Pat<(HLOClient_AtanhOp NonComplexElementType:$input),
|
||||||
def : Pat<(HLOClient_ConjOp $v),
|
def : Pat<(HLOClient_ConjOp $v),
|
||||||
(HLO_ComplexOp (HLO_RealOp $v), (HLO_NegOp (HLO_ImagOp $v)))>;
|
(HLO_ComplexOp (HLO_RealOp $v), (HLO_NegOp (HLO_ImagOp $v)))>;
|
||||||
|
|
||||||
// Express `cosh` as
|
|
||||||
// cosh(x) = (e^x + e^-x) / 2
|
|
||||||
// = e^(x + log(1/2)) + e^(-x + log(1/2))
|
|
||||||
//
|
|
||||||
// The second formulation avoids overflowing when e^x = inf but (e^x)/2 is not.
|
|
||||||
//
|
|
||||||
// This incorrectly overflows to inf for two f32 input values, namely
|
|
||||||
// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
|
|
||||||
// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
|
|
||||||
// we deem this acceptable.
|
|
||||||
def : Pat<(HLOClient_CoshOp NonComplexElementType:$input),
|
|
||||||
(HLO_AddOp
|
|
||||||
(HLO_ExpOp
|
|
||||||
(HLO_AddOp
|
|
||||||
$input,
|
|
||||||
(HLO_LogOp
|
|
||||||
(HLO_ConstantLike<"0.5"> $input)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
),
|
|
||||||
(HLO_ExpOp
|
|
||||||
(HLO_AddOp
|
|
||||||
(HLO_NegOp $input),
|
|
||||||
(HLO_LogOp
|
|
||||||
(HLO_ConstantLike<"0.5"> $input)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)>;
|
|
||||||
|
|
||||||
// Express `is_inf` as
|
// Express `is_inf` as
|
||||||
// is_inf(x) = is_pos_inf(|x|)
|
// is_inf(x) = is_pos_inf(|x|)
|
||||||
def : Pat<(HLOClient_IsInfOp NonComplexElementType:$input),
|
def : Pat<(HLOClient_IsInfOp NonComplexElementType:$input),
|
||||||
|
|
|
@ -2187,3 +2187,32 @@ func @sinh_complex(%x : tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> {
|
||||||
%1 = chlo.sinh %x : tensor<2xcomplex<f32>> -> tensor<2xcomplex<f32>>
|
%1 = chlo.sinh %x : tensor<2xcomplex<f32>> -> tensor<2xcomplex<f32>>
|
||||||
return %1 : tensor<2xcomplex<f32>>
|
return %1 : tensor<2xcomplex<f32>>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @cosh_f32
|
||||||
|
// CHECK-SAME: (%[[X:.*]]: tensor<f32>)
|
||||||
|
func @cosh_f32(%x : tensor<f32>) -> tensor<f32> {
|
||||||
|
// CHECK: %[[HALF:.*]] = mhlo.constant dense<5.000000e-01> : tensor<f32>
|
||||||
|
// CHECK: %[[LOG_HALF:.*]] = "mhlo.log"(%[[HALF]]) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[X_PLUS_LOG_HALF:.*]] = mhlo.add %[[X]], %[[LOG_HALF]] : tensor<f32>
|
||||||
|
// CHECK: %[[EXP_1:.*]] = "mhlo.exponential"(%[[X_PLUS_LOG_HALF]]) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[LOG_HALF_MINUS_X:.*]] = mhlo.subtract %[[LOG_HALF]], %[[X]] : tensor<f32>
|
||||||
|
// CHECK: %[[EXP_2:.*]] = "mhlo.exponential"(%[[LOG_HALF_MINUS_X]]) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[RESULT:.*]] = mhlo.add %[[EXP_1]], %[[EXP_2]] : tensor<f32>
|
||||||
|
// CHECK: return %[[RESULT]] : tensor<f32>
|
||||||
|
%1 = chlo.cosh %x : tensor<f32> -> tensor<f32>
|
||||||
|
return %1 : tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @cosh_f16
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<f16>)
|
||||||
|
func @cosh_f16(%x : tensor<f16>) -> tensor<f16> {
|
||||||
|
// CHECK: "mhlo.convert"(%[[ARG0]]) : (tensor<f16>) -> tensor<f32>
|
||||||
|
// CHECK: %[[RES:.*]] = "mhlo.convert"(%{{.*}}) : (tensor<f32>) -> tensor<f16>
|
||||||
|
// CHECK: return %[[RES]]
|
||||||
|
%1 = chlo.cosh %x : tensor<f16> -> tensor<f16>
|
||||||
|
return %1 : tensor<f16>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue