[MLIR][HLO] Allow rank specialization clustering with `chlo.broadcast_select` op
PiperOrigin-RevId: 373379990
This commit is contained in:
parent
e260aa771c
commit
596918a6f1
|
@ -694,11 +694,12 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
|
|||
// Broadcasting select op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def HLOClient_BroadcastSelectOp : HLOClient_Op<
|
||||
"broadcast_select",
|
||||
[NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["inferReturnTypeComponents"]>]> {
|
||||
def HLOClient_BroadcastSelectOp : HLOClient_Op<"broadcast_select", [
|
||||
NoSideEffect,
|
||||
HLOClient_Broadcasting,
|
||||
HLOClient_BroadcastingElementwise,
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface, [
|
||||
"inferReturnTypeComponents"]>]> {
|
||||
string summary = "Select operator (with optional numpy-style broadcasting)";
|
||||
|
||||
string description = [{
|
||||
|
|
|
@ -49,3 +49,20 @@ func @sqrt_ranked(%arg: tensor<3x?xf32>) -> tensor<3x?xf32> {
|
|||
%2 = "mhlo.sqrt"(%1) : (tensor<3x?xf32>) -> tensor<3x?xf32>
|
||||
return %2 : tensor<3x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Ternary operation.
|
||||
// CHECK-LABEL: @select_mixed
|
||||
// CHECK-SAME: (%[[PRED:.*]]: tensor<*xi1>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<2xf32>)
|
||||
func @select_mixed(%pred: tensor<*xi1>, %arg1: tensor<*xf32>,
|
||||
%arg2: tensor<2xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[PRED]], %[[ARG1]], %[[ARG2]])
|
||||
// CHECK: ^bb0(%[[PRED_:.*]]: tensor<*xi1>, %[[ARG1_:.*]]: tensor<*xf32>, %[[ARG2_:.*]]: tensor<2xf32>)
|
||||
// CHECK: %[[TMP:.*]] = chlo.broadcast_select %[[PRED_]], %[[ARG1_]], %[[ARG2_]]
|
||||
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP]])
|
||||
// CHECK: return %[[RES]]
|
||||
%0 = "chlo.broadcast_select"(%pred, %arg1, %arg2)
|
||||
: (tensor<*xi1>, tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue