[MLIR] Merge extent tensor casts into `shape_of` ops in broadcast moving
PiperOrigin-RevId: 370058002
This commit is contained in:
parent
890a79641e
commit
da5d252143
|
@ -305,6 +305,31 @@ struct MergeAssumingOpsPattern : public OpRewritePattern<shape::AssumingOp> {
|
|||
}
|
||||
};
|
||||
|
||||
// Eliminate casted extent tensors. Instead, produce the concrete extent tensor
|
||||
// type where possible.
|
||||
struct CanonicalizeCastedShapeOfOpPattern
|
||||
: public OpRewritePattern<tensor::CastOp> {
|
||||
using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tensor::CastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Only merge tensor cast into `shape_of` ops.
|
||||
auto shape_of_op = op.source().getDefiningOp<shape::ShapeOfOp>();
|
||||
if (!shape_of_op) return failure();
|
||||
|
||||
// Desired type must be an extent tensor type.
|
||||
auto result_ty = op.getType().dyn_cast<RankedTensorType>();
|
||||
if (!result_ty || result_ty.getRank() != 1 ||
|
||||
!result_ty.getElementType().isIndex())
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op, result_ty,
|
||||
shape_of_op.arg());
|
||||
if (shape_of_op->getUses().empty()) rewriter.eraseOp(shape_of_op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
|
||||
struct MoveUpBroadcastInDimOpPattern
|
||||
: public OpRewritePattern<DynamicBroadcastInDimOp> {
|
||||
|
@ -376,6 +401,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
|||
MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||
// clang-format off
|
||||
patterns->insert<
|
||||
CanonicalizeCastedShapeOfOpPattern,
|
||||
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
|
||||
MergeAssumingOpsPattern,
|
||||
MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
|
||||
|
|
|
@ -311,6 +311,20 @@ func @do_not_merge_assuming_ops() {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK: @eliminate_extent_tensor_cast
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<2x?x4xf32>)
|
||||
func @eliminate_extent_tensor_cast(%arg : tensor<2x?x4xf32>) {
|
||||
// CHECK-NOT: shape_of
|
||||
// CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<2x?x4xf32> -> tensor<3xindex>
|
||||
// CHECK-NEXT: "use"(%[[RESULT]]) : (tensor<3xindex>) -> ()
|
||||
%0 = shape.shape_of %arg : tensor<2x?x4xf32> -> tensor<?xindex>
|
||||
%1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
|
||||
"use"(%1) : (tensor<3xindex>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Exemplary IR as it appears in the lowering of two subsequent `tf.Sub` ops.
|
||||
// CHECK-LABEL: @sub_sub
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>, %[[ARG2:.*]]: tensor<?x?x32xf16>)
|
||||
|
|
Loading…
Reference in New Issue