[MLIR][MHLO] Add pattern to inline broadcasted shapes
Simplify reasoning about `cstr_broadcastable` ops in the `mhlo-move-up-dynamic-broadcasts-for-fusion` pass. PiperOrigin-RevId: 365560893
This commit is contained in:
parent
fb819c1de8
commit
85a306d356
|
@ -69,6 +69,33 @@ struct ShapeReificationPattern : public OpRewritePattern<shape::ShapeOfOp> {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
struct InlineBroadcastedShapeOperandsPattern : public OpRewritePattern<OpTy> {
|
||||
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(OpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Find all the shape operands, direct and indirect.
|
||||
SmallVector<Value, 8> inlined_operands;
|
||||
for (Value direct : op->getOperands()) {
|
||||
if (auto bcast_op = direct.getDefiningOp<shape::BroadcastOp>()) {
|
||||
for (Value indirect : bcast_op->getOperands())
|
||||
inlined_operands.push_back(indirect);
|
||||
} else {
|
||||
inlined_operands.push_back(direct);
|
||||
}
|
||||
}
|
||||
|
||||
// Only rewrite if it makes a difference.
|
||||
if (inlined_operands.size() == op.getNumOperands()) return failure();
|
||||
|
||||
// Inline shape operands.
|
||||
rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
|
||||
inlined_operands, op->getAttrs());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
|
||||
struct MoveUpBroadcastInDimOpPattern
|
||||
: public OpRewritePattern<DynamicBroadcastInDimOp> {
|
||||
|
@ -124,12 +151,9 @@ struct MoveUpDynamicBroadcastsForFusionPass
|
|||
}
|
||||
|
||||
void runOnFunction() override {
|
||||
// Populate rewrite patterns.
|
||||
MLIRContext *ctx = &getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(ctx, &patterns);
|
||||
|
||||
// Apply transformation.
|
||||
if (failed(
|
||||
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
|
@ -142,8 +166,10 @@ struct MoveUpDynamicBroadcastsForFusionPass
|
|||
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||
// clang-format off
|
||||
patterns->insert<ShapeReificationPattern,
|
||||
MoveUpBroadcastInDimOpPattern>(context);
|
||||
patterns->insert<
|
||||
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
|
||||
MoveUpBroadcastInDimOpPattern,
|
||||
ShapeReificationPattern>(context);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
|
|
|
@ -97,3 +97,18 @@ func @cast_sub(%arg0: tensor<?x32xi16>, %arg1: tensor<?x?x32xf16>)
|
|||
}
|
||||
return %4 : tensor<?x?x32xf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @inline_bcasted_shape_operands
|
||||
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>)
|
||||
func @inline_bcasted_shape_operands(%a : tensor<?xindex>, %b : tensor<?xindex>,
|
||||
%c : tensor<?xindex>) -> !shape.witness {
|
||||
// CHECK-NOT: shape.broadcast
|
||||
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[A]], %[[B]], %[[C]]
|
||||
// CHECK: return %[[WITNESS]] : !shape.witness
|
||||
%0 = shape.broadcast %a, %b : tensor<?xindex>, tensor<?xindex>
|
||||
-> tensor<?xindex>
|
||||
%1 = shape.cstr_broadcastable %0, %c : tensor<?xindex>, tensor<?xindex>
|
||||
return %1 : !shape.witness
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue