[MLIR][MHLO] Generalize extent tensor cast elimination in bcast moving
PiperOrigin-RevId: 370085141
This commit is contained in:
parent
da5d252143
commit
21e9365718
|
@ -305,17 +305,18 @@ struct MergeAssumingOpsPattern : public OpRewritePattern<shape::AssumingOp> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Eliminate casted extent tensors. Instead, produce the concrete extent tensor
|
// Eliminate extent tensor casts. Instead, produce the concrete extent tensor
|
||||||
// type where possible.
|
// type where possible.
|
||||||
struct CanonicalizeCastedShapeOfOpPattern
|
template <typename OpTy>
|
||||||
|
struct CanonicalizeCastedExtentTensorOpPattern
|
||||||
: public OpRewritePattern<tensor::CastOp> {
|
: public OpRewritePattern<tensor::CastOp> {
|
||||||
using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
|
using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(tensor::CastOp op,
|
LogicalResult matchAndRewrite(tensor::CastOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
// Only merge tensor cast into `shape_of` ops.
|
// Only merge tensor cast into a producer op if we know it supports it.
|
||||||
auto shape_of_op = op.source().getDefiningOp<shape::ShapeOfOp>();
|
auto producer_op = op.source().getDefiningOp<OpTy>();
|
||||||
if (!shape_of_op) return failure();
|
if (!producer_op) return failure();
|
||||||
|
|
||||||
// Desired type must be an extent tensor type.
|
// Desired type must be an extent tensor type.
|
||||||
auto result_ty = op.getType().dyn_cast<RankedTensorType>();
|
auto result_ty = op.getType().dyn_cast<RankedTensorType>();
|
||||||
|
@ -323,9 +324,9 @@ struct CanonicalizeCastedShapeOfOpPattern
|
||||||
!result_ty.getElementType().isIndex())
|
!result_ty.getElementType().isIndex())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op, result_ty,
|
rewriter.replaceOpWithNewOp<OpTy>(op, result_ty, producer_op->getOperands(),
|
||||||
shape_of_op.arg());
|
producer_op->getAttrs());
|
||||||
if (shape_of_op->getUses().empty()) rewriter.eraseOp(shape_of_op);
|
if (producer_op->getUses().empty()) rewriter.eraseOp(producer_op);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -401,7 +402,8 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
||||||
MLIRContext *context, OwningRewritePatternList *patterns) {
|
MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<
|
patterns->insert<
|
||||||
CanonicalizeCastedShapeOfOpPattern,
|
CanonicalizeCastedExtentTensorOpPattern<shape::ShapeOfOp>,
|
||||||
|
CanonicalizeCastedExtentTensorOpPattern<shape::BroadcastOp>,
|
||||||
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
|
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
|
||||||
MergeAssumingOpsPattern,
|
MergeAssumingOpsPattern,
|
||||||
MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
|
MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
|
||||||
|
|
|
@ -311,9 +311,9 @@ func @do_not_merge_assuming_ops() {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK: @eliminate_extent_tensor_cast
|
// CHECK: @merge_extent_tensor_cast_into_shape_of
|
||||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<2x?x4xf32>)
|
// CHECK-SAME: (%[[ARG:.*]]: tensor<2x?x4xf32>)
|
||||||
func @eliminate_extent_tensor_cast(%arg : tensor<2x?x4xf32>) {
|
func @merge_extent_tensor_cast_into_shape_of(%arg : tensor<2x?x4xf32>) {
|
||||||
// CHECK-NOT: shape_of
|
// CHECK-NOT: shape_of
|
||||||
// CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<2x?x4xf32> -> tensor<3xindex>
|
// CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<2x?x4xf32> -> tensor<3xindex>
|
||||||
// CHECK-NEXT: "use"(%[[RESULT]]) : (tensor<3xindex>) -> ()
|
// CHECK-NEXT: "use"(%[[RESULT]]) : (tensor<3xindex>) -> ()
|
||||||
|
@ -325,6 +325,19 @@ func @eliminate_extent_tensor_cast(%arg : tensor<2x?x4xf32>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK: @merge_extent_tensor_cast_into_broadcast
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3xindex>, %[[ARG1:.*]]: tensor<3xindex>)
|
||||||
|
func @merge_extent_tensor_cast_into_broadcast(%arg0 : tensor<3xindex>, %arg1 : tensor<3xindex>) {
|
||||||
|
// CHECK: %[[RESULT:.*]] = shape.broadcast %[[ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
|
||||||
|
// CHECK: "use"(%[[RESULT]]) : (tensor<3xindex>) -> ()
|
||||||
|
%0 = shape.broadcast %arg0, %arg1 : tensor<3xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||||
|
%1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
|
||||||
|
"use"(%1) : (tensor<3xindex>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// Exemplary IR as it appears in the lowering of two subsequent `tf.Sub` ops.
|
// Exemplary IR as it appears in the lowering of two subsequent `tf.Sub` ops.
|
||||||
// CHECK-LABEL: @sub_sub
|
// CHECK-LABEL: @sub_sub
|
||||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>, %[[ARG2:.*]]: tensor<?x?x32xf16>)
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>, %[[ARG2:.*]]: tensor<?x?x32xf16>)
|
||||||
|
|
Loading…
Reference in New Issue