[MLIR][KernelGen] Add `tf.Asinh` kernels and complete their lowerings
PiperOrigin-RevId: 351989552
This commit is contained in:
parent
316f630728
commit
791d5afd28
|
@ -66,6 +66,8 @@ static Value getConstantLike(OpBuilder& b, Location loc, T constant,
|
||||||
return b.create<ConstantLikeOp>(loc, getAttr(), val);
|
return b.create<ConstantLikeOp>(loc, getAttr(), val);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val);
|
||||||
|
|
||||||
} // namespace chlo
|
} // namespace chlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -372,6 +372,18 @@ def HLOClient_AsinOp : HLOClient_UnaryElementwiseOp<"asin", [],
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def HLOClient_AsinhOp : HLOClient_UnaryElementwiseOp<"asinh", [],
|
||||||
|
HLO_FpOrComplexTensor> {
|
||||||
|
let summary = "Asinh operation";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Returns `Asinh(operand)` element-wise.
|
||||||
|
|
||||||
|
$$
|
||||||
|
\asinh(x) = log(x + sqrt(x^2 + 1))
|
||||||
|
$$
|
||||||
|
}];
|
||||||
|
}
|
||||||
def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
|
def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
|
||||||
HLO_FpOrComplexTensor> {
|
HLO_FpOrComplexTensor> {
|
||||||
let summary = "Atan operator";
|
let summary = "Atan operator";
|
||||||
|
|
|
@ -30,6 +30,9 @@ class ConstantSplat<string value> : NativeCodeCall<
|
||||||
class HLO_ConstantLike<string value> : NativeCodeCall<
|
class HLO_ConstantLike<string value> : NativeCodeCall<
|
||||||
"chlo::getConstantLike($_builder, $_loc, " # value # ", $0)">;
|
"chlo::getConstantLike($_builder, $_loc, " # value # ", $0)">;
|
||||||
|
|
||||||
|
def HLO_ConstantLikeMaxFiniteValue : NativeCodeCall<
|
||||||
|
"chlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">;
|
||||||
|
|
||||||
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
|
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
|
||||||
|
|
||||||
def BinBroadcastDimensions : NativeCodeCall<
|
def BinBroadcastDimensions : NativeCodeCall<
|
||||||
|
|
|
@ -32,6 +32,20 @@ static LogicalResult Verify(T op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static constexpr float kF16MaxFiniteValue = 0x1.ffcP15;
|
||||||
|
|
||||||
|
Value getConstantLikeMaxFiniteValue(OpBuilder& b, Location loc, Value val) {
|
||||||
|
Type ty = getElementTypeOrSelf(val.getType());
|
||||||
|
if (ty.isF16()) {
|
||||||
|
return getConstantLike(b, loc, kF16MaxFiniteValue, val);
|
||||||
|
} else if (ty.isF32()) {
|
||||||
|
return getConstantLike(b, loc, std::numeric_limits<float>::max(), val);
|
||||||
|
} else if (ty.isF64()) {
|
||||||
|
return getConstantLike(b, loc, std::numeric_limits<double>::max(), val);
|
||||||
|
}
|
||||||
|
llvm_unreachable("unhandled type");
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// BinaryOps
|
// BinaryOps
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -79,6 +79,94 @@ def : Pat<(HLOClient_AsinOp NonComplexElementType:$input),
|
||||||
)
|
)
|
||||||
)>;
|
)>;
|
||||||
|
|
||||||
|
// Expand asinh to MHLO dialect as
|
||||||
|
// asinh(x) = log(x + sqrt(x^2 + 1))
|
||||||
|
//
|
||||||
|
// If x^2 will overflow and x is positive, we can approximate x + sqrt(x^2 + 1)
|
||||||
|
// as 2*x and return log(2) + log(x).
|
||||||
|
//
|
||||||
|
// For small x, sqrt(x^2 + 1) will evaluate to 1 due to floating point
|
||||||
|
// arithmetic. However, we would like to retain the low order term of this,
|
||||||
|
// which is around 0.5 * x^2 using a binomial expansion.
|
||||||
|
// Let z = sqrt(a^2 + 1)
|
||||||
|
// The following rewrite retains the lower order term.
|
||||||
|
// log(a + sqrt(a^2 + 1))
|
||||||
|
// = log((a + sqrt(a^2 + 1)) * (1 + sqrt(a^2 + 1)) / (1 + sqrt(a^2 + 1)))
|
||||||
|
// = log((a + a^2 + 1 + a * z + z) / (1 + z))
|
||||||
|
// = log(1 + a + a^2 / (1 + z))
|
||||||
|
// = log(1 + a + a^2 / (1 + sqrt(a^2 + 1)))
|
||||||
|
//
|
||||||
|
// If x is negative, the above would give us some trouble; we can't approximate
|
||||||
|
// the result as x + abs(x) = 0 but we are saved by the fact that asinh(-x) =
|
||||||
|
// -asinh(x).
|
||||||
|
def : Pat<(HLOClient_AsinhOp NonComplexElementType:$input),
|
||||||
|
(HLO_MulOp
|
||||||
|
(HLO_SignOp $input),
|
||||||
|
(HLO_SelectOp
|
||||||
|
(HLO_CompareOp
|
||||||
|
(HLO_AbsOp $input),
|
||||||
|
(HLO_SqrtOp
|
||||||
|
(HLO_ConstantLikeMaxFiniteValue $input)
|
||||||
|
),
|
||||||
|
HLO_COMPARISON_DIRECTION_GE,
|
||||||
|
(HLO_DEFAULT_COMPARISON_TYPE)
|
||||||
|
),
|
||||||
|
(HLO_AddOp
|
||||||
|
(HLO_LogOp
|
||||||
|
(HLO_AbsOp $input)
|
||||||
|
),
|
||||||
|
(HLO_LogOp
|
||||||
|
(HLO_ConstantLike<"2"> $input)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
(HLO_SelectOp
|
||||||
|
(HLO_CompareOp
|
||||||
|
(HLO_AbsOp $input),
|
||||||
|
(HLO_ConstantLike<"1"> $input),
|
||||||
|
HLO_COMPARISON_DIRECTION_LE,
|
||||||
|
(HLO_DEFAULT_COMPARISON_TYPE)
|
||||||
|
),
|
||||||
|
(HLO_Log1pOp
|
||||||
|
(HLO_AddOp
|
||||||
|
(HLO_AbsOp $input),
|
||||||
|
(HLO_MulOp
|
||||||
|
(HLO_AbsOp $input),
|
||||||
|
(HLO_DivOp
|
||||||
|
(HLO_AbsOp $input),
|
||||||
|
(HLO_AddOp
|
||||||
|
(HLO_ConstantLike<"1"> $input),
|
||||||
|
(HLO_SqrtOp
|
||||||
|
(HLO_AddOp
|
||||||
|
(HLO_MulOp
|
||||||
|
(HLO_AbsOp $input),
|
||||||
|
(HLO_AbsOp $input)
|
||||||
|
),
|
||||||
|
(HLO_ConstantLike<"1"> $input)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
(HLO_LogOp
|
||||||
|
(HLO_AddOp
|
||||||
|
(HLO_AbsOp $input),
|
||||||
|
(HLO_SqrtOp
|
||||||
|
(HLO_AddOp
|
||||||
|
(HLO_MulOp
|
||||||
|
(HLO_AbsOp $input),
|
||||||
|
(HLO_AbsOp $input)
|
||||||
|
),
|
||||||
|
(HLO_ConstantLike<"1"> $input)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)>;
|
||||||
|
|
||||||
// Express `atan` as
|
// Express `atan` as
|
||||||
// atan(x) = atan2(x, 1)
|
// atan(x) = atan2(x, 1)
|
||||||
def : Pat<(HLOClient_AtanOp $input),
|
def : Pat<(HLOClient_AtanOp $input),
|
||||||
|
|
|
@ -51,8 +51,8 @@ namespace {
|
||||||
|
|
||||||
// TODO(herhut): Generate these out of op definitions.
|
// TODO(herhut): Generate these out of op definitions.
|
||||||
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
||||||
fn(AcosOp) sep fn(AsinOp) sep fn(AtanOp) sep fn(ConjOp) sep fn(CoshOp) \
|
fn(AcosOp) sep fn(AsinOp) sep fn(AsinhOp) sep fn(AtanOp) sep fn(ConjOp) \
|
||||||
sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp)
|
sep fn(CoshOp) sep fn(ErfOp) sep fn(ErfcOp) sep fn(SinhOp) sep fn(TanOp)
|
||||||
|
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
||||||
|
|
Loading…
Reference in New Issue