Extend unranked to ranked pattern for hlo operations to all unary and binary ops.
As this is essentially always the same pattern, only one operation is tested. PiperOrigin-RevId: 323525418
This commit is contained in:
parent
b7c4314e7f
commit
effd3fb4f9
|
@ -31,6 +31,22 @@ namespace mlir {
|
|||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
// TODO(herhut): Generate these out of op definitions.
|
||||
#define MAP_XLA_OPERATION_CWISE_UNARY(fn, sep) \
|
||||
fn(AbsOp) sep fn(CeilOp) sep fn(ClzOp) sep fn(CosOp) sep fn(ExpOp) \
|
||||
sep fn(Expm1Op) sep fn(FloorOp) sep fn(ImagOp) sep fn(IsFiniteOp) \
|
||||
sep fn(LogOp) sep fn(Log1pOp) sep fn(LogisticOp) sep fn(NotOp) \
|
||||
sep fn(NegOp) sep fn(PopulationCountOp) sep fn(RealOp) \
|
||||
sep fn(RoundOp) sep fn(RsqrtOp) sep fn(SignOp) sep fn(SinOp) \
|
||||
sep fn(SqrtOp) sep fn(TanhOp)
|
||||
|
||||
// TODO(herhut): Generate these out of op definitions.
|
||||
#define MAP_XLA_OPERATION_CWISE_BINARY(fn, sep) \
|
||||
fn(AddOp) sep fn(Atan2Op) sep fn(ComplexOp) sep fn(DivOp) sep fn(MaxOp) \
|
||||
sep fn(MinOp) sep fn(MulOp) sep fn(PowOp) sep fn(RemOp) \
|
||||
sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \
|
||||
sep fn(ShiftRightLogicalOp) sep fn(SubOp)
|
||||
|
||||
// TODO(frgossen): Make it variadic.
|
||||
template <typename OpTy>
|
||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
||||
|
@ -154,8 +170,10 @@ struct TransformUnrankedHloPass
|
|||
target.addLegalDialect<MhloDialect, StandardOpsDialect,
|
||||
shape::ShapeDialect>();
|
||||
target.addLegalOp<FuncOp>();
|
||||
AddLegalOpOnRankedTensor<SqrtOp>(&target);
|
||||
AddLegalOpOnRankedTensor<AddOp>(&target);
|
||||
#define ADD_LEGAL(op) AddLegalOpOnRankedTensor<op>(&target)
|
||||
MAP_XLA_OPERATION_CWISE_UNARY(ADD_LEGAL, ;);
|
||||
MAP_XLA_OPERATION_CWISE_BINARY(ADD_LEGAL, ;);
|
||||
#undef ADD_LEGAL
|
||||
|
||||
// Populate rewrite patterns.
|
||||
OwningRewritePatternList patterns;
|
||||
|
@ -173,9 +191,16 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
|||
OwningRewritePatternList *patterns) {
|
||||
// TODO(frgossen): Populate all unary and binary operations.
|
||||
// clang-format off
|
||||
#define MAP_UNARY(op) UnaryElementwiseOpConversion<op>
|
||||
#define MAP_BINARY(op) BinaryElementwiseOpConversion<op>
|
||||
#define COMMA ,
|
||||
patterns->insert<
|
||||
BinaryElementwiseOpConversion<AddOp>,
|
||||
UnaryElementwiseOpConversion<SqrtOp>>(context);
|
||||
MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA),
|
||||
MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA)
|
||||
>(context);
|
||||
#undef MAP_UNARY
|
||||
#undef MAP_BINARY
|
||||
#undef COMMA
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue