[MLIR][MHLO] Generalize extent tensor cast elimination in bcast moving
PiperOrigin-RevId: 370112887
This commit is contained in:
		
							parent
							
								
									ab1ccbaa6e
								
							
						
					
					
						commit
						0569b7f7a4
					
				|  | @ -305,18 +305,17 @@ struct MergeAssumingOpsPattern : public OpRewritePattern<shape::AssumingOp> { | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| // Eliminate extent tensor casts. Instead, produce the concrete extent tensor
 | // Eliminate casted extent tensors. Instead, produce the concrete extent tensor
 | ||||||
| // type where possible.
 | // type where possible.
 | ||||||
| template <typename OpTy> | struct CanonicalizeCastedShapeOfOpPattern | ||||||
| struct CanonicalizeCastedExtentTensorOpPattern |  | ||||||
|     : public OpRewritePattern<tensor::CastOp> { |     : public OpRewritePattern<tensor::CastOp> { | ||||||
|   using OpRewritePattern<tensor::CastOp>::OpRewritePattern; |   using OpRewritePattern<tensor::CastOp>::OpRewritePattern; | ||||||
| 
 | 
 | ||||||
|   LogicalResult matchAndRewrite(tensor::CastOp op, |   LogicalResult matchAndRewrite(tensor::CastOp op, | ||||||
|                                 PatternRewriter &rewriter) const override { |                                 PatternRewriter &rewriter) const override { | ||||||
|     // Only merge tensor cast into a producer op if we know it supports it.
 |     // Only merge tensor cast into `shape_of` ops.
 | ||||||
|     auto producer_op = op.source().getDefiningOp<OpTy>(); |     auto shape_of_op = op.source().getDefiningOp<shape::ShapeOfOp>(); | ||||||
|     if (!producer_op) return failure(); |     if (!shape_of_op) return failure(); | ||||||
| 
 | 
 | ||||||
|     // Desired type must be an extent tensor type.
 |     // Desired type must be an extent tensor type.
 | ||||||
|     auto result_ty = op.getType().dyn_cast<RankedTensorType>(); |     auto result_ty = op.getType().dyn_cast<RankedTensorType>(); | ||||||
|  | @ -324,9 +323,9 @@ struct CanonicalizeCastedExtentTensorOpPattern | ||||||
|         !result_ty.getElementType().isIndex()) |         !result_ty.getElementType().isIndex()) | ||||||
|       return failure(); |       return failure(); | ||||||
| 
 | 
 | ||||||
|     rewriter.replaceOpWithNewOp<OpTy>(op, result_ty, producer_op->getOperands(), |     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op, result_ty, | ||||||
|                                       producer_op->getAttrs()); |                                                   shape_of_op.arg()); | ||||||
|     if (producer_op->getUses().empty()) rewriter.eraseOp(producer_op); |     if (shape_of_op->getUses().empty()) rewriter.eraseOp(shape_of_op); | ||||||
|     return success(); |     return success(); | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
|  | @ -402,8 +401,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns( | ||||||
|     MLIRContext *context, OwningRewritePatternList *patterns) { |     MLIRContext *context, OwningRewritePatternList *patterns) { | ||||||
|   // clang-format off
 |   // clang-format off
 | ||||||
|   patterns->insert< |   patterns->insert< | ||||||
|       CanonicalizeCastedExtentTensorOpPattern<shape::ShapeOfOp>, |       CanonicalizeCastedShapeOfOpPattern, | ||||||
|       CanonicalizeCastedExtentTensorOpPattern<shape::BroadcastOp>, |  | ||||||
|       InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>, |       InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>, | ||||||
|       MergeAssumingOpsPattern, |       MergeAssumingOpsPattern, | ||||||
|       MoveIntoAssumingOpPattern<shape::ShapeOfOp>, |       MoveIntoAssumingOpPattern<shape::ShapeOfOp>, | ||||||
|  |  | ||||||
|  | @ -311,9 +311,9 @@ func @do_not_merge_assuming_ops() { | ||||||
| 
 | 
 | ||||||
| // ----- | // ----- | ||||||
| 
 | 
 | ||||||
| // CHECK:      @merge_extent_tensor_cast_into_shape_of | // CHECK:      @eliminate_extent_tensor_cast | ||||||
| // CHECK-SAME: (%[[ARG:.*]]: tensor<2x?x4xf32>) | // CHECK-SAME: (%[[ARG:.*]]: tensor<2x?x4xf32>) | ||||||
| func @merge_extent_tensor_cast_into_shape_of(%arg : tensor<2x?x4xf32>) { | func @eliminate_extent_tensor_cast(%arg : tensor<2x?x4xf32>) { | ||||||
|   // CHECK-NOT:  shape_of |   // CHECK-NOT:  shape_of | ||||||
|   // CHECK:      %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<2x?x4xf32> -> tensor<3xindex> |   // CHECK:      %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<2x?x4xf32> -> tensor<3xindex> | ||||||
|   // CHECK-NEXT: "use"(%[[RESULT]]) : (tensor<3xindex>) -> () |   // CHECK-NEXT: "use"(%[[RESULT]]) : (tensor<3xindex>) -> () | ||||||
|  | @ -325,19 +325,6 @@ func @merge_extent_tensor_cast_into_shape_of(%arg : tensor<2x?x4xf32>) { | ||||||
| 
 | 
 | ||||||
| // ----- | // ----- | ||||||
| 
 | 
 | ||||||
| // CHECK:      @merge_extent_tensor_cast_into_broadcast |  | ||||||
| // CHECK-SAME: (%[[ARG0:.*]]: tensor<3xindex>, %[[ARG1:.*]]: tensor<3xindex>) |  | ||||||
| func @merge_extent_tensor_cast_into_broadcast(%arg0 : tensor<3xindex>, %arg1 : tensor<3xindex>) { |  | ||||||
|   // CHECK: %[[RESULT:.*]] = shape.broadcast %[[ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex> |  | ||||||
|   // CHECK: "use"(%[[RESULT]]) : (tensor<3xindex>) -> () |  | ||||||
|   %0 = shape.broadcast %arg0, %arg1 : tensor<3xindex>, tensor<3xindex> -> tensor<?xindex> |  | ||||||
|   %1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex> |  | ||||||
|   "use"(%1) : (tensor<3xindex>) -> () |  | ||||||
|   return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // ----- |  | ||||||
| 
 |  | ||||||
| // Exemplary IR as it appears in the lowering of two subsequent `tf.Sub` ops. | // Exemplary IR as it appears in the lowering of two subsequent `tf.Sub` ops. | ||||||
| // CHECK-LABEL: @sub_sub | // CHECK-LABEL: @sub_sub | ||||||
| // CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>, %[[ARG2:.*]]: tensor<?x?x32xf16>) | // CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>, %[[ARG2:.*]]: tensor<?x?x32xf16>) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue