Add chlo.acosh operation and associated lowerings.

PiperOrigin-RevId: 352839289
This commit is contained in:
Stephan Herhut 2021-01-20 11:42:12 -08:00 committed by TensorFlow MLIR Team
parent a7e645f37e
commit 70a351f301
4 changed files with 98 additions and 3 deletions

View File

@ -359,6 +359,20 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [],
}];
}
def HLOClient_AcoshOp : HLOClient_UnaryElementwiseOp<"acosh", [],
HLO_FpOrComplexTensor> {
let summary = "Acosh operation";
let description = [{
Returns `Acosh(operand)` element-wise.
$$
\acosh(x) = log(x + sqrt(x^2 - 1)) if x >= -1
\acosh(x) = nan if x < -1
$$
}];
}
def HLOClient_AsinOp : HLOClient_UnaryElementwiseOp<"asin", [],
HLO_FpOrComplexTensor> {
let summary = "Asin operator";
@ -384,6 +398,7 @@ def HLOClient_AsinhOp : HLOClient_UnaryElementwiseOp<"asinh", [],
$$
}];
}
def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
HLO_FpOrComplexTensor> {
let summary = "Atan operator";

View File

@ -60,6 +60,58 @@ def : Pat<(HLOClient_AcosOp NonComplexElementType:$input),
(HLO_ConstantLike<"M_PI"> $input)
)>;
// Expand acosh to MHLO dialect as follows:
// acosh(x) = log(x + sqrt(x^2 - 1)) if x >= -1
// = log(x + sqrt((x+1)*(x-1)))
// acosh(x) = nan if x < -1
//
// If x^2 will overflow, we approximate sqrt(x^2 - 1) == x and compute as
// log(2*x) = log(2) + log(x). (Note this works because negative x never
// overflows; x < -1 simply yields nan.
def : Pat<(HLOClient_AcoshOp NonComplexElementType:$input),
(HLO_SelectOp
(HLO_CompareOp
$input,
(HLO_ConstantLike<"-1"> $input),
HLO_COMPARISON_DIRECTION_LT,
(HLO_DEFAULT_COMPARISON_TYPE)
),
(HLO_ConstantLike<"NAN"> $input),
(HLO_SelectOp
(HLO_CompareOp
$input,
(HLO_SqrtOp
(HLO_ConstantLikeMaxFiniteValue $input)
),
HLO_COMPARISON_DIRECTION_GE,
(HLO_DEFAULT_COMPARISON_TYPE)
),
(HLO_AddOp
(HLO_LogOp $input),
(HLO_LogOp
(HLO_ConstantLike<"2"> $input)
)
),
(HLO_LogOp
(HLO_AddOp
$input,
(HLO_SqrtOp
(HLO_MulOp
(HLO_AddOp
(HLO_ConstantLike<"1"> $input),
$input
),
(HLO_AddOp
(HLO_ConstantLike<"-1"> $input),
$input
)
)
)
)
)
)
)>;
// Expand asin to MHLO dialect as follows:
// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
def : Pat<(HLOClient_AsinOp NonComplexElementType:$input),

View File

@ -51,9 +51,9 @@ namespace {
// TODO(herhut): Generate these out of op definitions.
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
fn(AcosOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) sep fn(AtanhOp) \
sep fn(ConjOp) sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp) \
sep fn(SinhOp) sep fn(TanOp)
fn(AcosOp) sep fn(AcoshOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) \
sep fn(AtanhOp) sep fn(ConjOp) sep fn(CoshOp) sep fn(ErfOp) \
sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp)
template <typename OpTy>
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {

View File

@ -171,3 +171,31 @@ func @erf_f16(%arg : tensor<f16>) -> tensor<f16> {
%1 = "chlo.erf"(%arg) : (tensor<f16>) -> tensor<f16>
return %1 : tensor<f16>
}
// CHECK-LABEL: @acosh
// CHECK-SAME: %[[ARG:.*]]: tensor<f16>
func @acosh(%arg: tensor<f16>) -> tensor<f16> {
// CHECK: %[[MINUSONE:.*]] = mhlo.constant dense<-1.000000e+00>
// CHECK: %[[CMP:.*]] = "mhlo.compare"(%[[ARG]], %[[MINUSONE]]) {comparison_direction = "LT"}
// CHECK: %[[MAX:.*]] = mhlo.constant dense<6.550400e+04>
// CHECK: %[[SQRTMAX:.*]] = "mhlo.sqrt"(%[[MAX]])
// CHECK: %[[OVERFLOW:.*]] = "mhlo.compare"(%[[ARG]], %[[SQRTMAX]]) {comparison_direction = "GE"}
// CHECK: %[[LOGARG:.*]] = "mhlo.log"(%[[ARG]])
// CHECK: %[[TWO:.*]] = mhlo.constant dense<2.000000e+00>
// CHECK: %[[LOGTWO:.*]] = "mhlo.log"(%[[TWO]])
// CHECK: %[[OFLRES:.*]] = mhlo.add %[[LOGARG]], %[[LOGTWO]]
// CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00>
// CHECK: %[[ARGPONE:.*]] = mhlo.add %[[ONE]], %[[ARG]]
// CHECK: %[[MINUSONE2:.*]] = mhlo.constant dense<-1.000000e+00>
// CHECK: %[[ARGMONE:.*]] = mhlo.add %[[MINUSONE2]], %[[ARG]]
// CHECK: %[[MUL:.*]] = mhlo.multiply %[[ARGPONE]], %[[ARGMONE]]
// CHECK: %[[SQRT:.*]] = "mhlo.sqrt"(%[[MUL]])
// CHECK: %[[APSQRT:.*]] = mhlo.add %[[ARG]], %[[SQRT]]
// CHECK: %[[LOGAPMUL:.*]] = "mhlo.log"(%[[APSQRT]])
// CHECK: %[[SEL1:.*]] = "mhlo.select"(%[[OVERFLOW]], %[[OFLRES]], %[[LOGAPMUL]])
// CHECK: %[[NAN:.*]] = mhlo.constant dense<0x7E00>
// CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[CMP]], %[[NAN]], %[[SEL1]])
// CHECK: return %[[RESULT]]
%1 = "chlo.acosh"(%arg) : (tensor<f16>) -> tensor<f16>
return %1 : tensor<f16>
}