[MLIR][MHLO] Generalize extent tensor cast elimination in bcast moving

PiperOrigin-RevId: 370112887
This commit is contained in:
A. Unique TensorFlower 2021-04-23 10:51:38 -07:00 committed by TensorFlow MLIR Team
parent ab1ccbaa6e
commit 0569b7f7a4
2 changed files with 11 additions and 26 deletions

View File

@ -305,18 +305,17 @@ struct MergeAssumingOpsPattern : public OpRewritePattern<shape::AssumingOp> {
} }
}; };
// Eliminate extent tensor casts. Instead, produce the concrete extent tensor // Eliminate casted extent tensors. Instead, produce the concrete extent tensor
// type where possible. // type where possible.
template <typename OpTy> struct CanonicalizeCastedShapeOfOpPattern
struct CanonicalizeCastedExtentTensorOpPattern
: public OpRewritePattern<tensor::CastOp> { : public OpRewritePattern<tensor::CastOp> {
using OpRewritePattern<tensor::CastOp>::OpRewritePattern; using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::CastOp op, LogicalResult matchAndRewrite(tensor::CastOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
// Only merge tensor cast into a producer op if we know it supports it. // Only merge tensor cast into `shape_of` ops.
auto producer_op = op.source().getDefiningOp<OpTy>(); auto shape_of_op = op.source().getDefiningOp<shape::ShapeOfOp>();
if (!producer_op) return failure(); if (!shape_of_op) return failure();
// Desired type must be an extent tensor type. // Desired type must be an extent tensor type.
auto result_ty = op.getType().dyn_cast<RankedTensorType>(); auto result_ty = op.getType().dyn_cast<RankedTensorType>();
@ -324,9 +323,9 @@ struct CanonicalizeCastedExtentTensorOpPattern
!result_ty.getElementType().isIndex()) !result_ty.getElementType().isIndex())
return failure(); return failure();
rewriter.replaceOpWithNewOp<OpTy>(op, result_ty, producer_op->getOperands(), rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op, result_ty,
producer_op->getAttrs()); shape_of_op.arg());
if (producer_op->getUses().empty()) rewriter.eraseOp(producer_op); if (shape_of_op->getUses().empty()) rewriter.eraseOp(shape_of_op);
return success(); return success();
} }
}; };
@ -402,8 +401,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
MLIRContext *context, OwningRewritePatternList *patterns) { MLIRContext *context, OwningRewritePatternList *patterns) {
// clang-format off // clang-format off
patterns->insert< patterns->insert<
CanonicalizeCastedExtentTensorOpPattern<shape::ShapeOfOp>, CanonicalizeCastedShapeOfOpPattern,
CanonicalizeCastedExtentTensorOpPattern<shape::BroadcastOp>,
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>, InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
MergeAssumingOpsPattern, MergeAssumingOpsPattern,
MoveIntoAssumingOpPattern<shape::ShapeOfOp>, MoveIntoAssumingOpPattern<shape::ShapeOfOp>,

View File

@ -311,9 +311,9 @@ func @do_not_merge_assuming_ops() {
// ----- // -----
// CHECK: @merge_extent_tensor_cast_into_shape_of // CHECK: @eliminate_extent_tensor_cast
// CHECK-SAME: (%[[ARG:.*]]: tensor<2x?x4xf32>) // CHECK-SAME: (%[[ARG:.*]]: tensor<2x?x4xf32>)
func @merge_extent_tensor_cast_into_shape_of(%arg : tensor<2x?x4xf32>) { func @eliminate_extent_tensor_cast(%arg : tensor<2x?x4xf32>) {
// CHECK-NOT: shape_of // CHECK-NOT: shape_of
// CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<2x?x4xf32> -> tensor<3xindex> // CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<2x?x4xf32> -> tensor<3xindex>
// CHECK-NEXT: "use"(%[[RESULT]]) : (tensor<3xindex>) -> () // CHECK-NEXT: "use"(%[[RESULT]]) : (tensor<3xindex>) -> ()
@ -325,19 +325,6 @@ func @merge_extent_tensor_cast_into_shape_of(%arg : tensor<2x?x4xf32>) {
// ----- // -----
// CHECK: @merge_extent_tensor_cast_into_broadcast
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3xindex>, %[[ARG1:.*]]: tensor<3xindex>)
func @merge_extent_tensor_cast_into_broadcast(%arg0 : tensor<3xindex>, %arg1 : tensor<3xindex>) {
// CHECK: %[[RESULT:.*]] = shape.broadcast %[[ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
// CHECK: "use"(%[[RESULT]]) : (tensor<3xindex>) -> ()
%0 = shape.broadcast %arg0, %arg1 : tensor<3xindex>, tensor<3xindex> -> tensor<?xindex>
%1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
"use"(%1) : (tensor<3xindex>) -> ()
return
}
// -----
// Exemplary IR as it appears in the lowering of two subsequent `tf.Sub` ops. // Exemplary IR as it appears in the lowering of two subsequent `tf.Sub` ops.
// CHECK-LABEL: @sub_sub // CHECK-LABEL: @sub_sub
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>, %[[ARG2:.*]]: tensor<?x?x32xf16>) // CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>, %[[ARG2:.*]]: tensor<?x?x32xf16>)