Support up to rank 8 in rank specialization for SelectOp.
PiperOrigin-RevId: 367406557
This commit is contained in:
		
							parent
							
								
									f068d26843
								
							
						
					
					
						commit
						cc607bc72d
					
				|  | @ -497,7 +497,11 @@ struct ConvertUnrankedDynamicBroadcastNaryOp | |||
|     // Put each subsequent rank specialization inside the else statement of the
 | ||||
|     // previous one.
 | ||||
|     OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); | ||||
|     constexpr int kMaxRankSpecialization = 5; | ||||
| 
 | ||||
|     // Tensorflow supports up to rank 8 for SelectOp (currently the only op with
 | ||||
|     // arity > 2 that we support), but only up to rank 5 for binary ops. We want
 | ||||
|     // to preserve this behavior.
 | ||||
|     const int kMaxRankSpecialization = operands.size() > 2 ? 8 : 5; | ||||
|     for (int i = 2; i < kMaxRankSpecialization; i++) { | ||||
|       auto inner_if = createIfOpForRankSpecializedBroadcastAndOp( | ||||
|           else_builder, op, greater_rank, i); | ||||
|  |  | |||
|  | @ -383,3 +383,6 @@ func @selectUnrankedUnrankedUnranked( | |||
| // CHECK:      chlo.broadcast_select {{.*}} : (tensor<?x?x?xi1>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32> | ||||
| // CHECK:      chlo.broadcast_select {{.*}} : (tensor<?x?x?x?xi1>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> | ||||
| // CHECK:      chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?xi1>, tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> | ||||
| // CHECK:      chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32> | ||||
| // CHECK:      chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?x?xf32> | ||||
| // CHECK:      chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?x?x?xf32> | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue