Support up to rank 8 in rank specialization for SelectOp.
PiperOrigin-RevId: 367406557
This commit is contained in:
parent
f068d26843
commit
cc607bc72d
|
@ -497,7 +497,11 @@ struct ConvertUnrankedDynamicBroadcastNaryOp
|
|||
// Put each subsequent rank specialization inside the else statement of the
|
||||
// previous one.
|
||||
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
|
||||
constexpr int kMaxRankSpecialization = 5;
|
||||
|
||||
// Tensorflow supports up to rank 8 for SelectOp (currently the only op with
|
||||
// arity > 2 that we support), but only up to rank 5 for binary ops. We want
|
||||
// to preserve this behavior.
|
||||
const int kMaxRankSpecialization = operands.size() > 2 ? 8 : 5;
|
||||
for (int i = 2; i < kMaxRankSpecialization; i++) {
|
||||
auto inner_if = createIfOpForRankSpecializedBroadcastAndOp(
|
||||
else_builder, op, greater_rank, i);
|
||||
|
|
|
@ -383,3 +383,6 @@ func @selectUnrankedUnrankedUnranked(
|
|||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?xi1>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?xi1>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?xi1>, tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
|
||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?x?xf32>
|
||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?x?x?xf32>
|
||||
|
|
Loading…
Reference in New Issue