From cc607bc72db476dcaf6058a2eef9a4d8d408c45f Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 8 Apr 2021 04:55:00 -0700 Subject: [PATCH] Support up to rank 8 in rank specialization for SelectOp. PiperOrigin-RevId: 367406557 --- lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc | 6 +++++- tests/hlo-transform-unranked.mlir | 3 +++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 503faad..097e029 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -497,7 +497,11 @@ struct ConvertUnrankedDynamicBroadcastNaryOp // Put each subsequent rank specialization inside the else statement of the // previous one. OpBuilder else_builder = if_op.getElseBodyBuilder(rewriter.getListener()); - constexpr int kMaxRankSpecialization = 5; + + // Tensorflow supports up to rank 8 for SelectOp (currently the only op with + // arity > 2 that we support), but only up to rank 5 for binary ops. We want + // to preserve this behavior. + const int kMaxRankSpecialization = operands.size() > 2 ? 8 : 5; for (int i = 2; i < kMaxRankSpecialization; i++) { auto inner_if = createIfOpForRankSpecializedBroadcastAndOp( else_builder, op, greater_rank, i); diff --git a/tests/hlo-transform-unranked.mlir b/tests/hlo-transform-unranked.mlir index 687b0eb..148ab28 100644 --- a/tests/hlo-transform-unranked.mlir +++ b/tests/hlo-transform-unranked.mlir @@ -383,3 +383,6 @@ func @selectUnrankedUnrankedUnranked( // CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor // CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor // CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor +// CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor +// CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor +// CHECK: chlo.broadcast_select {{.*}} : (tensor, tensor, tensor) -> tensor