[MLIR][KernelGen] Add `tf.Asinh` kernels and complete their lowerings
PiperOrigin-RevId: 352773540
This commit is contained in:
parent
0e85b4d511
commit
ec5f5667e1
|
@ -66,6 +66,11 @@ static Value getConstantLike(OpBuilder& b, Location loc, T constant,
|
|||
return b.create<ConstantLikeOp>(loc, getAttr(), val);
|
||||
}
|
||||
|
||||
Value getConstantLike(OpBuilder& b, Location loc, const APFloat& constant,
|
||||
Value val);
|
||||
|
||||
Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val);
|
||||
|
||||
} // namespace chlo
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -372,6 +372,18 @@ def HLOClient_AsinOp : HLOClient_UnaryElementwiseOp<"asin", [],
|
|||
}];
|
||||
}
|
||||
|
||||
def HLOClient_AsinhOp : HLOClient_UnaryElementwiseOp<"asinh", [],
|
||||
HLO_FpOrComplexTensor> {
|
||||
let summary = "Asinh operation";
|
||||
|
||||
let description = [{
|
||||
Returns `Asinh(operand)` element-wise.
|
||||
|
||||
$$
|
||||
\asinh(x) = log(x + sqrt(x^2 + 1))
|
||||
$$
|
||||
}];
|
||||
}
|
||||
def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
|
||||
HLO_FpOrComplexTensor> {
|
||||
let summary = "Atan operator";
|
||||
|
|
|
@ -30,6 +30,9 @@ class ConstantSplat<string value> : NativeCodeCall<
|
|||
class HLO_ConstantLike<string value> : NativeCodeCall<
|
||||
"chlo::getConstantLike($_builder, $_loc, " # value # ", $0)">;
|
||||
|
||||
def HLO_ConstantLikeMaxFiniteValue : NativeCodeCall<
|
||||
"chlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">;
|
||||
|
||||
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
|
||||
|
||||
def BinBroadcastDimensions : NativeCodeCall<
|
||||
|
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/utils/broadcast_utils.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
|
@ -32,6 +33,18 @@ static LogicalResult Verify(T op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val) {
|
||||
auto ty = getElementTypeOrSelf(val.getType()).cast<FloatType>();
|
||||
return getConstantLike(
|
||||
b, loc, llvm::APFloat::getLargest(ty.getFloatSemantics()), val);
|
||||
}
|
||||
|
||||
Value getConstantLike(OpBuilder& b, Location loc, const APFloat& constant,
|
||||
Value val) {
|
||||
Type ty = getElementTypeOrSelf(val.getType());
|
||||
return b.create<ConstantLikeOp>(loc, b.getFloatAttr(ty, constant), val);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BinaryOps
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -79,6 +79,94 @@ def : Pat<(HLOClient_AsinOp NonComplexElementType:$input),
|
|||
)
|
||||
)>;
|
||||
|
||||
// Expand asinh to MHLO dialect as
|
||||
// asinh(x) = log(x + sqrt(x^2 + 1))
|
||||
//
|
||||
// If x^2 will overflow and x is positive, we can approximate x + sqrt(x^2 + 1)
|
||||
// as 2*x and return log(2) + log(x).
|
||||
//
|
||||
// For small x, sqrt(x^2 + 1) will evaluate to 1 due to floating point
|
||||
// arithmetic. However, we would like to retain the low order term of this,
|
||||
// which is around 0.5 * x^2 using a binomial expansion.
|
||||
// Let z = sqrt(a^2 + 1)
|
||||
// The following rewrite retains the lower order term.
|
||||
// log(a + sqrt(a^2 + 1))
|
||||
// = log((a + sqrt(a^2 + 1)) * (1 + sqrt(a^2 + 1)) / (1 + sqrt(a^2 + 1)))
|
||||
// = log((a + a^2 + 1 + a * z + z) / (1 + z))
|
||||
// = log(1 + a + a^2 / (1 + z))
|
||||
// = log(1 + a + a^2 / (1 + sqrt(a^2 + 1)))
|
||||
//
|
||||
// If x is negative, the above would give us some trouble; we can't approximate
|
||||
// the result as x + abs(x) = 0 but we are saved by the fact that asinh(-x) =
|
||||
// -asinh(x).
|
||||
def : Pat<(HLOClient_AsinhOp NonComplexElementType:$input),
|
||||
(HLO_MulOp
|
||||
(HLO_SignOp $input),
|
||||
(HLO_SelectOp
|
||||
(HLO_CompareOp
|
||||
(HLO_AbsOp $input),
|
||||
(HLO_SqrtOp
|
||||
(HLO_ConstantLikeMaxFiniteValue $input)
|
||||
),
|
||||
HLO_COMPARISON_DIRECTION_GE,
|
||||
(HLO_DEFAULT_COMPARISON_TYPE)
|
||||
),
|
||||
(HLO_AddOp
|
||||
(HLO_LogOp
|
||||
(HLO_AbsOp $input)
|
||||
),
|
||||
(HLO_LogOp
|
||||
(HLO_ConstantLike<"2"> $input)
|
||||
)
|
||||
),
|
||||
(HLO_SelectOp
|
||||
(HLO_CompareOp
|
||||
(HLO_AbsOp $input),
|
||||
(HLO_ConstantLike<"1"> $input),
|
||||
HLO_COMPARISON_DIRECTION_LE,
|
||||
(HLO_DEFAULT_COMPARISON_TYPE)
|
||||
),
|
||||
(HLO_Log1pOp
|
||||
(HLO_AddOp
|
||||
(HLO_AbsOp $input),
|
||||
(HLO_MulOp
|
||||
(HLO_AbsOp $input),
|
||||
(HLO_DivOp
|
||||
(HLO_AbsOp $input),
|
||||
(HLO_AddOp
|
||||
(HLO_ConstantLike<"1"> $input),
|
||||
(HLO_SqrtOp
|
||||
(HLO_AddOp
|
||||
(HLO_MulOp
|
||||
(HLO_AbsOp $input),
|
||||
(HLO_AbsOp $input)
|
||||
),
|
||||
(HLO_ConstantLike<"1"> $input)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
(HLO_LogOp
|
||||
(HLO_AddOp
|
||||
(HLO_AbsOp $input),
|
||||
(HLO_SqrtOp
|
||||
(HLO_AddOp
|
||||
(HLO_MulOp
|
||||
(HLO_AbsOp $input),
|
||||
(HLO_AbsOp $input)
|
||||
),
|
||||
(HLO_ConstantLike<"1"> $input)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
)>;
|
||||
|
||||
// Express `atan` as
|
||||
// atan(x) = atan2(x, 1)
|
||||
def : Pat<(HLOClient_AtanOp $input),
|
||||
|
|
|
@ -51,8 +51,9 @@ namespace {
|
|||
|
||||
// TODO(herhut): Generate these out of op definitions.
|
||||
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
||||
fn(AcosOp) sep fn(AsinOp) sep fn(AtanOp) sep fn(AtanhOp) sep fn(ConjOp) \
|
||||
sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp)
|
||||
fn(AcosOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) sep fn(AtanhOp) \
|
||||
sep fn(ConjOp) sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp) \
|
||||
sep fn(SinhOp) sep fn(TanOp)
|
||||
|
||||
template <typename OpTy>
|
||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
||||
|
|
|
@ -1,5 +1,80 @@
|
|||
// RUN: mlir-hlo-opt --chlo-legalize-to-hlo --split-input-file %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @asinh_bf16
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<bf16>
|
||||
func @asinh_bf16(%arg : tensor<bf16>) -> tensor<bf16> {
|
||||
// Check for the bf16-specific max value.
|
||||
// CHECK: mhlo.constant dense<3.389{{.*}}e+38>
|
||||
%result = "chlo.asinh"(%arg) : (tensor<bf16>) -> tensor<bf16>
|
||||
return %result : tensor<bf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @asinh_f16
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<f16>
|
||||
func @asinh_f16(%arg : tensor<f16>) -> tensor<f16> {
|
||||
// Check for the f16-specific max value.
|
||||
// CHECK: mhlo.constant dense<6.550{{.*}}e+04>
|
||||
%result = "chlo.asinh"(%arg) : (tensor<f16>) -> tensor<f16>
|
||||
return %result : tensor<f16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @asinh_f32
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>
|
||||
func @asinh_f32(%arg : tensor<f32>) -> tensor<f32> {
|
||||
// Check for the f32-specific max value.
|
||||
// CHECK: mhlo.constant dense<3.402{{.*}}E+38>
|
||||
%result = "chlo.asinh"(%arg) : (tensor<f32>) -> tensor<f32>
|
||||
return %result : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @asinh_f64
|
||||
// CHECK-SAME: %[[ARG:.*]]: tensor<f64>
|
||||
func @asinh_f64(%arg : tensor<f64>) -> tensor<f64> {
|
||||
// CHECK: %[[TMP_0:.*]] = "mhlo.sign"(%[[ARG]])
|
||||
// CHECK: %[[TMP_1:.*]] = "mhlo.abs"(%[[ARG]])
|
||||
// CHECK: %[[TMP_2:.*]] = mhlo.constant dense<1.797{{.*}}E+308>
|
||||
// CHECK: %[[TMP_3:.*]] = "mhlo.sqrt"(%[[TMP_2]])
|
||||
// CHECK: %[[TMP_4:.*]] = "mhlo.compare"(%[[TMP_1]], %[[TMP_3]]) {comparison_direction = "GE"}
|
||||
// CHECK: %[[TMP_5:.*]] = "mhlo.abs"(%[[ARG]])
|
||||
// CHECK: %[[TMP_6:.*]] = "mhlo.log"(%[[TMP_5]])
|
||||
// CHECK: %[[TMP_7:.*]] = mhlo.constant dense<2.000{{.*}}e+00>
|
||||
// CHECK: %[[TMP_8:.*]] = "mhlo.log"(%[[TMP_7]])
|
||||
// CHECK: %[[TMP_9:.*]] = mhlo.add %[[TMP_6]], %[[TMP_8]]
|
||||
// CHECK: %[[TMP_10:.*]] = "mhlo.abs"(%[[ARG]])
|
||||
// CHECK: %[[TMP_11:.*]] = mhlo.constant dense<1.000{{.*}}e+00>
|
||||
// CHECK: %[[TMP_12:.*]] = "mhlo.compare"(%[[TMP_10]], %[[TMP_11]]) {comparison_direction = "LE"}
|
||||
// CHECK: %[[TMP_13:.*]] = "mhlo.abs"(%[[ARG]])
|
||||
// CHECK: %[[TMP_14:.*]] = "mhlo.abs"(%[[ARG]])
|
||||
// CHECK: %[[TMP_15:.*]] = "mhlo.abs"(%[[ARG]])
|
||||
// CHECK: %[[TMP_16:.*]] = "mhlo.abs"(%[[ARG]])
|
||||
// CHECK: %[[TMP_17:.*]] = "mhlo.abs"(%[[ARG]])
|
||||
// CHECK: %[[TMP_18:.*]] = mhlo.multiply %[[TMP_16]], %[[TMP_17]]
|
||||
// CHECK: %[[TMP_19:.*]] = mhlo.constant dense<1.000{{.*}}e+00>
|
||||
// CHECK: %[[TMP_20:.*]] = mhlo.add %[[TMP_18]], %[[TMP_19]]
|
||||
// CHECK: %[[TMP_21:.*]] = "mhlo.sqrt"(%[[TMP_20]])
|
||||
// CHECK: %[[TMP_22:.*]] = mhlo.constant dense<1.000{{.*}}e+00>
|
||||
// CHECK: %[[TMP_23:.*]] = mhlo.add %[[TMP_22]], %[[TMP_21]]
|
||||
// CHECK: %[[TMP_24:.*]] = mhlo.divide %[[TMP_15]], %[[TMP_23]]
|
||||
// CHECK: %[[TMP_25:.*]] = mhlo.multiply %[[TMP_14]], %[[TMP_24]]
|
||||
// CHECK: %[[TMP_26:.*]] = mhlo.add %[[TMP_13]], %[[TMP_25]]
|
||||
// CHECK: %[[TMP_27:.*]] = "mhlo.log_plus_one"(%[[TMP_26]])
|
||||
// CHECK: %[[TMP_28:.*]] = "mhlo.abs"(%[[ARG]])
|
||||
// CHECK: %[[TMP_29:.*]] = "mhlo.abs"(%[[ARG]])
|
||||
// CHECK: %[[TMP_30:.*]] = "mhlo.abs"(%[[ARG]])
|
||||
// CHECK: %[[TMP_31:.*]] = mhlo.multiply %[[TMP_29]], %[[TMP_30]]
|
||||
// CHECK: %[[TMP_32:.*]] = mhlo.constant dense<1.000{{.*}}e+00>
|
||||
// CHECK: %[[TMP_33:.*]] = mhlo.add %[[TMP_31]], %[[TMP_32]]
|
||||
// CHECK: %[[TMP_34:.*]] = "mhlo.sqrt"(%[[TMP_33]])
|
||||
// CHECK: %[[TMP_35:.*]] = mhlo.add %[[TMP_28]], %[[TMP_34]]
|
||||
// CHECK: %[[TMP_36:.*]] = "mhlo.log"(%[[TMP_35]])
|
||||
// CHECK: %[[TMP_37:.*]] = "mhlo.select"(%[[TMP_12]], %[[TMP_27]], %[[TMP_36]])
|
||||
// CHECK: %[[TMP_38:.*]] = "mhlo.select"(%[[TMP_4]], %[[TMP_9]], %[[TMP_37]])
|
||||
// CHECK: %[[RES:.*]] = mhlo.multiply %[[TMP_0]], %[[TMP_38]]
|
||||
// CHECK: return %[[RES]]
|
||||
%result = "chlo.asinh"(%arg) : (tensor<f64>) -> tensor<f64>
|
||||
return %result : tensor<f64>
|
||||
}
|
||||
|
||||
// Lower statically shaped `constant_like` to constant.
|
||||
// CHECK-LABEL: @constant_like_static_shape
|
||||
func @constant_like_static_shape(%arg : tensor<1x2xi64>) -> tensor<1x2xf32> {
|
||||
|
|
Loading…
Reference in New Issue