From 0e85b4d511cb6ae55278e7b18946025d842f65ce Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 19 Jan 2021 10:50:33 -0800 Subject: [PATCH] [MLIR][KernelGen] Add `tf.Asinh` kernels and complete their lowerings PiperOrigin-RevId: 352604725 --- include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h | 2 - include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td | 12 --- include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td | 3 - lib/Dialect/mhlo/IR/chlo_ops.cc | 14 --- .../chlo_legalize_to_hlo_patterns.td | 88 ------------------- .../mhlo/transforms/transform_unranked_hlo.cc | 7 +- 6 files changed, 3 insertions(+), 123 deletions(-) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h index 3f28937..c1d7ffc 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h @@ -66,8 +66,6 @@ static Value getConstantLike(OpBuilder& b, Location loc, T constant, return b.create(loc, getAttr(), val); } -Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val); - } // namespace chlo } // namespace mlir diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index c633bd2..1adfa93 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -372,18 +372,6 @@ def HLOClient_AsinOp : HLOClient_UnaryElementwiseOp<"asin", [], }]; } -def HLOClient_AsinhOp : HLOClient_UnaryElementwiseOp<"asinh", [], - HLO_FpOrComplexTensor> { - let summary = "Asinh operation"; - - let description = [{ - Returns `Asinh(operand)` element-wise. - - $$ - \asinh(x) = log(x + sqrt(x^2 + 1)) - $$ - }]; -} def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [], HLO_FpOrComplexTensor> { let summary = "Atan operator"; diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td index 84df353..461527f 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td @@ -30,9 +30,6 @@ class ConstantSplat : NativeCodeCall< class HLO_ConstantLike : NativeCodeCall< "chlo::getConstantLike($_builder, $_loc, " # value # ", $0)">; -def HLO_ConstantLikeMaxFiniteValue : NativeCodeCall< - "chlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">; - def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">; def BinBroadcastDimensions : NativeCodeCall< diff --git a/lib/Dialect/mhlo/IR/chlo_ops.cc b/lib/Dialect/mhlo/IR/chlo_ops.cc index fa6cc01..9761e6a 100644 --- a/lib/Dialect/mhlo/IR/chlo_ops.cc +++ b/lib/Dialect/mhlo/IR/chlo_ops.cc @@ -32,20 +32,6 @@ static LogicalResult Verify(T op) { return success(); } -static constexpr float kF16MaxFiniteValue = 0x1.ffcP15; - -Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val) { - Type ty = getElementTypeOrSelf(val.getType()); - if (ty.isF16()) { - return getConstantLike(b, loc, kF16MaxFiniteValue, val); - } else if (ty.isF32()) { - return getConstantLike(b, loc, std::numeric_limits::max(), val); - } else if (ty.isF64()) { - return getConstantLike(b, loc, std::numeric_limits::max(), val); - } - llvm_unreachable("unhandled type"); -} - //===----------------------------------------------------------------------===// // BinaryOps //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td index a2b97a8..5b9a300 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_patterns.td @@ -79,94 +79,6 @@ def : Pat<(HLOClient_AsinOp NonComplexElementType:$input), ) )>; -// Expand asinh to MHLO dialect as -// asinh(x) = log(x + sqrt(x^2 + 1)) -// -// If x^2 will overflow and x is positive, we can approximate x + sqrt(x^2 + 1) -// as 2*x and return log(2) + log(x). -// -// For small x, sqrt(x^2 + 1) will evaluate to 1 due to floating point -// arithmetic. However, we would like to retain the low order term of this, -// which is around 0.5 * x^2 using a binomial expansion. -// Let z = sqrt(a^2 + 1) -// The following rewrite retains the lower order term. -// log(a + sqrt(a^2 + 1)) -// = log((a + sqrt(a^2 + 1)) * (1 + sqrt(a^2 + 1)) / (1 + sqrt(a^2 + 1))) -// = log((a + a^2 + 1 + a * z + z) / (1 + z)) -// = log(1 + a + a^2 / (1 + z)) -// = log(1 + a + a^2 / (1 + sqrt(a^2 + 1))) -// -// If x is negative, the above would give us some trouble; we can't approximate -// the result as x + abs(x) = 0 but we are saved by the fact that asinh(-x) = -// -asinh(x). -def : Pat<(HLOClient_AsinhOp NonComplexElementType:$input), - (HLO_MulOp - (HLO_SignOp $input), - (HLO_SelectOp - (HLO_CompareOp - (HLO_AbsOp $input), - (HLO_SqrtOp - (HLO_ConstantLikeMaxFiniteValue $input) - ), - HLO_COMPARISON_DIRECTION_GE, - (HLO_DEFAULT_COMPARISON_TYPE) - ), - (HLO_AddOp - (HLO_LogOp - (HLO_AbsOp $input) - ), - (HLO_LogOp - (HLO_ConstantLike<"2"> $input) - ) - ), - (HLO_SelectOp - (HLO_CompareOp - (HLO_AbsOp $input), - (HLO_ConstantLike<"1"> $input), - HLO_COMPARISON_DIRECTION_LE, - (HLO_DEFAULT_COMPARISON_TYPE) - ), - (HLO_Log1pOp - (HLO_AddOp - (HLO_AbsOp $input), - (HLO_MulOp - (HLO_AbsOp $input), - (HLO_DivOp - (HLO_AbsOp $input), - (HLO_AddOp - (HLO_ConstantLike<"1"> $input), - (HLO_SqrtOp - (HLO_AddOp - (HLO_MulOp - (HLO_AbsOp $input), - (HLO_AbsOp $input) - ), - (HLO_ConstantLike<"1"> $input) - ) - ) - ) - ) - ) - ) - ), - (HLO_LogOp - (HLO_AddOp - (HLO_AbsOp $input), - (HLO_SqrtOp - (HLO_AddOp - (HLO_MulOp - (HLO_AbsOp $input), - (HLO_AbsOp $input) - ), - (HLO_ConstantLike<"1"> $input) - ) - ) - ) - ) - ) - ) - )>; - // Express `atan` as // atan(x) = atan2(x, 1) def : Pat<(HLOClient_AtanOp $input), diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 70d5d38..cc9b3d9 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -50,10 +50,9 @@ namespace { sep fn(ShiftRightLogicalOp) sep fn(SubOp) sep fn(XorOp) // TODO(herhut): Generate these out of op definitions. -#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \ - fn(AcosOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) sep fn(AtanhOp) \ - sep fn(ConjOp) sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp) \ - sep fn(SinhOp) sep fn(TanOp) +#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \ + fn(AcosOp) sep fn(AsinOp) sep fn(AtanOp) sep fn(AtanhOp) sep fn(ConjOp) \ + sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp) template inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {