[MLIR][MHLO] Add pattern to inline broadcasted shapes
Simplify reasoning about `cstr_broadcastable` ops in the `mhlo-move-up-dynamic-broadcasts-for-fusion` pass. PiperOrigin-RevId: 365560893
This commit is contained in:
parent
fb819c1de8
commit
85a306d356
|
@ -69,6 +69,33 @@ struct ShapeReificationPattern : public OpRewritePattern<shape::ShapeOfOp> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename OpTy>
|
||||||
|
struct InlineBroadcastedShapeOperandsPattern : public OpRewritePattern<OpTy> {
|
||||||
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(OpTy op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// Find all the shape operands, direct and indirect.
|
||||||
|
SmallVector<Value, 8> inlined_operands;
|
||||||
|
for (Value direct : op->getOperands()) {
|
||||||
|
if (auto bcast_op = direct.getDefiningOp<shape::BroadcastOp>()) {
|
||||||
|
for (Value indirect : bcast_op->getOperands())
|
||||||
|
inlined_operands.push_back(indirect);
|
||||||
|
} else {
|
||||||
|
inlined_operands.push_back(direct);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only rewrite if it makes a difference.
|
||||||
|
if (inlined_operands.size() == op.getNumOperands()) return failure();
|
||||||
|
|
||||||
|
// Inline shape operands.
|
||||||
|
rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
|
||||||
|
inlined_operands, op->getAttrs());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
|
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
|
||||||
struct MoveUpBroadcastInDimOpPattern
|
struct MoveUpBroadcastInDimOpPattern
|
||||||
: public OpRewritePattern<DynamicBroadcastInDimOp> {
|
: public OpRewritePattern<DynamicBroadcastInDimOp> {
|
||||||
|
@ -124,12 +151,9 @@ struct MoveUpDynamicBroadcastsForFusionPass
|
||||||
}
|
}
|
||||||
|
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
// Populate rewrite patterns.
|
|
||||||
MLIRContext *ctx = &getContext();
|
MLIRContext *ctx = &getContext();
|
||||||
RewritePatternSet patterns(ctx);
|
RewritePatternSet patterns(ctx);
|
||||||
mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(ctx, &patterns);
|
mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(ctx, &patterns);
|
||||||
|
|
||||||
// Apply transformation.
|
|
||||||
if (failed(
|
if (failed(
|
||||||
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
|
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
|
@ -142,8 +166,10 @@ struct MoveUpDynamicBroadcastsForFusionPass
|
||||||
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
||||||
MLIRContext *context, OwningRewritePatternList *patterns) {
|
MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<ShapeReificationPattern,
|
patterns->insert<
|
||||||
MoveUpBroadcastInDimOpPattern>(context);
|
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
|
||||||
|
MoveUpBroadcastInDimOpPattern,
|
||||||
|
ShapeReificationPattern>(context);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -97,3 +97,18 @@ func @cast_sub(%arg0: tensor<?x32xi16>, %arg1: tensor<?x?x32xf16>)
|
||||||
}
|
}
|
||||||
return %4 : tensor<?x?x32xf16>
|
return %4 : tensor<?x?x32xf16>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @inline_bcasted_shape_operands
|
||||||
|
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>)
|
||||||
|
func @inline_bcasted_shape_operands(%a : tensor<?xindex>, %b : tensor<?xindex>,
|
||||||
|
%c : tensor<?xindex>) -> !shape.witness {
|
||||||
|
// CHECK-NOT: shape.broadcast
|
||||||
|
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[A]], %[[B]], %[[C]]
|
||||||
|
// CHECK: return %[[WITNESS]] : !shape.witness
|
||||||
|
%0 = shape.broadcast %a, %b : tensor<?xindex>, tensor<?xindex>
|
||||||
|
-> tensor<?xindex>
|
||||||
|
%1 = shape.cstr_broadcastable %0, %c : tensor<?xindex>, tensor<?xindex>
|
||||||
|
return %1 : !shape.witness
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue