[MLIR][HLO] Add `shape.broadcast` canonicalization to unblock broadcast moving
PiperOrigin-RevId: 372120309
This commit is contained in:
		
							parent
							
								
									6bc854f5d9
								
							
						
					
					
						commit
						d8c40b691c
					
				|  | @ -330,6 +330,36 @@ struct CanonicalizeCastedShapeOfOpPattern | |||
|   } | ||||
| }; | ||||
| 
 | ||||
| // TODO(frgossen): Remove this once it has landed upstream.
 | ||||
| struct CanonicalizeBroadcastPattern | ||||
|     : public OpRewritePattern<shape::BroadcastOp> { | ||||
|   using OpRewritePattern<shape::BroadcastOp>::OpRewritePattern; | ||||
| 
 | ||||
|   LogicalResult matchAndRewrite(shape::BroadcastOp op, | ||||
|                                 PatternRewriter &rewriter) const override { | ||||
|     // Only concretize dynamic extent tensor result types.
 | ||||
|     auto resultTy = op.getType().dyn_cast<RankedTensorType>(); | ||||
|     if (!resultTy || !resultTy.isDynamicDim(0)) return failure(); | ||||
| 
 | ||||
|     // Infer resulting shape rank if possible.
 | ||||
|     int64_t maxRank = 0; | ||||
|     for (Value shape : op.shapes()) { | ||||
|       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { | ||||
|         // Cannot infer resulting shape rank if any operand is dynamically
 | ||||
|         // ranked.
 | ||||
|         if (extentTensorTy.isDynamicDim(0)) return failure(); | ||||
|         maxRank = std::max(maxRank, extentTensorTy.getDimSize(0)); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     auto newOp = rewriter.create<shape::BroadcastOp>( | ||||
|         op.getLoc(), RankedTensorType::get({maxRank}, rewriter.getIndexType()), | ||||
|         op.shapes()); | ||||
|     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| // TODO(frgossen): Only move up broadcasting operations if there is a consumer.
 | ||||
| struct MoveUpBroadcastInDimOpPattern | ||||
|     : public OpRewritePattern<DynamicBroadcastInDimOp> { | ||||
|  | @ -401,6 +431,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns( | |||
|     MLIRContext *context, OwningRewritePatternList *patterns) { | ||||
|   // clang-format off
 | ||||
|   patterns->insert< | ||||
|       CanonicalizeBroadcastPattern, | ||||
|       CanonicalizeCastedShapeOfOpPattern, | ||||
|       InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>, | ||||
|       MergeAssumingOpsPattern, | ||||
|  | @ -411,6 +442,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns( | |||
|       MoveUpBroadcastInDimOpPattern, | ||||
|       ShapeReificationPattern>(context); | ||||
|   // clang-format on
 | ||||
|   tensor::CastOp::getCanonicalizationPatterns(*patterns, context); | ||||
| } | ||||
| 
 | ||||
| std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass() { | ||||
|  |  | |||
|  | @ -337,7 +337,6 @@ func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>, | |||
|   // CHECK-DAG:  %[[WITNESS1:.*]] = shape.cstr_broadcastable %[[SHAPE2]], %[[SHAPE0]], %[[SHAPE1]] | ||||
|   // CHECK-DAG:  %[[COMBINED_WITNESS:.*]] = shape.assuming_all %[[WITNESS0]], %[[WITNESS1]] | ||||
|   // CHECK:      %[[ASSUMING_RESULT:.*]] = shape.assuming %[[COMBINED_WITNESS]] | ||||
|   // CHECK-SAME: { | ||||
|   // CHECK:        %[[BCASTED_SHAPE01:.*]] = shape.broadcast %[[SHAPE0]], %[[SHAPE1]] | ||||
|   // CHECK:        %[[BCASTED_SHAPE012:.*]] = shape.broadcast %[[SHAPE2]], %[[BCASTED_SHAPE01]] | ||||
|   // CHECK:        %[[BCASTED_ARG2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG2]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[0, 1, 2]> | ||||
|  | @ -346,7 +345,6 @@ func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>, | |||
|   // CHECK:        %[[TMP:.*]] = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]] | ||||
|   // CHECK:        %[[RESULT:.*]] = mhlo.subtract %[[BCASTED_ARG2]], %[[TMP]] | ||||
|   // CHECK:        shape.assuming_yield %[[RESULT]] | ||||
|   // CHECK:      } | ||||
|   // CHECK:      return %[[ASSUMING_RESULT]] | ||||
|   %0 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex> | ||||
|   %1 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex> | ||||
|  | @ -354,9 +352,10 @@ func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>, | |||
|   %3 = shape.assuming %2 -> (tensor<?x32xf16>) { | ||||
|     %8 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex> | ||||
|     %9 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex> | ||||
|     %10 = shape.broadcast %8, %9 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> | ||||
|     %12 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %10) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16> | ||||
|     %13 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %10) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16> | ||||
|     %10 = shape.broadcast %8, %9 : tensor<2xindex>, tensor<2xindex> -> tensor<?xindex> | ||||
|     %11 = tensor.cast %10 : tensor<?xindex> to tensor<2xindex> | ||||
|     %12 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %11) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16> | ||||
|     %13 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %11) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16> | ||||
|     %14 = mhlo.subtract %12, %13 : tensor<?x32xf16> | ||||
|     shape.assuming_yield %14 : tensor<?x32xf16> | ||||
|   } | ||||
|  | @ -366,9 +365,10 @@ func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>, | |||
|   %7 = shape.assuming %6 -> (tensor<?x?x32xf16>) { | ||||
|     %8 = shape.shape_of %arg2 : tensor<?x?x32xf16> -> tensor<3xindex> | ||||
|     %9 = shape.shape_of %3 : tensor<?x32xf16> -> tensor<2xindex> | ||||
|     %10 = shape.broadcast %8, %9 : tensor<3xindex>, tensor<2xindex> -> tensor<3xindex> | ||||
|     %12 = "mhlo.dynamic_broadcast_in_dim"(%arg2, %10) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<?x?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16> | ||||
|     %13 = "mhlo.dynamic_broadcast_in_dim"(%3, %10) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16> | ||||
|     %10 = shape.broadcast %8, %9 : tensor<3xindex>, tensor<2xindex> -> tensor<?xindex> | ||||
|     %11 = tensor.cast %10 : tensor<?xindex> to tensor<3xindex> | ||||
|     %12 = "mhlo.dynamic_broadcast_in_dim"(%arg2, %11) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<?x?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16> | ||||
|     %13 = "mhlo.dynamic_broadcast_in_dim"(%3, %11) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16> | ||||
|     %14 = mhlo.subtract %12, %13 : tensor<?x?x32xf16> | ||||
|     shape.assuming_yield %14 : tensor<?x?x32xf16> | ||||
|   } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue