Support up to rank 8 in rank specialization for SelectOp.
PiperOrigin-RevId: 367406557
This commit is contained in:
parent
f068d26843
commit
cc607bc72d
|
@ -497,7 +497,11 @@ struct ConvertUnrankedDynamicBroadcastNaryOp
|
||||||
// Put each subsequent rank specialization inside the else statement of the
|
// Put each subsequent rank specialization inside the else statement of the
|
||||||
// previous one.
|
// previous one.
|
||||||
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
|
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
|
||||||
constexpr int kMaxRankSpecialization = 5;
|
|
||||||
|
// Tensorflow supports up to rank 8 for SelectOp (currently the only op with
|
||||||
|
// arity > 2 that we support), but only up to rank 5 for binary ops. We want
|
||||||
|
// to preserve this behavior.
|
||||||
|
const int kMaxRankSpecialization = operands.size() > 2 ? 8 : 5;
|
||||||
for (int i = 2; i < kMaxRankSpecialization; i++) {
|
for (int i = 2; i < kMaxRankSpecialization; i++) {
|
||||||
auto inner_if = createIfOpForRankSpecializedBroadcastAndOp(
|
auto inner_if = createIfOpForRankSpecializedBroadcastAndOp(
|
||||||
else_builder, op, greater_rank, i);
|
else_builder, op, greater_rank, i);
|
||||||
|
|
|
@ -383,3 +383,6 @@ func @selectUnrankedUnrankedUnranked(
|
||||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?xi1>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?xi1>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?xi1>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?xi1>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||||
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?xi1>, tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?xi1>, tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||||
|
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
|
||||||
|
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?x?xf32>
|
||||||
|
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?x?x?xf32>
|
||||||
|
|
Loading…
Reference in New Issue