diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h index c1d7ffc..3f28937 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h @@ -66,6 +66,8 @@ static Value getConstantLike(OpBuilder& b, Location loc, T constant, return b.create(loc, getAttr(), val); } +Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val); + } // namespace chlo } // namespace mlir diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index 37e6727..558da58 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -372,6 +372,18 @@ def HLOClient_AsinOp : HLOClient_UnaryElementwiseOp<"asin", [], }]; } +def HLOClient_AsinhOp : HLOClient_UnaryElementwiseOp<"asinh", [], + HLO_FpOrComplexTensor> { + let summary = "Asinh operation"; + + let description = [{ + Returns `Asinh(operand)` element-wise. + + $$ + \asinh(x) = log(x + sqrt(x^2 + 1)) + $$ + }]; +} def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [], HLO_FpOrComplexTensor> { let summary = "Atan operator"; diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td index 461527f..84df353 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td @@ -30,6 +30,9 @@ class ConstantSplat : NativeCodeCall< class HLO_ConstantLike : NativeCodeCall< "chlo::getConstantLike($_builder, $_loc, " # value # ", $0)">; +def HLO_ConstantLikeMaxFiniteValue : NativeCodeCall< + "chlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">; + def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">; def BinBroadcastDimensions : NativeCodeCall< diff --git a/lib/Dialect/mhlo/IR/chlo_ops.cc b/lib/Dialect/mhlo/IR/chlo_ops.cc index 9761e6a..fa6cc01 100644 --- a/lib/Dialect/mhlo/IR/chlo_ops.cc +++ b/lib/Dialect/mhlo/IR/chlo_ops.cc @@ -32,6 +32,20 @@ static LogicalResult Verify(T op) { return success(); } +static constexpr float kF16MaxFiniteValue = 0x1.ffcP15; + +Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val) { + Type ty = getElementTypeOrSelf(val.getType()); + if (ty.isF16()) { + return getConstantLike(b, loc, kF16MaxFiniteValue, val); + } else if (ty.isF32()) { + return getConstantLike(b, loc, std::numeric_limits::max(), val); + } else if (ty.isF64()) { + return getConstantLike(b, loc, std::numeric_limits::max(), val); + } + llvm_unreachable("unhandled type"); +} + //===----------------------------------------------------------------------===// // BinaryOps //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td index 3d89f92..b8b6abb 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td @@ -79,6 +79,94 @@ def : Pat<(HLOClient_AsinOp NonComplexElementType:$input), ) )>; +// Expand asinh to MHLO dialect as +// asinh(x) = log(x + sqrt(x^2 + 1)) +// +// If x^2 will overflow and x is positive, we can approximate x + sqrt(x^2 + 1) +// as 2*x and return log(2) + log(x). +// +// For small x, sqrt(x^2 + 1) will evaluate to 1 due to floating point +// arithmetic. However, we would like to retain the low order term of this, +// which is around 0.5 * x^2 using a binomial expansion. +// Let z = sqrt(a^2 + 1) +// The following rewrite retains the lower order term. +// log(a + sqrt(a^2 + 1)) +// = log((a + sqrt(a^2 + 1)) * (1 + sqrt(a^2 + 1)) / (1 + sqrt(a^2 + 1))) +// = log((a + a^2 + 1 + a * z + z) / (1 + z)) +// = log(1 + a + a^2 / (1 + z)) +// = log(1 + a + a^2 / (1 + sqrt(a^2 + 1))) +// +// If x is negative, the above would give us some trouble; we can't approximate +// the result as x + abs(x) = 0 but we are saved by the fact that asinh(-x) = +// -asinh(x). +def : Pat<(HLOClient_AsinhOp NonComplexElementType:$input), + (HLO_MulOp + (HLO_SignOp $input), + (HLO_SelectOp + (HLO_CompareOp + (HLO_AbsOp $input), + (HLO_SqrtOp + (HLO_ConstantLikeMaxFiniteValue $input) + ), + HLO_COMPARISON_DIRECTION_GE, + (HLO_DEFAULT_COMPARISON_TYPE) + ), + (HLO_AddOp + (HLO_LogOp + (HLO_AbsOp $input) + ), + (HLO_LogOp + (HLO_ConstantLike<"2"> $input) + ) + ), + (HLO_SelectOp + (HLO_CompareOp + (HLO_AbsOp $input), + (HLO_ConstantLike<"1"> $input), + HLO_COMPARISON_DIRECTION_LE, + (HLO_DEFAULT_COMPARISON_TYPE) + ), + (HLO_Log1pOp + (HLO_AddOp + (HLO_AbsOp $input), + (HLO_MulOp + (HLO_AbsOp $input), + (HLO_DivOp + (HLO_AbsOp $input), + (HLO_AddOp + (HLO_ConstantLike<"1"> $input), + (HLO_SqrtOp + (HLO_AddOp + (HLO_MulOp + (HLO_AbsOp $input), + (HLO_AbsOp $input) + ), + (HLO_ConstantLike<"1"> $input) + ) + ) + ) + ) + ) + ) + ), + (HLO_LogOp + (HLO_AddOp + (HLO_AbsOp $input), + (HLO_SqrtOp + (HLO_AddOp + (HLO_MulOp + (HLO_AbsOp $input), + (HLO_AbsOp $input) + ), + (HLO_ConstantLike<"1"> $input) + ) + ) + ) + ) + ) + ) + )>; + // Express `atan` as // atan(x) = atan2(x, 1) def : Pat<(HLOClient_AtanOp $input), diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 152be21..bd6d891 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -50,9 +50,9 @@ namespace { sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp) // TODO(herhut): Generate these out of op definitions. -#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \ - fn(AcosOp) sep fn(AsinOp) sep fn(AtanOp) sep fn(ConjOp) sep fn(CoshOp) \ - sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp) +#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \ + fn(AcosOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) sep fn(ConjOp) \ + sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp) template inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {