[MLIR][KerneGen] Lower `tf.Atan` all the way to LLVM
PiperOrigin-RevId: 334810730
This commit is contained in:
		
							parent
							
								
									73f461080c
								
							
						
					
					
						commit
						458e861254
					
				| 
						 | 
					@ -360,6 +360,19 @@ def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos", [],
 | 
				
			||||||
  }];
 | 
					  }];
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan", [],
 | 
				
			||||||
 | 
					    HLO_FpOrComplexTensor> {
 | 
				
			||||||
 | 
					  let summary = "Atan operator";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  let description = [{
 | 
				
			||||||
 | 
					    Returns `Atan(operand)` element-wise.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    $$
 | 
				
			||||||
 | 
					    \atan(x) = \atan2(x, 1)
 | 
				
			||||||
 | 
					    $$
 | 
				
			||||||
 | 
					  }];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [],
 | 
					def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh", [],
 | 
				
			||||||
    HLO_FpOrComplexTensor> {
 | 
					    HLO_FpOrComplexTensor> {
 | 
				
			||||||
  let summary = "Sinh operation";
 | 
					  let summary = "Sinh operation";
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -149,6 +149,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
 | 
				
			||||||
      loc, result_types, args, b);
 | 
					      loc, result_types, args, b);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					inline Value MapLhloOpToStdScalarOp<lmhlo::Atan2Op>(Location loc,
 | 
				
			||||||
 | 
					                                                    ArrayRef<Type> result_types,
 | 
				
			||||||
 | 
					                                                    ArrayRef<Value> args,
 | 
				
			||||||
 | 
					                                                    OpBuilder* b) {
 | 
				
			||||||
 | 
					  return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::Atan2Op>{}(
 | 
				
			||||||
 | 
					      loc, result_types, args, b);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename PredicateType>
 | 
					template <typename PredicateType>
 | 
				
			||||||
inline Optional<PredicateType> getCmpPredicate(StringRef comparison_direction) {
 | 
					inline Optional<PredicateType> getCmpPredicate(StringRef comparison_direction) {
 | 
				
			||||||
  return llvm::None;
 | 
					  return llvm::None;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -49,6 +49,14 @@ def : Pat<(HLOClient_AcosOp $input),
 | 
				
			||||||
    ),
 | 
					    ),
 | 
				
			||||||
    (HLO_ConstantLike<"M_PI"> $input))>;
 | 
					    (HLO_ConstantLike<"M_PI"> $input))>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Express `atan` as
 | 
				
			||||||
 | 
					//   atan(x) = atan2(x, 1)
 | 
				
			||||||
 | 
					def : Pat<(HLOClient_AtanOp $input),
 | 
				
			||||||
 | 
					  (HLO_Atan2Op
 | 
				
			||||||
 | 
					    $input,
 | 
				
			||||||
 | 
					    (HLO_ConstantLike<"1"> $input)
 | 
				
			||||||
 | 
					  )>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Express `sinh` as
 | 
					// Express `sinh` as
 | 
				
			||||||
//   sinh(x) = (e^x - e^-x) / 2                     if |x| < 1
 | 
					//   sinh(x) = (e^x - e^-x) / 2                     if |x| < 1
 | 
				
			||||||
//           = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
 | 
					//           = e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
 | 
				
			||||||
| 
						 | 
					@ -95,4 +103,3 @@ def : Pat<(HLOClient_TanOp $input),
 | 
				
			||||||
    (HLO_SinOp $input),
 | 
					    (HLO_SinOp $input),
 | 
				
			||||||
    (HLO_CosOp $input)
 | 
					    (HLO_CosOp $input)
 | 
				
			||||||
  )>;
 | 
					  )>;
 | 
				
			||||||
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -822,6 +822,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
 | 
				
			||||||
                   PointwiseToLinalgConverter<lmhlo::AbsOp>,
 | 
					                   PointwiseToLinalgConverter<lmhlo::AbsOp>,
 | 
				
			||||||
                   PointwiseToLinalgConverter<lmhlo::AddOp>,
 | 
					                   PointwiseToLinalgConverter<lmhlo::AddOp>,
 | 
				
			||||||
                   PointwiseToLinalgConverter<lmhlo::AndOp>,
 | 
					                   PointwiseToLinalgConverter<lmhlo::AndOp>,
 | 
				
			||||||
 | 
					                   PointwiseToLinalgConverter<lmhlo::Atan2Op>,
 | 
				
			||||||
                   PointwiseToLinalgConverter<lmhlo::CeilOp>,
 | 
					                   PointwiseToLinalgConverter<lmhlo::CeilOp>,
 | 
				
			||||||
                   PointwiseToLinalgConverter<lmhlo::CompareOp>,
 | 
					                   PointwiseToLinalgConverter<lmhlo::CompareOp>,
 | 
				
			||||||
                   PointwiseToLinalgConverter<lmhlo::ComplexOp>,
 | 
					                   PointwiseToLinalgConverter<lmhlo::ComplexOp>,
 | 
				
			||||||
| 
						 | 
					@ -932,6 +933,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
 | 
				
			||||||
               PointwiseToLinalgConverter<mhlo::AbsOp, false>,
 | 
					               PointwiseToLinalgConverter<mhlo::AbsOp, false>,
 | 
				
			||||||
               PointwiseToLinalgConverter<mhlo::AddOp, false>,
 | 
					               PointwiseToLinalgConverter<mhlo::AddOp, false>,
 | 
				
			||||||
               PointwiseToLinalgConverter<mhlo::AndOp, false>,
 | 
					               PointwiseToLinalgConverter<mhlo::AndOp, false>,
 | 
				
			||||||
 | 
					               PointwiseToLinalgConverter<mhlo::Atan2Op, false>,
 | 
				
			||||||
               PointwiseToLinalgConverter<mhlo::CeilOp, false>,
 | 
					               PointwiseToLinalgConverter<mhlo::CeilOp, false>,
 | 
				
			||||||
               PointwiseToLinalgConverter<mhlo::CompareOp, false>,
 | 
					               PointwiseToLinalgConverter<mhlo::CompareOp, false>,
 | 
				
			||||||
               PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
 | 
					               PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -48,7 +48,7 @@ namespace {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// TODO(herhut): Generate these out of op definitions.
 | 
					// TODO(herhut): Generate these out of op definitions.
 | 
				
			||||||
#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
 | 
					#define MAP_CHLO_OPERATION_CWISE_UNARY(fn, sep) \
 | 
				
			||||||
  fn(TanOp) sep fn(AcosOp) sep fn(SinhOp)
 | 
					  fn(AcosOp) sep fn(AtanOp) sep fn(SinhOp) sep fn(TanOp)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename OpTy>
 | 
					template <typename OpTy>
 | 
				
			||||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
 | 
					inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue