diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index b7c5417..c9e46ea 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -694,11 +694,12 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp< // Broadcasting select op //===----------------------------------------------------------------------===// -def HLOClient_BroadcastSelectOp : HLOClient_Op< - "broadcast_select", - [NoSideEffect, - DeclareOpInterfaceMethods]> { +def HLOClient_BroadcastSelectOp : HLOClient_Op<"broadcast_select", [ + NoSideEffect, + HLOClient_Broadcasting, + HLOClient_BroadcastingElementwise, + DeclareOpInterfaceMethods]> { string summary = "Select operator (with optional numpy-style broadcasting)"; string description = [{ diff --git a/tests/rank-specialization.mlir b/tests/rank-specialization.mlir index eb505cf..6258ef0 100644 --- a/tests/rank-specialization.mlir +++ b/tests/rank-specialization.mlir @@ -49,3 +49,20 @@ func @sqrt_ranked(%arg: tensor<3x?xf32>) -> tensor<3x?xf32> { %2 = "mhlo.sqrt"(%1) : (tensor<3x?xf32>) -> tensor<3x?xf32> return %2 : tensor<3x?xf32> } + +// ----- + +// Ternary operation. +// CHECK-LABEL: @select_mixed +// CHECK-SAME: (%[[PRED:.*]]: tensor<*xi1>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<2xf32>) +func @select_mixed(%pred: tensor<*xi1>, %arg1: tensor<*xf32>, + %arg2: tensor<2xf32>) -> tensor<*xf32> { + // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[PRED]], %[[ARG1]], %[[ARG2]]) + // CHECK: ^bb0(%[[PRED_:.*]]: tensor<*xi1>, %[[ARG1_:.*]]: tensor<*xf32>, %[[ARG2_:.*]]: tensor<2xf32>) + // CHECK: %[[TMP:.*]] = chlo.broadcast_select %[[PRED_]], %[[ARG1_]], %[[ARG2_]] + // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP]]) + // CHECK: return %[[RES]] + %0 = "chlo.broadcast_select"(%pred, %arg1, %arg2) + : (tensor<*xi1>, tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +}