[MLIR][MHLO] Add pattern to inline broadcasted shapes

Simplify reasoning about `cstr_broadcastable` ops in the
`mhlo-move-up-dynamic-broadcasts-for-fusion` pass.

PiperOrigin-RevId: 365560893
This commit is contained in:
A. Unique TensorFlower 2021-03-29 06:31:43 -07:00 committed by TensorFlow MLIR Team
parent fb819c1de8
commit 85a306d356
2 changed files with 46 additions and 5 deletions

View File

@ -69,6 +69,33 @@ struct ShapeReificationPattern : public OpRewritePattern<shape::ShapeOfOp> {
}
};
template <typename OpTy>
struct InlineBroadcastedShapeOperandsPattern : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Find all the shape operands, direct and indirect.
SmallVector<Value, 8> inlined_operands;
for (Value direct : op->getOperands()) {
if (auto bcast_op = direct.getDefiningOp<shape::BroadcastOp>()) {
for (Value indirect : bcast_op->getOperands())
inlined_operands.push_back(indirect);
} else {
inlined_operands.push_back(direct);
}
}
// Only rewrite if it makes a difference.
if (inlined_operands.size() == op.getNumOperands()) return failure();
// Inline shape operands.
rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
inlined_operands, op->getAttrs());
return success();
}
};
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
struct MoveUpBroadcastInDimOpPattern
: public OpRewritePattern<DynamicBroadcastInDimOp> {
@ -124,12 +151,9 @@ struct MoveUpDynamicBroadcastsForFusionPass
}
void runOnFunction() override {
// Populate rewrite patterns.
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(ctx, &patterns);
// Apply transformation.
if (failed(
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
return signalPassFailure();
@ -142,8 +166,10 @@ struct MoveUpDynamicBroadcastsForFusionPass
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
MLIRContext *context, OwningRewritePatternList *patterns) {
// clang-format off
patterns->insert<ShapeReificationPattern,
MoveUpBroadcastInDimOpPattern>(context);
patterns->insert<
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
MoveUpBroadcastInDimOpPattern,
ShapeReificationPattern>(context);
// clang-format on
}

View File

@ -97,3 +97,18 @@ func @cast_sub(%arg0: tensor<?x32xi16>, %arg1: tensor<?x?x32xf16>)
}
return %4 : tensor<?x?x32xf16>
}
// -----
// CHECK-LABEL: @inline_bcasted_shape_operands
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>)
func @inline_bcasted_shape_operands(%a : tensor<?xindex>, %b : tensor<?xindex>,
%c : tensor<?xindex>) -> !shape.witness {
// CHECK-NOT: shape.broadcast
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[A]], %[[B]], %[[C]]
// CHECK: return %[[WITNESS]] : !shape.witness
%0 = shape.broadcast %a, %b : tensor<?xindex>, tensor<?xindex>
-> tensor<?xindex>
%1 = shape.cstr_broadcastable %0, %c : tensor<?xindex>, tensor<?xindex>
return %1 : !shape.witness
}