Add chlo.acosh operation and associated lowerings.
PiperOrigin-RevId: 352839289
This commit is contained in:
parent
a7e645f37e
commit
70a351f301
|
@ -359,6 +359,20 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [],
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def HLOClient_AcoshOp : HLOClient_UnaryElementwiseOp<"acosh", [],
|
||||||
|
HLO_FpOrComplexTensor> {
|
||||||
|
let summary = "Acosh operation";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Returns `Acosh(operand)` element-wise.
|
||||||
|
|
||||||
|
$$
|
||||||
|
\acosh(x) = log(x + sqrt(x^2 - 1)) if x >= -1
|
||||||
|
\acosh(x) = nan if x < -1
|
||||||
|
$$
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def HLOClient_AsinOp : HLOClient_UnaryElementwiseOp<"asin", [],
|
def HLOClient_AsinOp : HLOClient_UnaryElementwiseOp<"asin", [],
|
||||||
HLO_FpOrComplexTensor> {
|
HLO_FpOrComplexTensor> {
|
||||||
let summary = "Asin operator";
|
let summary = "Asin operator";
|
||||||
|
@ -384,6 +398,7 @@ def HLOClient_AsinhOp : HLOClient_UnaryElementwiseOp<"asinh", [],
|
||||||
$$
|
$$
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
|
def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
|
||||||
HLO_FpOrComplexTensor> {
|
HLO_FpOrComplexTensor> {
|
||||||
let summary = "Atan operator";
|
let summary = "Atan operator";
|
||||||
|
|
|
@ -60,6 +60,58 @@ def : Pat<(HLOClient_AcosOp NonComplexElementType:$input),
|
||||||
(HLO_ConstantLike<"M_PI"> $input)
|
(HLO_ConstantLike<"M_PI"> $input)
|
||||||
)>;
|
)>;
|
||||||
|
|
||||||
|
// Expand acosh to MHLO dialect as follows:
|
||||||
|
// acosh(x) = log(x + sqrt(x^2 - 1)) if x >= -1
|
||||||
|
// = log(x + sqrt((x+1)*(x-1)))
|
||||||
|
// acosh(x) = nan if x < -1
|
||||||
|
//
|
||||||
|
// If x^2 will overflow, we approximate sqrt(x^2 - 1) == x and compute as
|
||||||
|
// log(2*x) = log(2) + log(x). (Note this works because negative x never
|
||||||
|
// overflows; x < -1 simply yields nan.
|
||||||
|
def : Pat<(HLOClient_AcoshOp NonComplexElementType:$input),
|
||||||
|
(HLO_SelectOp
|
||||||
|
(HLO_CompareOp
|
||||||
|
$input,
|
||||||
|
(HLO_ConstantLike<"-1"> $input),
|
||||||
|
HLO_COMPARISON_DIRECTION_LT,
|
||||||
|
(HLO_DEFAULT_COMPARISON_TYPE)
|
||||||
|
),
|
||||||
|
(HLO_ConstantLike<"NAN"> $input),
|
||||||
|
(HLO_SelectOp
|
||||||
|
(HLO_CompareOp
|
||||||
|
$input,
|
||||||
|
(HLO_SqrtOp
|
||||||
|
(HLO_ConstantLikeMaxFiniteValue $input)
|
||||||
|
),
|
||||||
|
HLO_COMPARISON_DIRECTION_GE,
|
||||||
|
(HLO_DEFAULT_COMPARISON_TYPE)
|
||||||
|
),
|
||||||
|
(HLO_AddOp
|
||||||
|
(HLO_LogOp $input),
|
||||||
|
(HLO_LogOp
|
||||||
|
(HLO_ConstantLike<"2"> $input)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
(HLO_LogOp
|
||||||
|
(HLO_AddOp
|
||||||
|
$input,
|
||||||
|
(HLO_SqrtOp
|
||||||
|
(HLO_MulOp
|
||||||
|
(HLO_AddOp
|
||||||
|
(HLO_ConstantLike<"1"> $input),
|
||||||
|
$input
|
||||||
|
),
|
||||||
|
(HLO_AddOp
|
||||||
|
(HLO_ConstantLike<"-1"> $input),
|
||||||
|
$input
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)>;
|
||||||
|
|
||||||
// Expand asin to MHLO dialect as follows:
|
// Expand asin to MHLO dialect as follows:
|
||||||
// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
|
// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
|
||||||
def : Pat<(HLOClient_AsinOp NonComplexElementType:$input),
|
def : Pat<(HLOClient_AsinOp NonComplexElementType:$input),
|
||||||
|
|
|
@ -51,9 +51,9 @@ namespace {
|
||||||
|
|
||||||
// TODO(herhut): Generate these out of op definitions.
|
// TODO(herhut): Generate these out of op definitions.
|
||||||
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
||||||
fn(AcosOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) sep fn(AtanhOp) \
|
fn(AcosOp) sep fn(AcoshOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) \
|
||||||
sep fn(ConjOp) sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp) \
|
sep fn(AtanhOp) sep fn(ConjOp) sep fn(CoshOp) sep fn(ErfOp) \
|
||||||
sep fn(SinhOp) sep fn(TanOp)
|
sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp)
|
||||||
|
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
||||||
|
|
|
@ -171,3 +171,31 @@ func @erf_f16(%arg : tensor<f16>) -> tensor<f16> {
|
||||||
%1 = "chlo.erf"(%arg) : (tensor<f16>) -> tensor<f16>
|
%1 = "chlo.erf"(%arg) : (tensor<f16>) -> tensor<f16>
|
||||||
return %1 : tensor<f16>
|
return %1 : tensor<f16>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @acosh
|
||||||
|
// CHECK-SAME: %[[ARG:.*]]: tensor<f16>
|
||||||
|
func @acosh(%arg: tensor<f16>) -> tensor<f16> {
|
||||||
|
// CHECK: %[[MINUSONE:.*]] = mhlo.constant dense<-1.000000e+00>
|
||||||
|
// CHECK: %[[CMP:.*]] = "mhlo.compare"(%[[ARG]], %[[MINUSONE]]) {comparison_direction = "LT"}
|
||||||
|
// CHECK: %[[MAX:.*]] = mhlo.constant dense<6.550400e+04>
|
||||||
|
// CHECK: %[[SQRTMAX:.*]] = "mhlo.sqrt"(%[[MAX]])
|
||||||
|
// CHECK: %[[OVERFLOW:.*]] = "mhlo.compare"(%[[ARG]], %[[SQRTMAX]]) {comparison_direction = "GE"}
|
||||||
|
// CHECK: %[[LOGARG:.*]] = "mhlo.log"(%[[ARG]])
|
||||||
|
// CHECK: %[[TWO:.*]] = mhlo.constant dense<2.000000e+00>
|
||||||
|
// CHECK: %[[LOGTWO:.*]] = "mhlo.log"(%[[TWO]])
|
||||||
|
// CHECK: %[[OFLRES:.*]] = mhlo.add %[[LOGARG]], %[[LOGTWO]]
|
||||||
|
// CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00>
|
||||||
|
// CHECK: %[[ARGPONE:.*]] = mhlo.add %[[ONE]], %[[ARG]]
|
||||||
|
// CHECK: %[[MINUSONE2:.*]] = mhlo.constant dense<-1.000000e+00>
|
||||||
|
// CHECK: %[[ARGMONE:.*]] = mhlo.add %[[MINUSONE2]], %[[ARG]]
|
||||||
|
// CHECK: %[[MUL:.*]] = mhlo.multiply %[[ARGPONE]], %[[ARGMONE]]
|
||||||
|
// CHECK: %[[SQRT:.*]] = "mhlo.sqrt"(%[[MUL]])
|
||||||
|
// CHECK: %[[APSQRT:.*]] = mhlo.add %[[ARG]], %[[SQRT]]
|
||||||
|
// CHECK: %[[LOGAPMUL:.*]] = "mhlo.log"(%[[APSQRT]])
|
||||||
|
// CHECK: %[[SEL1:.*]] = "mhlo.select"(%[[OVERFLOW]], %[[OFLRES]], %[[LOGAPMUL]])
|
||||||
|
// CHECK: %[[NAN:.*]] = mhlo.constant dense<0x7E00>
|
||||||
|
// CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[CMP]], %[[NAN]], %[[SEL1]])
|
||||||
|
// CHECK: return %[[RESULT]]
|
||||||
|
%1 = "chlo.acosh"(%arg) : (tensor<f16>) -> tensor<f16>
|
||||||
|
return %1 : tensor<f16>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue