[MLIR] Merge extent tensor casts into `shape_of` ops in broadcast moving
PiperOrigin-RevId: 370058002
This commit is contained in:
		
							parent
							
								
									890a79641e
								
							
						
					
					
						commit
						da5d252143
					
				| 
						 | 
					@ -305,6 +305,31 @@ struct MergeAssumingOpsPattern : public OpRewritePattern<shape::AssumingOp> {
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Eliminate casted extent tensors. Instead, produce the concrete extent tensor
 | 
				
			||||||
 | 
					// type where possible.
 | 
				
			||||||
 | 
					struct CanonicalizeCastedShapeOfOpPattern
 | 
				
			||||||
 | 
					    : public OpRewritePattern<tensor::CastOp> {
 | 
				
			||||||
 | 
					  using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  LogicalResult matchAndRewrite(tensor::CastOp op,
 | 
				
			||||||
 | 
					                                PatternRewriter &rewriter) const override {
 | 
				
			||||||
 | 
					    // Only merge tensor cast into `shape_of` ops.
 | 
				
			||||||
 | 
					    auto shape_of_op = op.source().getDefiningOp<shape::ShapeOfOp>();
 | 
				
			||||||
 | 
					    if (!shape_of_op) return failure();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Desired type must be an extent tensor type.
 | 
				
			||||||
 | 
					    auto result_ty = op.getType().dyn_cast<RankedTensorType>();
 | 
				
			||||||
 | 
					    if (!result_ty || result_ty.getRank() != 1 ||
 | 
				
			||||||
 | 
					        !result_ty.getElementType().isIndex())
 | 
				
			||||||
 | 
					      return failure();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    rewriter.replaceOpWithNewOp<shape::ShapeOfOp>(op, result_ty,
 | 
				
			||||||
 | 
					                                                  shape_of_op.arg());
 | 
				
			||||||
 | 
					    if (shape_of_op->getUses().empty()) rewriter.eraseOp(shape_of_op);
 | 
				
			||||||
 | 
					    return success();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
 | 
					// TODO(frgossen): Only move up broadcasting operations if there is a consumer.
 | 
				
			||||||
struct MoveUpBroadcastInDimOpPattern
 | 
					struct MoveUpBroadcastInDimOpPattern
 | 
				
			||||||
    : public OpRewritePattern<DynamicBroadcastInDimOp> {
 | 
					    : public OpRewritePattern<DynamicBroadcastInDimOp> {
 | 
				
			||||||
| 
						 | 
					@ -376,6 +401,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
 | 
				
			||||||
    MLIRContext *context, OwningRewritePatternList *patterns) {
 | 
					    MLIRContext *context, OwningRewritePatternList *patterns) {
 | 
				
			||||||
  // clang-format off
 | 
					  // clang-format off
 | 
				
			||||||
  patterns->insert<
 | 
					  patterns->insert<
 | 
				
			||||||
 | 
					      CanonicalizeCastedShapeOfOpPattern,
 | 
				
			||||||
      InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
 | 
					      InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
 | 
				
			||||||
      MergeAssumingOpsPattern,
 | 
					      MergeAssumingOpsPattern,
 | 
				
			||||||
      MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
 | 
					      MoveIntoAssumingOpPattern<shape::ShapeOfOp>,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -311,6 +311,20 @@ func @do_not_merge_assuming_ops() {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// -----
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK:      @eliminate_extent_tensor_cast
 | 
				
			||||||
 | 
					// CHECK-SAME: (%[[ARG:.*]]: tensor<2x?x4xf32>)
 | 
				
			||||||
 | 
					func @eliminate_extent_tensor_cast(%arg : tensor<2x?x4xf32>) {
 | 
				
			||||||
 | 
					  // CHECK-NOT:  shape_of
 | 
				
			||||||
 | 
					  // CHECK:      %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<2x?x4xf32> -> tensor<3xindex>
 | 
				
			||||||
 | 
					  // CHECK-NEXT: "use"(%[[RESULT]]) : (tensor<3xindex>) -> ()
 | 
				
			||||||
 | 
					  %0 = shape.shape_of %arg : tensor<2x?x4xf32> -> tensor<?xindex>
 | 
				
			||||||
 | 
					  %1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
 | 
				
			||||||
 | 
					  "use"(%1) : (tensor<3xindex>) -> ()
 | 
				
			||||||
 | 
					  return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Exemplary IR as it appears in the lowering of two subsequent `tf.Sub` ops.
 | 
					// Exemplary IR as it appears in the lowering of two subsequent `tf.Sub` ops.
 | 
				
			||||||
// CHECK-LABEL: @sub_sub
 | 
					// CHECK-LABEL: @sub_sub
 | 
				
			||||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>, %[[ARG2:.*]]: tensor<?x?x32xf16>)
 | 
					// CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>, %[[ARG2:.*]]: tensor<?x?x32xf16>)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue