Add kernel definition for zeta operation.
PiperOrigin-RevId: 355575619
This commit is contained in:
parent
945bf8768d
commit
60e1b6882c
|
@ -57,6 +57,9 @@ namespace {
|
||||||
sep fn(ErfOp) sep fn(ErfcOp) sep fn(IsInfOp) sep fn(LgammaOp) \
|
sep fn(ErfOp) sep fn(ErfcOp) sep fn(IsInfOp) sep fn(LgammaOp) \
|
||||||
sep fn(SinhOp) sep fn(TanOp)
|
sep fn(SinhOp) sep fn(TanOp)
|
||||||
|
|
||||||
|
// TODO(herhut): Generate these out of op definitions.
|
||||||
|
#define MAP_CHLO_OPERATION_CWISE_BINARY(fn, sep) fn(ZetaOp)
|
||||||
|
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
||||||
target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
|
target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
|
||||||
|
@ -484,21 +487,20 @@ struct TransformUnrankedHloPass
|
||||||
|
|
||||||
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
||||||
OwningRewritePatternList *patterns) {
|
OwningRewritePatternList *patterns) {
|
||||||
#define MAP_UNARY(op) ElementwiseOpConversion<mhlo::op>
|
#define MAP_HLO(op) ElementwiseOpConversion<mhlo::op>
|
||||||
#define MAP_BINARY(op) ElementwiseOpConversion<mhlo::op>
|
#define MAP_CHLO(op) ElementwiseOpConversion<chlo::op>
|
||||||
#define MAP_CHLO_UNARY(op) ElementwiseOpConversion<chlo::op>
|
|
||||||
#define COMMA ,
|
#define COMMA ,
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<
|
patterns->insert<
|
||||||
MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA),
|
MAP_XLA_OPERATION_CWISE_UNARY(MAP_HLO, COMMA),
|
||||||
MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA),
|
MAP_XLA_OPERATION_CWISE_BINARY(MAP_HLO, COMMA),
|
||||||
MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO_UNARY, COMMA),
|
MAP_CHLO_OPERATION_CWISE_UNARY(MAP_CHLO, COMMA),
|
||||||
|
MAP_CHLO_OPERATION_CWISE_BINARY(MAP_CHLO, COMMA),
|
||||||
ElementwiseOpConversion<mhlo::CompareOp>,
|
ElementwiseOpConversion<mhlo::CompareOp>,
|
||||||
ElementwiseOpConversion<mhlo::SelectOp>>(context);
|
ElementwiseOpConversion<mhlo::SelectOp>>(context);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
#undef MAP_UNARY
|
#undef MAP_HLO
|
||||||
#undef MAP_BINARY
|
#undef MAP_CHLO
|
||||||
#undef MAP_CHLO_UNARY
|
|
||||||
#undef COMMA
|
#undef COMMA
|
||||||
chlo::PopulateForBroadcastingBinaryOp<
|
chlo::PopulateForBroadcastingBinaryOp<
|
||||||
ConvertUnrankedDynamicBroadcastBinaryOp>(context, patterns);
|
ConvertUnrankedDynamicBroadcastBinaryOp>(context, patterns);
|
||||||
|
|
Loading…
Reference in New Issue