Add kernel definition for zeta operation.
PiperOrigin-RevId: 355575619
This commit is contained in:
parent
945bf8768d
commit
60e1b6882c
|
@ -57,6 +57,9 @@ namespace {
|
|||
sep fn(ErfOp) sep fn(ErfcOp) sep fn(IsInfOp) sep fn(LgammaOp) \
|
||||
sep fn(SinhOp) sep fn(TanOp)
|
||||
|
||||
// TODO(herhut): Generate these out of op definitions.
|
||||
#define MAP_CHLO_OPERATION_CWISE_BINARY(fn, sep) fn(ZetaOp)
|
||||
|
||||
template <typename OpTy>
|
||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
||||
target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
|
||||
|
@ -484,21 +487,20 @@ struct TransformUnrankedHloPass
|
|||
|
||||
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns) {
|
||||
#define MAP_UNARY(op) ElementwiseOpConversion<mhlo::op>
|
||||
#define MAP_BINARY(op) ElementwiseOpConversion<mhlo::op>
|
||||
#define MAP_CHLO_UNARY(op) ElementwiseOpConversion<chlo::op>
|
||||
#define MAP_HLO(op) ElementwiseOpConversion<mhlo::op>
|
||||
#define MAP_CHLO(op) ElementwiseOpConversion<chlo::op>
|
||||
#define COMMA ,
|
||||
// clang-format off
|
||||
patterns->insert<
|
||||
MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA),
|
||||
MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA),
|
||||
MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO_UNARY, COMMA),
|
||||
MAP_XLA_OPERATION_CWISE_UNARY(MAP_HLO, COMMA),
|
||||
MAP_XLA_OPERATION_CWISE_BINARY(MAP_HLO, COMMA),
|
||||
MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO, COMMA),
|
||||
MAP_CHLO_OPERATION_CWISE_BINARY(MAP_CHLO, COMMA),
|
||||
ElementwiseOpConversion<mhlo::CompareOp>,
|
||||
ElementwiseOpConversion<mhlo::SelectOp>>(context);
|
||||
// clang-format on
|
||||
#undef MAP_UNARY
|
||||
#undef MAP_BINARY
|
||||
#undef MAP_CHLO_UNARY
|
||||
#undef MAP_HLO
|
||||
#undef MAP_CHLO
|
||||
#undef COMMA
|
||||
chlo::PopulateForBroadcastingBinaryOp<
|
||||
ConvertUnrankedDynamicBroadcastBinaryOp>(context, patterns);
|
||||
|
|
Loading…
Reference in New Issue