Add kernel definition for zeta operation.

PiperOrigin-RevId: 355575619
This commit is contained in:
Stephan Herhut 2021-02-04 01:26:44 -08:00 committed by TensorFlow MLIR Team
parent 945bf8768d
commit 60e1b6882c
1 changed files with 11 additions and 9 deletions

View File

@ -57,6 +57,9 @@ namespace {
sep fn(ErfOp) sep fn(ErfcOp) sep fn(IsInfOp) sep fn(LgammaOp) \ sep fn(ErfOp) sep fn(ErfcOp) sep fn(IsInfOp) sep fn(LgammaOp) \
sep fn(SinhOp) sep fn(TanOp) sep fn(SinhOp) sep fn(TanOp)
// TODO(herhut): Generate these out of op definitions.
#define MAP_CHLO_OPERATION_CWISE_BINARY(fn, sep) fn(ZetaOp)
template <typename OpTy> template <typename OpTy>
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
target->addDynamicallyLegalOp<OpTy>([](OpTy op) { target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
@ -484,21 +487,20 @@ struct TransformUnrankedHloPass
void PopulateTransformUnrankedHloPatterns(MLIRContext *context, void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) { OwningRewritePatternList *patterns) {
#define MAP_UNARY(op) ElementwiseOpConversion<mhlo::op> #define MAP_HLO(op) ElementwiseOpConversion<mhlo::op>
#define MAP_BINARY(op) ElementwiseOpConversion<mhlo::op> #define MAP_CHLO(op) ElementwiseOpConversion<chlo::op>
#define MAP_CHLO_UNARY(op) ElementwiseOpConversion<chlo::op>
#define COMMA , #define COMMA ,
// clang-format off // clang-format off
patterns->insert< patterns->insert<
MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA), MAP_XLA_OPERATION_CWISE_UNARY(MAP_HLO, COMMA),
MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA), MAP_XLA_OPERATION_CWISE_BINARY(MAP_HLO, COMMA),
MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO_UNARY, COMMA), MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO, COMMA),
MAP_CHLO_OPERATION_CWISE_BINARY(MAP_CHLO, COMMA),
ElementwiseOpConversion<mhlo::CompareOp>, ElementwiseOpConversion<mhlo::CompareOp>,
ElementwiseOpConversion<mhlo::SelectOp>>(context); ElementwiseOpConversion<mhlo::SelectOp>>(context);
// clang-format on // clang-format on
#undef MAP_UNARY #undef MAP_HLO
#undef MAP_BINARY #undef MAP_CHLO
#undef MAP_CHLO_UNARY
#undef COMMA #undef COMMA
chlo::PopulateForBroadcastingBinaryOp< chlo::PopulateForBroadcastingBinaryOp<
ConvertUnrankedDynamicBroadcastBinaryOp>(context, patterns); ConvertUnrankedDynamicBroadcastBinaryOp>(context, patterns);