[MLIR][KerneGen] Lower `tf.Atan` all the way to LLVM
PiperOrigin-RevId: 334843070
This commit is contained in:
parent
aa65e49ad2
commit
5f303440da
|
@ -360,19 +360,6 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [],
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
|
|
||||||
HLO_FpOrComplexTensor> {
|
|
||||||
let summary = "Atan operator";
|
|
||||||
|
|
||||||
let description = [{
|
|
||||||
Returns `Atan(operand)` element-wise.
|
|
||||||
|
|
||||||
$$
|
|
||||||
\atan(x) = \atan2(x, 1)
|
|
||||||
$$
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [],
|
def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [],
|
||||||
HLO_FpOrComplexTensor> {
|
HLO_FpOrComplexTensor> {
|
||||||
let summary = "Sinh operation";
|
let summary = "Sinh operation";
|
||||||
|
|
|
@ -149,15 +149,6 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::Atan2Op>(Location loc,
|
|
||||||
ArrayRef<Type> result_types,
|
|
||||||
ArrayRef<Value> args,
|
|
||||||
OpBuilder* b) {
|
|
||||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::Atan2Op>{}(
|
|
||||||
loc, result_types, args, b);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename PredicateType>
|
template <typename PredicateType>
|
||||||
inline Optional<PredicateType> getCmpPredicate(StringRef comparison_direction) {
|
inline Optional<PredicateType> getCmpPredicate(StringRef comparison_direction) {
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
|
|
|
@ -49,14 +49,6 @@ def : Pat<(HLOClient_AcosOp $input),
|
||||||
),
|
),
|
||||||
(HLO_ConstantLike<"M_PI"> $input))>;
|
(HLO_ConstantLike<"M_PI"> $input))>;
|
||||||
|
|
||||||
// Express `atan` as
|
|
||||||
// atan(x) = atan2(x, 1)
|
|
||||||
def : Pat<(HLOClient_AtanOp $input),
|
|
||||||
(HLO_Atan2Op
|
|
||||||
$input,
|
|
||||||
(HLO_ConstantLike<"1"> $input)
|
|
||||||
)>;
|
|
||||||
|
|
||||||
// Express `sinh` as
|
// Express `sinh` as
|
||||||
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
|
// sinh(x) = (e^x - e^-x) / 2 if |x| < 1
|
||||||
// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
|
// = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
|
||||||
|
@ -103,3 +95,4 @@ def : Pat<(HLOClient_TanOp $input),
|
||||||
(HLO_SinOp $input),
|
(HLO_SinOp $input),
|
||||||
(HLO_CosOp $input)
|
(HLO_CosOp $input)
|
||||||
)>;
|
)>;
|
||||||
|
|
||||||
|
|
|
@ -822,7 +822,6 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
PointwiseToLinalgConverter<lmhlo::AbsOp>,
|
PointwiseToLinalgConverter<lmhlo::AbsOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::AddOp>,
|
PointwiseToLinalgConverter<lmhlo::AddOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::AndOp>,
|
PointwiseToLinalgConverter<lmhlo::AndOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::Atan2Op>,
|
|
||||||
PointwiseToLinalgConverter<lmhlo::CeilOp>,
|
PointwiseToLinalgConverter<lmhlo::CeilOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::CompareOp>,
|
PointwiseToLinalgConverter<lmhlo::CompareOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::ComplexOp>,
|
PointwiseToLinalgConverter<lmhlo::ComplexOp>,
|
||||||
|
@ -933,7 +932,6 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::AddOp, false>,
|
PointwiseToLinalgConverter<mhlo::AddOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::Atan2Op, false>,
|
|
||||||
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
|
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
|
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
|
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
|
||||||
|
|
|
@ -48,7 +48,7 @@ namespace {
|
||||||
|
|
||||||
// TODO(herhut): Generate these out of op definitions.
|
// TODO(herhut): Generate these out of op definitions.
|
||||||
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
|
||||||
fn(AcosOp) sep fn(AtanOp) sep fn(SinhOp) sep fn(TanOp)
|
fn(TanOp) sep fn(AcosOp) sep fn(SinhOp)
|
||||||
|
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
||||||
|
|
Loading…
Reference in New Issue