Support up to rank 8 in rank specialization for SelectOp.

PiperOrigin-RevId: 367406557
This commit is contained in:
Adrian Kuegel 2021-04-08 04:55:00 -07:00 committed by TensorFlow MLIR Team
parent f068d26843
commit cc607bc72d
2 changed files with 8 additions and 1 deletions

View File

@ -497,7 +497,11 @@ struct ConvertUnrankedDynamicBroadcastNaryOp
// Put each subsequent rank specialization inside the else statement of the // Put each subsequent rank specialization inside the else statement of the
// previous one. // previous one.
OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener());
constexpr int kMaxRankSpecialization = 5;
// Tensorflow supports up to rank 8 for SelectOp (currently the only op with
// arity > 2 that we support), but only up to rank 5 for binary ops. We want
// to preserve this behavior.
const int kMaxRankSpecialization = operands.size() > 2 ? 8 : 5;
for (int i = 2; i < kMaxRankSpecialization; i++) { for (int i = 2; i < kMaxRankSpecialization; i++) {
auto inner_if = createIfOpForRankSpecializedBroadcastAndOp( auto inner_if = createIfOpForRankSpecializedBroadcastAndOp(
else_builder, op, greater_rank, i); else_builder, op, greater_rank, i);

View File

@ -383,3 +383,6 @@ func @selectUnrankedUnrankedUnranked(
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?xi1>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32> // CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?xi1>, tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?xi1>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> // CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?xi1>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?xi1>, tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> // CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?xi1>, tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?xf32>
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?x?xf32>
// CHECK: chlo.broadcast_select {{.*}} : (tensor<?x?x?x?x?x?x?x?xi1>, tensor<?x?x?x?x?x?x?x?xf32>, tensor<?x?x?x?x?x?x?x?xf32>) -> tensor<?x?x?x?x?x?x?x?xf32>