diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 61f51af..3b52288 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -57,6 +57,9 @@ namespace { sep fn(ErfOp) sep fn(ErfcOp) sep fn(IsInfOp) sep fn(LgammaOp) \ sep fn(SinhOp) sep fn(TanOp) +// TODO(herhut): Generate these out of op definitions. +#define MAP_CHLO_OPERATION_CWISE_BINARY(fn, sep) fn(ZetaOp) + template inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { target->addDynamicallyLegalOp([](OpTy op) { @@ -484,21 +487,20 @@ struct TransformUnrankedHloPass void PopulateTransformUnrankedHloPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { -#define MAP_UNARY(op) ElementwiseOpConversion -#define MAP_BINARY(op) ElementwiseOpConversion -#define MAP_CHLO_UNARY(op) ElementwiseOpConversion +#define MAP_HLO(op) ElementwiseOpConversion +#define MAP_CHLO(op) ElementwiseOpConversion #define COMMA , // clang-format off patterns->insert< - MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA), - MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA), - MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO_UNARY, COMMA), + MAP_XLA_OPERATION_CWISE_UNARY(MAP_HLO, COMMA), + MAP_XLA_OPERATION_CWISE_BINARY(MAP_HLO, COMMA), + MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO, COMMA), + MAP_CHLO_OPERATION_CWISE_BINARY(MAP_CHLO, COMMA), ElementwiseOpConversion, ElementwiseOpConversion>(context); // clang-format on -#undef MAP_UNARY -#undef MAP_BINARY -#undef MAP_CHLO_UNARY +#undef MAP_HLO +#undef MAP_CHLO #undef COMMA chlo::PopulateForBroadcastingBinaryOp< ConvertUnrankedDynamicBroadcastBinaryOp>(context, patterns);