[MLIR][MHLO] Apply patterns in MoveUpDynamicBroadcastsForFusionPass greedily
PiperOrigin-RevId: 365556488
This commit is contained in:
		
							parent
							
								
									238c1d8a92
								
							
						
					
					
						commit
						fb819c1de8
					
				|  | @ -33,76 +33,60 @@ limitations under the License. | |||
| #include "mlir/IR/PatternMatch.h" | ||||
| #include "mlir/Interfaces/InferTypeOpInterface.h" | ||||
| #include "mlir/Pass/Pass.h" | ||||
| #include "mlir/Transforms/DialectConversion.h" | ||||
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||||
| 
 | ||||
| namespace mlir { | ||||
| namespace mhlo { | ||||
| namespace { | ||||
| 
 | ||||
| bool IsShapeOfOpMovable(Value arg) { | ||||
|   return arg.getDefiningOp<InferShapedTypeOpInterface>(); | ||||
| } | ||||
| 
 | ||||
| struct ShapeOfOpConversion : public OpConversionPattern<shape::ShapeOfOp> { | ||||
|   explicit ShapeOfOpConversion(MLIRContext *context) | ||||
|       : OpConversionPattern<shape::ShapeOfOp>(context) { | ||||
| struct ShapeReificationPattern : public OpRewritePattern<shape::ShapeOfOp> { | ||||
|   explicit ShapeReificationPattern(MLIRContext *context) | ||||
|       : OpRewritePattern<shape::ShapeOfOp>(context) { | ||||
|     // Recursively reify until we hit an op that doesn't support it.
 | ||||
|     setHasBoundedRewriteRecursion(); | ||||
|   } | ||||
| 
 | ||||
|   LogicalResult matchAndRewrite( | ||||
|       shape::ShapeOfOp op, ArrayRef<Value> operands, | ||||
|       ConversionPatternRewriter &rewriter) const override { | ||||
|     shape::ShapeOfOp::Adaptor transformed(operands); | ||||
| 
 | ||||
|   LogicalResult matchAndRewrite(shape::ShapeOfOp op, | ||||
|                                 PatternRewriter &rewriter) const override { | ||||
|     // Only reify shape computation if operand allows for it.
 | ||||
|     if (!IsShapeOfOpMovable(transformed.arg())) return failure(); | ||||
|     auto shape_origin = op.arg().getDefiningOp<InferShapedTypeOpInterface>(); | ||||
|     if (!shape_origin) return failure(); | ||||
| 
 | ||||
|     auto shape_origin = | ||||
|         transformed.arg().getDefiningOp<InferShapedTypeOpInterface>(); | ||||
|     llvm::SmallVector<Value, 1> reified_shapes; | ||||
|     if (failed(shape_origin.reifyReturnTypeShapes(rewriter, reified_shapes))) | ||||
|     llvm::SmallVector<Value, 1> reifications; | ||||
|     if (failed(shape_origin.reifyReturnTypeShapes(rewriter, reifications))) | ||||
|       return failure(); | ||||
|     assert(reifications.size() == 1); | ||||
|     Value reified_shape = reifications.front(); | ||||
| 
 | ||||
|     assert(reified_shapes.size() == 1); | ||||
|     Value reified_shape = reified_shapes.front(); | ||||
|     // Insert cast if needed.
 | ||||
|     if (reified_shape.getType() != op.getType()) { | ||||
|       reified_shape = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(), | ||||
|                                                       reified_shape); | ||||
|     } | ||||
| 
 | ||||
|     rewriter.replaceOp(op, reified_shapes.front()); | ||||
|     rewriter.replaceOp(op, reified_shape); | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| // We can only move up broadcasting ops that apply to the result of a
 | ||||
| // shape-preserving operation.
 | ||||
| bool isDynamicBroadcastInDimOpMovable(Value operand) { | ||||
|   Operation *producer_op = operand.getDefiningOp(); | ||||
|   return producer_op != nullptr && | ||||
|          producer_op->hasTrait<OpTrait::SameOperandsAndResultShape>() && | ||||
|          producer_op->hasTrait<OpTrait::Elementwise>(); | ||||
| } | ||||
| 
 | ||||
| // TODO(frgossen): Only move up broadcasting operations if there is a consumer.
 | ||||
| struct MoveUpBroadcastInDimOpConversion | ||||
|     : public OpConversionPattern<DynamicBroadcastInDimOp> { | ||||
|   explicit MoveUpBroadcastInDimOpConversion(MLIRContext *context) | ||||
|       : OpConversionPattern<DynamicBroadcastInDimOp>(context) {} | ||||
| struct MoveUpBroadcastInDimOpPattern | ||||
|     : public OpRewritePattern<DynamicBroadcastInDimOp> { | ||||
|   using OpRewritePattern<DynamicBroadcastInDimOp>::OpRewritePattern; | ||||
| 
 | ||||
|   LogicalResult matchAndRewrite( | ||||
|       DynamicBroadcastInDimOp bcast_op, ArrayRef<Value> operands, | ||||
|       ConversionPatternRewriter &rewriter) const override { | ||||
|     DynamicBroadcastInDimOp::Adaptor transformed(operands); | ||||
|     if (!isDynamicBroadcastInDimOpMovable(transformed.operand())) | ||||
|   LogicalResult matchAndRewrite(DynamicBroadcastInDimOp bcast_op, | ||||
|                                 PatternRewriter &rewriter) const override { | ||||
|     Operation *producer_op = bcast_op.operand().getDefiningOp(); | ||||
|     if (!producer_op || | ||||
|         !producer_op->hasTrait<OpTrait::SameOperandsAndResultShape>() || | ||||
|         !producer_op->hasTrait<OpTrait::Elementwise>()) { | ||||
|       return failure(); | ||||
|     } | ||||
| 
 | ||||
|     // Materialize broadcast on operands.
 | ||||
|     SmallVector<Value, 2> bcasted_operands; | ||||
|     Location loc = bcast_op.getLoc(); | ||||
|     ArrayRef<int64_t> ty_shape = bcast_op.getType().getShape(); | ||||
|     Operation *producer_op = transformed.operand().getDefiningOp(); | ||||
|     for (Value operand : producer_op->getOperands()) { | ||||
|       // The broadcast only works on ranked operations.
 | ||||
|       auto operand_ty = operand.getType().dyn_cast<RankedTensorType>(); | ||||
|  | @ -114,7 +98,7 @@ struct MoveUpBroadcastInDimOpConversion | |||
|       auto bcasted_operand_ty = | ||||
|           RankedTensorType::get(ty_shape, operand_ty.getElementType()); | ||||
|       bcasted_operands.push_back(rewriter.create<DynamicBroadcastInDimOp>( | ||||
|           loc, bcasted_operand_ty, operand, transformed.output_dimensions(), | ||||
|           loc, bcasted_operand_ty, operand, bcast_op.output_dimensions(), | ||||
|           bcast_op.broadcast_dimensions())); | ||||
|     } | ||||
| 
 | ||||
|  | @ -140,18 +124,14 @@ struct MoveUpDynamicBroadcastsForFusionPass | |||
|   } | ||||
| 
 | ||||
|   void runOnFunction() override { | ||||
|     // Setup target legality.
 | ||||
|     MLIRContext &ctx = getContext(); | ||||
|     ConversionTarget target(ctx); | ||||
|     PopulateMoveUpDynamicBroadcastsForFusionLegality(&target); | ||||
| 
 | ||||
|     // Populate rewrite patterns.
 | ||||
|     OwningRewritePatternList patterns(&ctx); | ||||
|     mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(&ctx, &patterns); | ||||
|     MLIRContext *ctx = &getContext(); | ||||
|     RewritePatternSet patterns(ctx); | ||||
|     mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(ctx, &patterns); | ||||
| 
 | ||||
|     // Apply transformation.
 | ||||
|     if (failed(applyPartialConversion(getFunction(), target, | ||||
|                                       std::move(patterns)))) { | ||||
|     if (failed( | ||||
|             applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) { | ||||
|       return signalPassFailure(); | ||||
|     } | ||||
|   } | ||||
|  | @ -159,23 +139,11 @@ struct MoveUpDynamicBroadcastsForFusionPass | |||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| void PopulateMoveUpDynamicBroadcastsForFusionLegality( | ||||
|     ConversionTarget *target) { | ||||
|   target->addLegalDialect<MhloDialect, StandardOpsDialect, shape::ShapeDialect, | ||||
|                           tensor::TensorDialect>(); | ||||
|   target->addDynamicallyLegalOp<shape::ShapeOfOp>( | ||||
|       [](shape::ShapeOfOp op) { return !IsShapeOfOpMovable(op.arg()); }); | ||||
|   target->addDynamicallyLegalOp<DynamicBroadcastInDimOp>( | ||||
|       [](DynamicBroadcastInDimOp op) { | ||||
|         return !isDynamicBroadcastInDimOpMovable(op.operand()); | ||||
|       }); | ||||
| } | ||||
| 
 | ||||
| void PopulateMoveUpDynamicBroadcastsForFusionPatterns( | ||||
|     MLIRContext *context, OwningRewritePatternList *patterns) { | ||||
|   // clang-format off
 | ||||
|   patterns->insert<ShapeOfOpConversion, | ||||
|                    MoveUpBroadcastInDimOpConversion>(context); | ||||
|   patterns->insert<ShapeReificationPattern, | ||||
|                    MoveUpBroadcastInDimOpPattern>(context); | ||||
|   // clang-format on
 | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -5,7 +5,8 @@ | |||
| // CHECK-SAME: (%[[ARG:.*]]: tensor<?x32xi16>) | ||||
| func @shape_of_unary(%arg : tensor<?x32xi16>) { | ||||
|   // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<?x32xi16> -> tensor<2xindex> | ||||
|   // CHECK: "use"(%[[SHAPE]]) | ||||
|   // CHECK: %[[CASTED:.*]] = tensor.cast %[[SHAPE]] : tensor<2xindex> to tensor<?xindex> | ||||
|   // CHECK: "use"(%[[CASTED]]) | ||||
|   %0 = "mhlo.convert"(%arg) : (tensor<?x32xi16>) -> tensor<?x32xf16> | ||||
|   %1 = shape.shape_of %0 : tensor<?x32xf16> -> tensor<?xindex> | ||||
|   "use"(%1) : (tensor<?xindex>) -> () | ||||
|  | @ -19,7 +20,8 @@ func @shape_of_unary(%arg : tensor<?x32xi16>) { | |||
| // CHECK-SAME: (%[[ARG0:.*]]: tensor<?x32xf16>, %[[ARG1:.*]]: tensor<?x32xf16>) | ||||
| func @shape_of_nary(%arg0 : tensor<?x32xf16>, %arg1 : tensor<?x32xf16>) { | ||||
|   // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[ARG0]] : tensor<?x32xf16> -> tensor<2xindex> | ||||
|   // CHECK: "use"(%[[SHAPE]]) | ||||
|   // CHECK: %[[CASTED:.*]] = tensor.cast %[[SHAPE]] : tensor<2xindex> to tensor<?xindex> | ||||
|   // CHECK: "use"(%[[CASTED]]) | ||||
|   %0 = mhlo.subtract %arg0, %arg1 : tensor<?x32xf16> | ||||
|   %1 = mhlo.subtract %0, %arg1 : tensor<?x32xf16> | ||||
|   %2 = shape.shape_of %1 : tensor<?x32xf16> -> tensor<?xindex> | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue