diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 503faad..097e029 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -497,7 +497,11 @@ struct ConvertUnrankedDynamicBroadcastNaryOp // Put each subsequent rank specialization inside the else statement of the // previous one. OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); - constexpr int kMaxRankSpecialization = 5; + + // Tensorflow supports up to rank 8 for SelectOp (currently the only op with + // arity > 2 that we support), but only up to rank 5 for binary ops. We want + // to preserve this behavior. + const int kMaxRankSpecialization = operands.size() > 2 ? 8 : 5; for (int i = 2; i < kMaxRankSpecialization; i++) { auto inner_if = createIfOpForRankSpecializedBroadcastAndOp( else_builder, op, greater_rank, i); diff --git a/tests/hlo-transform-unranked.mlir b/tests/hlo-transform-unranked.mlir index 687b0eb..148ab28 100644 --- a/tests/hlo-transform-unranked.mlir +++ b/tests/hlo-transform-unranked.mlir @@ -383,3 +383,6 @@ func @selectUnrankedUnrankedUnranked( // CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor // CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor // CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor +// CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor +// CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor +// CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor