Add kernel definition for zeta operation.

PiperOrigin-RevId: 355575619
This commit is contained in:
Stephan Herhut 2021-02-04 01:26:44 -08:00 committed by TensorFlow MLIR Team
parent 945bf8768d
commit 60e1b6882c
1 changed files with 11 additions and 9 deletions

View File

@ -57,6 +57,9 @@ namespace {
sep fn(ErfOp) sep fn(ErfcOp) sep fn(IsInfOp) sep fn(LgammaOp) \
sep fn(SinhOp) sep fn(TanOp)
// TODO(herhut): Generate these out of op definitions.
#define MAP_CHLO_OPERATION_CWISE_BINARY(fn, sep) fn(ZetaOp)
template <typename OpTy>
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
@ -484,21 +487,20 @@ struct TransformUnrankedHloPass
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
#define MAP_UNARY(op) ElementwiseOpConversion<mhlo::op>
#define MAP_BINARY(op) ElementwiseOpConversion<mhlo::op>
#define MAP_CHLO_UNARY(op) ElementwiseOpConversion<chlo::op>
#define MAP_HLO(op) ElementwiseOpConversion<mhlo::op>
#define MAP_CHLO(op) ElementwiseOpConversion<chlo::op>
#define COMMA ,
// clang-format off
patterns->insert<
MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA),
MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA),
MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO_UNARY, COMMA),
MAP_XLA_OPERATION_CWISE_UNARY(MAP_HLO, COMMA),
MAP_XLA_OPERATION_CWISE_BINARY(MAP_HLO, COMMA),
MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO, COMMA),
MAP_CHLO_OPERATION_CWISE_BINARY(MAP_CHLO, COMMA),
ElementwiseOpConversion<mhlo::CompareOp>,
ElementwiseOpConversion<mhlo::SelectOp>>(context);
// clang-format on
#undef MAP_UNARY
#undef MAP_BINARY
#undef MAP_CHLO_UNARY
#undef MAP_HLO
#undef MAP_CHLO
#undef COMMA
chlo::PopulateForBroadcastingBinaryOp<
ConvertUnrankedDynamicBroadcastBinaryOp>(context, patterns);