[MLIR][HLO] Add `shape.broadcast` canonicalization to unblock broadcast moving
PiperOrigin-RevId: 372120309
This commit is contained in:
		
							parent
							
								
									6bc854f5d9
								
							
						
					
					
						commit
						d8c40b691c
					
				|  | @ -330,6 +330,36 @@ struct CanonicalizeCastedShapeOfOpPattern | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | // TODO(frgossen): Remove this once it has landed upstream.
 | ||||||
|  | struct CanonicalizeBroadcastPattern | ||||||
|  |     : public OpRewritePattern<shape::BroadcastOp> { | ||||||
|  |   using OpRewritePattern<shape::BroadcastOp>::OpRewritePattern; | ||||||
|  | 
 | ||||||
|  |   LogicalResult matchAndRewrite(shape::BroadcastOp op, | ||||||
|  |                                 PatternRewriter &rewriter) const override { | ||||||
|  |     // Only concretize dynamic extent tensor result types.
 | ||||||
|  |     auto resultTy = op.getType().dyn_cast<RankedTensorType>(); | ||||||
|  |     if (!resultTy || !resultTy.isDynamicDim(0)) return failure(); | ||||||
|  | 
 | ||||||
|  |     // Infer resulting shape rank if possible.
 | ||||||
|  |     int64_t maxRank = 0; | ||||||
|  |     for (Value shape : op.shapes()) { | ||||||
|  |       if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) { | ||||||
|  |         // Cannot infer resulting shape rank if any operand is dynamically
 | ||||||
|  |         // ranked.
 | ||||||
|  |         if (extentTensorTy.isDynamicDim(0)) return failure(); | ||||||
|  |         maxRank = std::max(maxRank, extentTensorTy.getDimSize(0)); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     auto newOp = rewriter.create<shape::BroadcastOp>( | ||||||
|  |         op.getLoc(), RankedTensorType::get({maxRank}, rewriter.getIndexType()), | ||||||
|  |         op.shapes()); | ||||||
|  |     rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); | ||||||
|  |     return success(); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
| // TODO(frgossen): Only move up broadcasting operations if there is a consumer.
 | // TODO(frgossen): Only move up broadcasting operations if there is a consumer.
 | ||||||
| struct MoveUpBroadcastInDimOpPattern | struct MoveUpBroadcastInDimOpPattern | ||||||
|     : public OpRewritePattern<DynamicBroadcastInDimOp> { |     : public OpRewritePattern<DynamicBroadcastInDimOp> { | ||||||
|  | @ -401,6 +431,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns( | ||||||
|     MLIRContext *context, OwningRewritePatternList *patterns) { |     MLIRContext *context, OwningRewritePatternList *patterns) { | ||||||
|   // clang-format off
 |   // clang-format off
 | ||||||
|   patterns->insert< |   patterns->insert< | ||||||
|  |       CanonicalizeBroadcastPattern, | ||||||
|       CanonicalizeCastedShapeOfOpPattern, |       CanonicalizeCastedShapeOfOpPattern, | ||||||
|       InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>, |       InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>, | ||||||
|       MergeAssumingOpsPattern, |       MergeAssumingOpsPattern, | ||||||
|  | @ -411,6 +442,7 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns( | ||||||
|       MoveUpBroadcastInDimOpPattern, |       MoveUpBroadcastInDimOpPattern, | ||||||
|       ShapeReificationPattern>(context); |       ShapeReificationPattern>(context); | ||||||
|   // clang-format on
 |   // clang-format on
 | ||||||
|  |   tensor::CastOp::getCanonicalizationPatterns(*patterns, context); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass() { | std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass() { | ||||||
|  |  | ||||||
|  | @ -337,7 +337,6 @@ func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>, | ||||||
|   // CHECK-DAG:  %[[WITNESS1:.*]] = shape.cstr_broadcastable %[[SHAPE2]], %[[SHAPE0]], %[[SHAPE1]] |   // CHECK-DAG:  %[[WITNESS1:.*]] = shape.cstr_broadcastable %[[SHAPE2]], %[[SHAPE0]], %[[SHAPE1]] | ||||||
|   // CHECK-DAG:  %[[COMBINED_WITNESS:.*]] = shape.assuming_all %[[WITNESS0]], %[[WITNESS1]] |   // CHECK-DAG:  %[[COMBINED_WITNESS:.*]] = shape.assuming_all %[[WITNESS0]], %[[WITNESS1]] | ||||||
|   // CHECK:      %[[ASSUMING_RESULT:.*]] = shape.assuming %[[COMBINED_WITNESS]] |   // CHECK:      %[[ASSUMING_RESULT:.*]] = shape.assuming %[[COMBINED_WITNESS]] | ||||||
|   // CHECK-SAME: { |  | ||||||
|   // CHECK:        %[[BCASTED_SHAPE01:.*]] = shape.broadcast %[[SHAPE0]], %[[SHAPE1]] |   // CHECK:        %[[BCASTED_SHAPE01:.*]] = shape.broadcast %[[SHAPE0]], %[[SHAPE1]] | ||||||
|   // CHECK:        %[[BCASTED_SHAPE012:.*]] = shape.broadcast %[[SHAPE2]], %[[BCASTED_SHAPE01]] |   // CHECK:        %[[BCASTED_SHAPE012:.*]] = shape.broadcast %[[SHAPE2]], %[[BCASTED_SHAPE01]] | ||||||
|   // CHECK:        %[[BCASTED_ARG2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG2]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[0, 1, 2]> |   // CHECK:        %[[BCASTED_ARG2:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG2]], %[[BCASTED_SHAPE012]]) {broadcast_dimensions = dense<[0, 1, 2]> | ||||||
|  | @ -346,7 +345,6 @@ func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>, | ||||||
|   // CHECK:        %[[TMP:.*]] = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]] |   // CHECK:        %[[TMP:.*]] = mhlo.subtract %[[BCASTED_ARG0]], %[[BCASTED_ARG1]] | ||||||
|   // CHECK:        %[[RESULT:.*]] = mhlo.subtract %[[BCASTED_ARG2]], %[[TMP]] |   // CHECK:        %[[RESULT:.*]] = mhlo.subtract %[[BCASTED_ARG2]], %[[TMP]] | ||||||
|   // CHECK:        shape.assuming_yield %[[RESULT]] |   // CHECK:        shape.assuming_yield %[[RESULT]] | ||||||
|   // CHECK:      } |  | ||||||
|   // CHECK:      return %[[ASSUMING_RESULT]] |   // CHECK:      return %[[ASSUMING_RESULT]] | ||||||
|   %0 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex> |   %0 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex> | ||||||
|   %1 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex> |   %1 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex> | ||||||
|  | @ -354,9 +352,10 @@ func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>, | ||||||
|   %3 = shape.assuming %2 -> (tensor<?x32xf16>) { |   %3 = shape.assuming %2 -> (tensor<?x32xf16>) { | ||||||
|     %8 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex> |     %8 = shape.shape_of %arg0 : tensor<?x32xf16> -> tensor<2xindex> | ||||||
|     %9 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex> |     %9 = shape.shape_of %arg1 : tensor<?x32xf16> -> tensor<2xindex> | ||||||
|     %10 = shape.broadcast %8, %9 : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> |     %10 = shape.broadcast %8, %9 : tensor<2xindex>, tensor<2xindex> -> tensor<?xindex> | ||||||
|     %12 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %10) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16> |     %11 = tensor.cast %10 : tensor<?xindex> to tensor<2xindex> | ||||||
|     %13 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %10) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16> |     %12 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %11) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16> | ||||||
|  |     %13 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %11) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<2xindex>) -> tensor<?x32xf16> | ||||||
|     %14 = mhlo.subtract %12, %13 : tensor<?x32xf16> |     %14 = mhlo.subtract %12, %13 : tensor<?x32xf16> | ||||||
|     shape.assuming_yield %14 : tensor<?x32xf16> |     shape.assuming_yield %14 : tensor<?x32xf16> | ||||||
|   } |   } | ||||||
|  | @ -366,9 +365,10 @@ func @sub_sub(%arg0: tensor<?x32xf16>, %arg1 : tensor<?x32xf16>, | ||||||
|   %7 = shape.assuming %6 -> (tensor<?x?x32xf16>) { |   %7 = shape.assuming %6 -> (tensor<?x?x32xf16>) { | ||||||
|     %8 = shape.shape_of %arg2 : tensor<?x?x32xf16> -> tensor<3xindex> |     %8 = shape.shape_of %arg2 : tensor<?x?x32xf16> -> tensor<3xindex> | ||||||
|     %9 = shape.shape_of %3 : tensor<?x32xf16> -> tensor<2xindex> |     %9 = shape.shape_of %3 : tensor<?x32xf16> -> tensor<2xindex> | ||||||
|     %10 = shape.broadcast %8, %9 : tensor<3xindex>, tensor<2xindex> -> tensor<3xindex> |     %10 = shape.broadcast %8, %9 : tensor<3xindex>, tensor<2xindex> -> tensor<?xindex> | ||||||
|     %12 = "mhlo.dynamic_broadcast_in_dim"(%arg2, %10) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<?x?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16> |     %11 = tensor.cast %10 : tensor<?xindex> to tensor<3xindex> | ||||||
|     %13 = "mhlo.dynamic_broadcast_in_dim"(%3, %10) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16> |     %12 = "mhlo.dynamic_broadcast_in_dim"(%arg2, %11) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<?x?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16> | ||||||
|  |     %13 = "mhlo.dynamic_broadcast_in_dim"(%3, %11) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<?x32xf16>, tensor<3xindex>) -> tensor<?x?x32xf16> | ||||||
|     %14 = mhlo.subtract %12, %13 : tensor<?x?x32xf16> |     %14 = mhlo.subtract %12, %13 : tensor<?x?x32xf16> | ||||||
|     shape.assuming_yield %14 : tensor<?x?x32xf16> |     shape.assuming_yield %14 : tensor<?x?x32xf16> | ||||||
|   } |   } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue