From 60e1b6882c33ce82d9376b7f754aa6259afa7426 Mon Sep 17 00:00:00 2001 From: Stephan Herhut Date: Thu, 4 Feb 2021 01:26:44 -0800 Subject: [PATCH] Add kernel definition for zeta operation. PiperOrigin-RevId: 355575619 --- .../mhlo/transforms/transform_unranked_hlo.cc | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 61f51af..3b52288 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -57,6 +57,9 @@ namespace { sep fn(ErfOp) sep fn(ErfcOp) sep fn(IsInfOp) sep fn(LgammaOp) \ sep fn(SinhOp) sep fn(TanOp) +// TODO(herhut): Generate these out of op definitions. +#define MAP_CHLO_OPERATION_CWISE_BINARY(fn, sep) fn(ZetaOp) + template inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { target->addDynamicallyLegalOp([](OpTy op) { @@ -484,21 +487,20 @@ struct TransformUnrankedHloPass void PopulateTransformUnrankedHloPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { -#define MAP_UNARY(op) ElementwiseOpConversion -#define MAP_BINARY(op) ElementwiseOpConversion -#define MAP_CHLO_UNARY(op) ElementwiseOpConversion +#define MAP_HLO(op) ElementwiseOpConversion +#define MAP_CHLO(op) ElementwiseOpConversion #define COMMA , // clang-format off patterns->insert< - MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA), - MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA), - MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO_UNARY, COMMA), + MAP_XLA_OPERATION_CWISE_UNARY(MAP_HLO, COMMA), + MAP_XLA_OPERATION_CWISE_BINARY(MAP_HLO, COMMA), + MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO, COMMA), + MAP_CHLO_OPERATION_CWISE_BINARY(MAP_CHLO, COMMA), ElementwiseOpConversion, ElementwiseOpConversion>(context); // clang-format on -#undef MAP_UNARY -#undef MAP_BINARY -#undef MAP_CHLO_UNARY +#undef MAP_HLO +#undef MAP_CHLO #undef COMMA chlo::PopulateForBroadcastingBinaryOp< ConvertUnrankedDynamicBroadcastBinaryOp>(context, patterns);