[MLIR][HLO] Implement `RegionBranchOpInterface` for rank specialization cluster
PiperOrigin-RevId: 373163196
This commit is contained in:
		
							parent
							
								
									1432db02e7
								
							
						
					
					
						commit
						7f7a86ad0d
					
				|  | @ -757,6 +757,7 @@ def HLOClient_MinimumBroadcastShapesOp : | ||||||
| 
 | 
 | ||||||
| def HLOClient_RankSpecializationClusterOp | def HLOClient_RankSpecializationClusterOp | ||||||
|     : HLOClient_Op<"rank_specialization_cluster", [ |     : HLOClient_Op<"rank_specialization_cluster", [ | ||||||
|  |     DeclareOpInterfaceMethods<RegionBranchOpInterface>, | ||||||
|     SingleBlockImplicitTerminator<"RankSpecializationClusterYieldOp">, |     SingleBlockImplicitTerminator<"RankSpecializationClusterYieldOp">, | ||||||
|     RecursiveSideEffects]> { |     RecursiveSideEffects]> { | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -422,7 +422,21 @@ LogicalResult BroadcastSelectOp::inferReturnTypeComponents( | ||||||
| // RankSpecializationClusterOp
 | // RankSpecializationClusterOp
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
| 
 | 
 | ||||||
|  | void RankSpecializationClusterOp::getSuccessorRegions( | ||||||
|  |     Optional<unsigned> index, ArrayRef<Attribute> operands, | ||||||
|  |     SmallVectorImpl<RegionSuccessor>& regions) { | ||||||
|  |   // RankSpecializationClusterOp has unconditional control flows into the region
 | ||||||
|  |   // and back to the parent, so return the correct RegionSuccessor purely based
 | ||||||
|  |   // on the index being None or 0.
 | ||||||
|  |   if (index.hasValue()) { | ||||||
|  |     regions.push_back(RegionSuccessor(getResults())); | ||||||
|  |     return; | ||||||
|  |   } | ||||||
|  |   regions.push_back(RegionSuccessor(&body())); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| static LogicalResult Verify(RankSpecializationClusterOp op) { | static LogicalResult Verify(RankSpecializationClusterOp op) { | ||||||
|  |   if (failed(RegionBranchOpInterface::verifyTypes(op))) return failure(); | ||||||
|   if (op.body().getArgumentTypes() != op.getOperandTypes()) |   if (op.body().getArgumentTypes() != op.getOperandTypes()) | ||||||
|     return op.emitOpError() << "block argument types must match operand types"; |     return op.emitOpError() << "block argument types must match operand types"; | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -42,6 +42,19 @@ func @rank_specialization_cluster(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, | ||||||
| 
 | 
 | ||||||
| // ----- | // ----- | ||||||
| 
 | 
 | ||||||
|  | func @rank_specialization_cluster(%arg0 : tensor<*xf32>, | ||||||
|  |     %arg1 : tensor<*xf32>) -> tensor<*xf32> { | ||||||
|  |   // expected-error @+1{{source has 2 operands, but target successor needs 1}} | ||||||
|  |   %0 = "chlo.rank_specialization_cluster"(%arg0, %arg1) ({ | ||||||
|  |   ^bb0(%arg0_ : tensor<*xf32>, %arg1_ : tensor<*xf32>): | ||||||
|  |     "chlo.rank_specialization_cluster_yield"(%arg0_, %arg1_) | ||||||
|  |         : (tensor<*xf32>, tensor<*xf32>) -> () | ||||||
|  |   }) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> | ||||||
|  |   return %0 : tensor<*xf32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
| func @rank_specialization_cluster(%arg0 : tensor<*xf32>) -> tensor<*xf32> { | func @rank_specialization_cluster(%arg0 : tensor<*xf32>) -> tensor<*xf32> { | ||||||
|   // expected-error @+1{{block argument types must match operand types}} |   // expected-error @+1{{block argument types must match operand types}} | ||||||
|   %0 = "chlo.rank_specialization_cluster"(%arg0) ({ |   %0 = "chlo.rank_specialization_cluster"(%arg0) ({ | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue