[MLIR] Merge extent tensor casts into `shape_of` ops in broadcast moving
PiperOrigin-RevId: 370058002
This commit is contained in:
		
							parent
							
								
									890a79641e
								
							
						
					
					
						commit
						da5d252143
					
				|  | @ -305,6 +305,31 @@ struct MergeAssumingOpsPattern : public OpRewritePattern<shape::AssumingOp> { | |||
|   } | ||||
| }; | ||||
| 
 | ||||
| // Eliminate casted extent tensors. Instead, produce the concrete extent tensor
 | ||||
| // type where possible.
 | ||||
| struct CanonicalizeCastedShapeOfOpPattern | ||||
|     : public OpRewritePattern<tensor::CastOp> { | ||||
|   using OpRewritePattern<tensor::CastOp>::OpRewritePattern; | ||||
| 
 | ||||
|   LogicalResult matchAndRewrite(tensor::CastOp op, | ||||
|                                 PatternRewriter &rewriter) const override { | ||||
|     // Only merge tensor cast into `shape_of` ops.
 | ||||
|     auto shape_of_op = op.source().getDefiningOp<shape::ShapeOfOp>(); | ||||
|     if (!shape_of_op) return failure(); | ||||
| 
 | ||||
|     // Desired type must be an extent tensor type.
 | ||||
|     auto result_ty = op.getType().dyn_cast<RankedTensorType>(); | ||||
|     if (!result_ty || result_ty.getRank() != 1 || | ||||
|         !result_ty.getElementType().isIndex()) | ||||
|       return failure(); | ||||
| 
 | ||||
|     rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op, result_ty, | ||||
|                                                   shape_of_op.arg()); | ||||
|     if (shape_of_op->getUses().empty()) rewriter.eraseOp(shape_of_op); | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| // TODO(frgossen): Only move up broadcasting operations if there is a consumer.
 | ||||
| struct MoveUpBroadcastInDimOpPattern | ||||
|     : public OpRewritePattern<DynamicBroadcastInDimOp> { | ||||
|  | @ -376,6 +401,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns( | |||
|     MLIRContext *context, OwningRewritePatternList *patterns) { | ||||
|   // clang-format off
 | ||||
|   patterns->insert< | ||||
|       CanonicalizeCastedShapeOfOpPattern, | ||||
|       InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>, | ||||
|       MergeAssumingOpsPattern, | ||||
|       MoveIntoAssumingOpPattern<shape::ShapeOfOp>, | ||||
|  |  | |||
|  | @ -311,6 +311,20 @@ func @do_not_merge_assuming_ops() { | |||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK:      @eliminate_extent_tensor_cast | ||||
| // CHECK-SAME: (%[[ARG:.*]]: tensor<2x?x4xf32>) | ||||
| func @eliminate_extent_tensor_cast(%arg : tensor<2x?x4xf32>) { | ||||
|   // CHECK-NOT:  shape_of | ||||
|   // CHECK:      %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<2x?x4xf32> -> tensor<3xindex> | ||||
|   // CHECK-NEXT: "use"(%[[RESULT]]) : (tensor<3xindex>) -> () | ||||
|   %0 = shape.shape_of %arg : tensor<2x?x4xf32> -> tensor<?xindex> | ||||
|   %1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex> | ||||
|   "use"(%1) : (tensor<3xindex>) -> () | ||||
|   return | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // Exemplary IR as it appears in the lowering of two subsequent `tf.Sub` ops. | ||||
| // CHECK-LABEL: @sub_sub | ||||
| // CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>, %[[ARG2:.*]]: tensor<?x?x32xf16>) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue