[MLIR][HLO] Allow rank specialization clustering with `chlo.broadcast_select` op

PiperOrigin-RevId: 373379990
This commit is contained in:
A. Unique TensorFlower 2021-05-12 08:55:45 -07:00 committed by TensorFlow MLIR Team
parent e260aa771c
commit 596918a6f1
2 changed files with 23 additions and 5 deletions

View File

@ -694,11 +694,12 @@ def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
// Broadcasting select op // Broadcasting select op
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def HLOClient_BroadcastSelectOp : HLOClient_Op< def HLOClient_BroadcastSelectOp : HLOClient_Op<"broadcast_select", [
"broadcast_select", NoSideEffect,
[NoSideEffect, HLOClient_Broadcasting,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface, HLOClient_BroadcastingElementwise,
["inferReturnTypeComponents"]>]> { DeclareOpInterfaceMethods<InferShapedTypeOpInterface, [
"inferReturnTypeComponents"]>]> {
string summary = "Select operator (with optional numpy-style broadcasting)"; string summary = "Select operator (with optional numpy-style broadcasting)";
string description = [{ string description = [{

View File

@ -49,3 +49,20 @@ func @sqrt_ranked(%arg: tensor<3x?xf32>) -> tensor<3x?xf32> {
%2 = "mhlo.sqrt"(%1) : (tensor<3x?xf32>) -> tensor<3x?xf32> %2 = "mhlo.sqrt"(%1) : (tensor<3x?xf32>) -> tensor<3x?xf32>
return %2 : tensor<3x?xf32> return %2 : tensor<3x?xf32>
} }
// -----
// Ternary operation.
// CHECK-LABEL: @select_mixed
// CHECK-SAME: (%[[PRED:.*]]: tensor<*xi1>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<2xf32>)
func @select_mixed(%pred: tensor<*xi1>, %arg1: tensor<*xf32>,
%arg2: tensor<2xf32>) -> tensor<*xf32> {
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[PRED]], %[[ARG1]], %[[ARG2]])
// CHECK: ^bb0(%[[PRED_:.*]]: tensor<*xi1>, %[[ARG1_:.*]]: tensor<*xf32>, %[[ARG2_:.*]]: tensor<2xf32>)
// CHECK: %[[TMP:.*]] = chlo.broadcast_select %[[PRED_]], %[[ARG1_]], %[[ARG2_]]
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP]])
// CHECK: return %[[RES]]
%0 = "chlo.broadcast_select"(%pred, %arg1, %arg2)
: (tensor<*xi1>, tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}