diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h index c1d7ffc..fb55402 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h @@ -66,6 +66,11 @@ static Value getConstantLike(OpBuilder& b, Location loc, T constant, return b.create(loc, getAttr(), val); } +Value getConstantLike(OpBuilder& b, Location loc, const APFloat& constant, + Value val); + +Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val); + } // namespace chlo } // namespace mlir diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index 1adfa93..c633bd2 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -372,6 +372,18 @@ def HLOClient_AsinOp : HLOClient_UnaryElementwiseOp<"asin", [], }]; } +def HLOClient_AsinhOp : HLOClient_UnaryElementwiseOp<"asinh", [], + HLO_FpOrComplexTensor> { + let summary = "Asinh operation"; + + let description = [{ + Returns `Asinh(operand)` element-wise. + + $$ + \asinh(x) = log(x + sqrt(x^2 + 1)) + $$ + }]; +} def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [], HLO_FpOrComplexTensor> { let summary = "Atan operator"; diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td index 461527f..84df353 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td @@ -30,6 +30,9 @@ class ConstantSplat : NativeCodeCall< class HLO_ConstantLike : NativeCodeCall< "chlo::getConstantLike($_builder, $_loc, " # value # ", $0)">; +def HLO_ConstantLikeMaxFiniteValue : NativeCodeCall< + "chlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">; + def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">; def BinBroadcastDimensions : NativeCodeCall< diff --git a/lib/Dialect/mhlo/IR/chlo_ops.cc b/lib/Dialect/mhlo/IR/chlo_ops.cc index 9761e6a..822240b 100644 --- a/lib/Dialect/mhlo/IR/chlo_ops.cc +++ b/lib/Dialect/mhlo/IR/chlo_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "llvm/ADT/APFloat.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/utils/broadcast_utils.h" #include "mlir/IR/Attributes.h" @@ -32,6 +33,18 @@ static LogicalResult Verify(T op) { return success(); } +Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val) { + auto ty = getElementTypeOrSelf(val.getType()).cast(); + return getConstantLike( + b, loc, llvm::APFloat::getLargest(ty.getFloatSemantics()), val); +} + +Value getConstantLike(OpBuilder& b, Location loc, const APFloat& constant, + Value val) { + Type ty = getElementTypeOrSelf(val.getType()); + return b.create(loc, b.getFloatAttr(ty, constant), val); +} + //===----------------------------------------------------------------------===// // BinaryOps //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td index 5b9a300..a2b97a8 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td @@ -79,6 +79,94 @@ def : Pat<(HLOClient_AsinOp NonComplexElementType:$input), ) )>; +// Expand asinh to MHLO dialect as +// asinh(x) = log(x + sqrt(x^2 + 1)) +// +// If x^2 will overflow and x is positive, we can approximate x + sqrt(x^2 + 1) +// as 2*x and return log(2) + log(x). +// +// For small x, sqrt(x^2 + 1) will evaluate to 1 due to floating point +// arithmetic. However, we would like to retain the low order term of this, +// which is around 0.5 * x^2 using a binomial expansion. +// Let z = sqrt(a^2 + 1) +// The following rewrite retains the lower order term. +// log(a + sqrt(a^2 + 1)) +// = log((a + sqrt(a^2 + 1)) * (1 + sqrt(a^2 + 1)) / (1 + sqrt(a^2 + 1))) +// = log((a + a^2 + 1 + a * z + z) / (1 + z)) +// = log(1 + a + a^2 / (1 + z)) +// = log(1 + a + a^2 / (1 + sqrt(a^2 + 1))) +// +// If x is negative, the above would give us some trouble; we can't approximate +// the result as x + abs(x) = 0 but we are saved by the fact that asinh(-x) = +// -asinh(x). +def : Pat<(HLOClient_AsinhOp NonComplexElementType:$input), + (HLO_MulOp + (HLO_SignOp $input), + (HLO_SelectOp + (HLO_CompareOp + (HLO_AbsOp $input), + (HLO_SqrtOp + (HLO_ConstantLikeMaxFiniteValue $input) + ), + HLO_COMPARISON_DIRECTION_GE, + (HLO_DEFAULT_COMPARISON_TYPE) + ), + (HLO_AddOp + (HLO_LogOp + (HLO_AbsOp $input) + ), + (HLO_LogOp + (HLO_ConstantLike<"2"> $input) + ) + ), + (HLO_SelectOp + (HLO_CompareOp + (HLO_AbsOp $input), + (HLO_ConstantLike<"1"> $input), + HLO_COMPARISON_DIRECTION_LE, + (HLO_DEFAULT_COMPARISON_TYPE) + ), + (HLO_Log1pOp + (HLO_AddOp + (HLO_AbsOp $input), + (HLO_MulOp + (HLO_AbsOp $input), + (HLO_DivOp + (HLO_AbsOp $input), + (HLO_AddOp + (HLO_ConstantLike<"1"> $input), + (HLO_SqrtOp + (HLO_AddOp + (HLO_MulOp + (HLO_AbsOp $input), + (HLO_AbsOp $input) + ), + (HLO_ConstantLike<"1"> $input) + ) + ) + ) + ) + ) + ) + ), + (HLO_LogOp + (HLO_AddOp + (HLO_AbsOp $input), + (HLO_SqrtOp + (HLO_AddOp + (HLO_MulOp + (HLO_AbsOp $input), + (HLO_AbsOp $input) + ), + (HLO_ConstantLike<"1"> $input) + ) + ) + ) + ) + ) + ) + )>; + // Express `atan` as // atan(x) = atan2(x, 1) def : Pat<(HLOClient_AtanOp $input), diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index cc9b3d9..70d5d38 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -50,9 +50,10 @@ namespace { sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp) // TODO(herhut): Generate these out of op definitions. -#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \ - fn(AcosOp) sep fn(AsinOp) sep fn(AtanOp) sep fn(AtanhOp) sep fn(ConjOp) \ - sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp) +#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \ + fn(AcosOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) sep fn(AtanhOp) \ + sep fn(ConjOp) sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp) \ + sep fn(SinhOp) sep fn(TanOp) template inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { diff --git a/tests/chlo_legalize_to_mhlo.mlir b/tests/chlo_legalize_to_mhlo.mlir index fcc8bf6..1a8613f 100644 --- a/tests/chlo_legalize_to_mhlo.mlir +++ b/tests/chlo_legalize_to_mhlo.mlir @@ -1,5 +1,80 @@ // RUN: mlir-hlo-opt --chlo-legalize-to-hlo --split-input-file %s | FileCheck %s +// CHECK-LABEL: @asinh_bf16 +// CHECK-SAME: %[[ARG:.*]]: tensor +func @asinh_bf16(%arg : tensor) -> tensor { + // Check for the bf16-specific max value. + // CHECK: mhlo.constant dense<3.389{{.*}}e+38> + %result = "chlo.asinh"(%arg) : (tensor) -> tensor + return %result : tensor +} + +// CHECK-LABEL: @asinh_f16 +// CHECK-SAME: %[[ARG:.*]]: tensor +func @asinh_f16(%arg : tensor) -> tensor { + // Check for the f16-specific max value. + // CHECK: mhlo.constant dense<6.550{{.*}}e+04> + %result = "chlo.asinh"(%arg) : (tensor) -> tensor + return %result : tensor +} + +// CHECK-LABEL: @asinh_f32 +// CHECK-SAME: %[[ARG:.*]]: tensor +func @asinh_f32(%arg : tensor) -> tensor { + // Check for the f32-specific max value. + // CHECK: mhlo.constant dense<3.402{{.*}}E+38> + %result = "chlo.asinh"(%arg) : (tensor) -> tensor + return %result : tensor +} + +// CHECK-LABEL: @asinh_f64 +// CHECK-SAME: %[[ARG:.*]]: tensor +func @asinh_f64(%arg : tensor) -> tensor { + // CHECK: %[[TMP_0:.*]] = "mhlo.sign"(%[[ARG]]) + // CHECK: %[[TMP_1:.*]] = "mhlo.abs"(%[[ARG]]) + // CHECK: %[[TMP_2:.*]] = mhlo.constant dense<1.797{{.*}}E+308> + // CHECK: %[[TMP_3:.*]] = "mhlo.sqrt"(%[[TMP_2]]) + // CHECK: %[[TMP_4:.*]] = "mhlo.compare"(%[[TMP_1]], %[[TMP_3]]) {comparison_direction = "GE"} + // CHECK: %[[TMP_5:.*]] = "mhlo.abs"(%[[ARG]]) + // CHECK: %[[TMP_6:.*]] = "mhlo.log"(%[[TMP_5]]) + // CHECK: %[[TMP_7:.*]] = mhlo.constant dense<2.000{{.*}}e+00> + // CHECK: %[[TMP_8:.*]] = "mhlo.log"(%[[TMP_7]]) + // CHECK: %[[TMP_9:.*]] = mhlo.add %[[TMP_6]], %[[TMP_8]] + // CHECK: %[[TMP_10:.*]] = "mhlo.abs"(%[[ARG]]) + // CHECK: %[[TMP_11:.*]] = mhlo.constant dense<1.000{{.*}}e+00> + // CHECK: %[[TMP_12:.*]] = "mhlo.compare"(%[[TMP_10]], %[[TMP_11]]) {comparison_direction = "LE"} + // CHECK: %[[TMP_13:.*]] = "mhlo.abs"(%[[ARG]]) + // CHECK: %[[TMP_14:.*]] = "mhlo.abs"(%[[ARG]]) + // CHECK: %[[TMP_15:.*]] = "mhlo.abs"(%[[ARG]]) + // CHECK: %[[TMP_16:.*]] = "mhlo.abs"(%[[ARG]]) + // CHECK: %[[TMP_17:.*]] = "mhlo.abs"(%[[ARG]]) + // CHECK: %[[TMP_18:.*]] = mhlo.multiply %[[TMP_16]], %[[TMP_17]] + // CHECK: %[[TMP_19:.*]] = mhlo.constant dense<1.000{{.*}}e+00> + // CHECK: %[[TMP_20:.*]] = mhlo.add %[[TMP_18]], %[[TMP_19]] + // CHECK: %[[TMP_21:.*]] = "mhlo.sqrt"(%[[TMP_20]]) + // CHECK: %[[TMP_22:.*]] = mhlo.constant dense<1.000{{.*}}e+00> + // CHECK: %[[TMP_23:.*]] = mhlo.add %[[TMP_22]], %[[TMP_21]] + // CHECK: %[[TMP_24:.*]] = mhlo.divide %[[TMP_15]], %[[TMP_23]] + // CHECK: %[[TMP_25:.*]] = mhlo.multiply %[[TMP_14]], %[[TMP_24]] + // CHECK: %[[TMP_26:.*]] = mhlo.add %[[TMP_13]], %[[TMP_25]] + // CHECK: %[[TMP_27:.*]] = "mhlo.log_plus_one"(%[[TMP_26]]) + // CHECK: %[[TMP_28:.*]] = "mhlo.abs"(%[[ARG]]) + // CHECK: %[[TMP_29:.*]] = "mhlo.abs"(%[[ARG]]) + // CHECK: %[[TMP_30:.*]] = "mhlo.abs"(%[[ARG]]) + // CHECK: %[[TMP_31:.*]] = mhlo.multiply %[[TMP_29]], %[[TMP_30]] + // CHECK: %[[TMP_32:.*]] = mhlo.constant dense<1.000{{.*}}e+00> + // CHECK: %[[TMP_33:.*]] = mhlo.add %[[TMP_31]], %[[TMP_32]] + // CHECK: %[[TMP_34:.*]] = "mhlo.sqrt"(%[[TMP_33]]) + // CHECK: %[[TMP_35:.*]] = mhlo.add %[[TMP_28]], %[[TMP_34]] + // CHECK: %[[TMP_36:.*]] = "mhlo.log"(%[[TMP_35]]) + // CHECK: %[[TMP_37:.*]] = "mhlo.select"(%[[TMP_12]], %[[TMP_27]], %[[TMP_36]]) + // CHECK: %[[TMP_38:.*]] = "mhlo.select"(%[[TMP_4]], %[[TMP_9]], %[[TMP_37]]) + // CHECK: %[[RES:.*]] = mhlo.multiply %[[TMP_0]], %[[TMP_38]] + // CHECK: return %[[RES]] + %result = "chlo.asinh"(%arg) : (tensor) -> tensor + return %result : tensor +} + // Lower statically shaped `constant_like` to constant. // CHECK-LABEL: @constant_like_static_shape func @constant_like_static_shape(%arg : tensor<1x2xi64>) -> tensor<1x2xf32> {