[MLIR][HLO] Allow rank specialization clustering with `chlo.broadcast_select` op
PiperOrigin-RevId: 373379990
This commit is contained in:
parent
e260aa771c
commit
596918a6f1
|
@ -694,11 +694,12 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
// Broadcasting select op
|
// Broadcasting select op
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def HLOClient_BroadcastSelectOp : HLOClient_Op<
|
def HLOClient_BroadcastSelectOp : HLOClient_Op<"broadcast_select", [
|
||||||
"broadcast_select",
|
NoSideEffect,
|
||||||
[NoSideEffect,
|
HLOClient_Broadcasting,
|
||||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
HLOClient_BroadcastingElementwise,
|
||||||
["inferReturnTypeComponents"]>]> {
|
DeclareOpInterfaceMethods<InferShapedTypeOpInterface, [
|
||||||
|
"inferReturnTypeComponents"]>]> {
|
||||||
string summary = "Select operator (with optional numpy-style broadcasting)";
|
string summary = "Select operator (with optional numpy-style broadcasting)";
|
||||||
|
|
||||||
string description = [{
|
string description = [{
|
||||||
|
|
|
@ -49,3 +49,20 @@ func @sqrt_ranked(%arg: tensor<3x?xf32>) -> tensor<3x?xf32> {
|
||||||
%2 = "mhlo.sqrt"(%1) : (tensor<3x?xf32>) -> tensor<3x?xf32>
|
%2 = "mhlo.sqrt"(%1) : (tensor<3x?xf32>) -> tensor<3x?xf32>
|
||||||
return %2 : tensor<3x?xf32>
|
return %2 : tensor<3x?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Ternary operation.
|
||||||
|
// CHECK-LABEL: @select_mixed
|
||||||
|
// CHECK-SAME: (%[[PRED:.*]]: tensor<*xi1>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<2xf32>)
|
||||||
|
func @select_mixed(%pred: tensor<*xi1>, %arg1: tensor<*xf32>,
|
||||||
|
%arg2: tensor<2xf32>) -> tensor<*xf32> {
|
||||||
|
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[PRED]], %[[ARG1]], %[[ARG2]])
|
||||||
|
// CHECK: ^bb0(%[[PRED_:.*]]: tensor<*xi1>, %[[ARG1_:.*]]: tensor<*xf32>, %[[ARG2_:.*]]: tensor<2xf32>)
|
||||||
|
// CHECK: %[[TMP:.*]] = chlo.broadcast_select %[[PRED_]], %[[ARG1_]], %[[ARG2_]]
|
||||||
|
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP]])
|
||||||
|
// CHECK: return %[[RES]]
|
||||||
|
%0 = "chlo.broadcast_select"(%pred, %arg1, %arg2)
|
||||||
|
: (tensor<*xi1>, tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32>
|
||||||
|
return %0 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue