Extend unranked to ranked pattern for hlo operations to all unary and binary ops.

As this is essentially always the same pattern, only one operation is tested.

PiperOrigin-RevId: 323525418
This commit is contained in:
Stephan Herhut 2020-07-28 00:55:58 -07:00 committed by TensorFlow MLIR Team
parent b7c4314e7f
commit effd3fb4f9
1 changed files with 29 additions and 4 deletions

View File

@ -31,6 +31,22 @@ namespace mlir {
namespace mhlo {
namespace {
// TODO(herhut): Generate these out of op definitions.
#define MAP_XLA_OPERATION_CWISE_UNARY(fn, sep) \
fn(AbsOp) sep fn(CeilOp) sep fn(ClzOp) sep fn(CosOp) sep fn(ExpOp) \
sep fn(Expm1Op) sep fn(FloorOp) sep fn(ImagOp) sep fn(IsFiniteOp) \
sep fn(LogOp) sep fn(Log1pOp) sep fn(LogisticOp) sep fn(NotOp) \
sep fn(NegOp) sep fn(PopulationCountOp) sep fn(RealOp) \
sep fn(RoundOp) sep fn(RsqrtOp) sep fn(SignOp) sep fn(SinOp) \
sep fn(SqrtOp) sep fn(TanhOp)
// TODO(herhut): Generate these out of op definitions.
#define MAP_XLA_OPERATION_CWISE_BINARY(fn, sep) \
fn(AddOp) sep fn(Atan2Op) sep fn(ComplexOp) sep fn(DivOp) sep fn(MaxOp) \
sep fn(MinOp) sep fn(MulOp) sep fn(PowOp) sep fn(RemOp) \
sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
sep fn(ShiftRightLogicalOp) sep fn(SubOp)
// TODO(frgossen): Make it variadic.
template <typename OpTy>
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
@ -154,8 +170,10 @@ struct TransformUnrankedHloPass
target.addLegalDialect<MhloDialect, StandardOpsDialect,
shape::ShapeDialect>();
target.addLegalOp<FuncOp>();
AddLegalOpOnRankedTensor<SqrtOp>(&target);
AddLegalOpOnRankedTensor<AddOp>(&target);
#define ADD_LEGAL(op) AddLegalOpOnRankedTensor<op>(&target)
MAP_XLA_OPERATION_CWISE_UNARY(ADD_LEGAL, ;);
MAP_XLA_OPERATION_CWISE_BINARY(ADD_LEGAL, ;);
#undef ADD_LEGAL
// Populate rewrite patterns.
OwningRewritePatternList patterns;
@ -173,9 +191,16 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
// TODO(frgossen): Populate all unary and binary operations.
// clang-format off
#define MAP_UNARY(op) UnaryElementwiseOpConversion<op>
#define MAP_BINARY(op) BinaryElementwiseOpConversion<op>
#define COMMA ,
patterns->insert<
BinaryElementwiseOpConversion<AddOp>,
UnaryElementwiseOpConversion<SqrtOp>>(context);
MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA),
MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA)
>(context);
#undef MAP_UNARY
#undef MAP_BINARY
#undef COMMA
// clang-format on
}