Support up to rank 8 in rank specialization for SelectOp.
PiperOrigin-RevId: 367406557
This commit is contained in:
		
							parent
							
								
									f068d26843
								
							
						
					
					
						commit
						cc607bc72d
					
				|  | @ -497,7 +497,11 @@ struct ConvertUnrankedDynamicBroadcastNaryOp | ||||||
|     // Put each subsequent rank specialization inside the else statement of the
 |     // Put each subsequent rank specialization inside the else statement of the
 | ||||||
|     // previous one.
 |     // previous one.
 | ||||||
|     OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); |     OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); | ||||||
|     constexpr int kMaxRankSpecialization = 5; | 
 | ||||||
|  |     // Tensorflow supports up to rank 8 for SelectOp (currently the only op with
 | ||||||
|  |     // arity > 2 that we support), but only up to rank 5 for binary ops. We want
 | ||||||
|  |     // to preserve this behavior.
 | ||||||
|  |     const int kMaxRankSpecialization = operands.size() > 2 ? 8 : 5; | ||||||
|     for (int i = 2; i < kMaxRankSpecialization; i++) { |     for (int i = 2; i < kMaxRankSpecialization; i++) { | ||||||
|       auto inner_if = createIfOpForRankSpecializedBroadcastAndOp( |       auto inner_if = createIfOpForRankSpecializedBroadcastAndOp( | ||||||
|           else_builder, op, greater_rank, i); |           else_builder, op, greater_rank, i); | ||||||
|  |  | ||||||
|  | @ -383,3 +383,6 @@ func @selectUnrankedUnrankedUnranked( | ||||||
| // CHECK:      chlo.broadcast_select {{.*}} : (tensor<?x?x?xi1>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32> | // CHECK:      chlo.broadcast_select {{.*}} : (tensor<?x?x?xi1>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32> | ||||||
| // CHECK:      chlo.broadcast_select {{.*}} : (tensor<?x?x?x?xi1>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> | // CHECK:      chlo.broadcast_select {{.*}} : (tensor<?x?x?x?xi1>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> | ||||||
| // CHECK:      chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?xi1>, tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> | // CHECK:      chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?xi1>, tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> | ||||||
|  | // CHECK:      chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32> | ||||||
|  | // CHECK:      chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?x?xf32> | ||||||
|  | // CHECK:      chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?x?x?xf32> | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue