diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index c633bd2..68c641c 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -359,6 +359,20 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [], }]; } +def HLOClient_AcoshOp : HLOClient_UnaryElementwiseOp<"acosh", [], + HLO_FpOrComplexTensor> { + let summary = "Acosh operation"; + + let description = [{ + Returns `Acosh(operand)` element-wise. + + $$ + \acosh(x) = log(x + sqrt(x^2 - 1)) if x >= -1 + \acosh(x) = nan if x < -1 + $$ + }]; +} + def HLOClient_AsinOp : HLOClient_UnaryElementwiseOp<"asin", [], HLO_FpOrComplexTensor> { let summary = "Asin operator"; @@ -384,6 +398,7 @@ def HLOClient_AsinhOp : HLOClient_UnaryElementwiseOp<"asinh", [], $$ }]; } + def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [], HLO_FpOrComplexTensor> { let summary = "Atan operator"; diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td index a2b97a8..492e714 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td @@ -60,6 +60,58 @@ def : Pat<(HLOClient_AcosOp NonComplexElementType:$input), (HLO_ConstantLike<"M_PI"> $input) )>; +// Expand acosh to MHLO dialect as follows: +// acosh(x) = log(x + sqrt(x^2 - 1)) if x >= -1 +// = log(x + sqrt((x+1)*(x-1))) +// acosh(x) = nan if x < -1 +// +// If x^2 will overflow, we approximate sqrt(x^2 - 1) == x and compute as +// log(2*x) = log(2) + log(x). (Note this works because negative x never +// overflows; x < -1 simply yields nan. +def : Pat<(HLOClient_AcoshOp NonComplexElementType:$input), + (HLO_SelectOp + (HLO_CompareOp + $input, + (HLO_ConstantLike<"-1"> $input), + HLO_COMPARISON_DIRECTION_LT, + (HLO_DEFAULT_COMPARISON_TYPE) + ), + (HLO_ConstantLike<"NAN"> $input), + (HLO_SelectOp + (HLO_CompareOp + $input, + (HLO_SqrtOp + (HLO_ConstantLikeMaxFiniteValue $input) + ), + HLO_COMPARISON_DIRECTION_GE, + (HLO_DEFAULT_COMPARISON_TYPE) + ), + (HLO_AddOp + (HLO_LogOp $input), + (HLO_LogOp + (HLO_ConstantLike<"2"> $input) + ) + ), + (HLO_LogOp + (HLO_AddOp + $input, + (HLO_SqrtOp + (HLO_MulOp + (HLO_AddOp + (HLO_ConstantLike<"1"> $input), + $input + ), + (HLO_AddOp + (HLO_ConstantLike<"-1"> $input), + $input + ) + ) + ) + ) + ) + ) + )>; + // Expand asin to MHLO dialect as follows: // asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) def : Pat<(HLOClient_AsinOp NonComplexElementType:$input), diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index b359217..063d077 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -51,9 +51,9 @@ namespace { // TODO(herhut): Generate these out of op definitions. #define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \ - fn(AcosOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) sep fn(AtanhOp) \ - sep fn(ConjOp) sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp) \ - sep fn(SinhOp) sep fn(TanOp) + fn(AcosOp) sep fn(AcoshOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) \ + sep fn(AtanhOp) sep fn(ConjOp) sep fn(CoshOp) sep fn(ErfOp) \ + sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp) template inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { diff --git a/tests/chlo_legalize_to_mhlo.mlir b/tests/chlo_legalize_to_mhlo.mlir index 1a8613f..ccc9c92 100644 --- a/tests/chlo_legalize_to_mhlo.mlir +++ b/tests/chlo_legalize_to_mhlo.mlir @@ -171,3 +171,31 @@ func @erf_f16(%arg : tensor) -> tensor { %1 = "chlo.erf"(%arg) : (tensor) -> tensor return %1 : tensor } + +// CHECK-LABEL: @acosh +// CHECK-SAME: %[[ARG:.*]]: tensor +func @acosh(%arg: tensor) -> tensor { + // CHECK: %[[MINUSONE:.*]] = mhlo.constant dense<-1.000000e+00> + // CHECK: %[[CMP:.*]] = "mhlo.compare"(%[[ARG]], %[[MINUSONE]]) {comparison_direction = "LT"} + // CHECK: %[[MAX:.*]] = mhlo.constant dense<6.550400e+04> + // CHECK: %[[SQRTMAX:.*]] = "mhlo.sqrt"(%[[MAX]]) + // CHECK: %[[OVERFLOW:.*]] = "mhlo.compare"(%[[ARG]], %[[SQRTMAX]]) {comparison_direction = "GE"} + // CHECK: %[[LOGARG:.*]] = "mhlo.log"(%[[ARG]]) + // CHECK: %[[TWO:.*]] = mhlo.constant dense<2.000000e+00> + // CHECK: %[[LOGTWO:.*]] = "mhlo.log"(%[[TWO]]) + // CHECK: %[[OFLRES:.*]] = mhlo.add %[[LOGARG]], %[[LOGTWO]] + // CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> + // CHECK: %[[ARGPONE:.*]] = mhlo.add %[[ONE]], %[[ARG]] + // CHECK: %[[MINUSONE2:.*]] = mhlo.constant dense<-1.000000e+00> + // CHECK: %[[ARGMONE:.*]] = mhlo.add %[[MINUSONE2]], %[[ARG]] + // CHECK: %[[MUL:.*]] = mhlo.multiply %[[ARGPONE]], %[[ARGMONE]] + // CHECK: %[[SQRT:.*]] = "mhlo.sqrt"(%[[MUL]]) + // CHECK: %[[APSQRT:.*]] = mhlo.add %[[ARG]], %[[SQRT]] + // CHECK: %[[LOGAPMUL:.*]] = "mhlo.log"(%[[APSQRT]]) + // CHECK: %[[SEL1:.*]] = "mhlo.select"(%[[OVERFLOW]], %[[OFLRES]], %[[LOGAPMUL]]) + // CHECK: %[[NAN:.*]] = mhlo.constant dense<0x7E00> + // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[CMP]], %[[NAN]], %[[SEL1]]) + // CHECK: return %[[RESULT]] + %1 = "chlo.acosh"(%arg) : (tensor) -> tensor + return %1 : tensor +}