Add Broadcasting and BroadcastingElementwise traits to ConstantLikeOp.

This allows to include such ops in rank specialization clusters.

PiperOrigin-RevId: 378380915
This commit is contained in:
Adrian Kuegel 2021-06-09 05:08:36 -07:00 committed by TensorFlow MLIR Team
parent b9e45007d5
commit b6d8160611
2 changed files with 18 additions and 2 deletions

View File

@ -581,8 +581,8 @@ def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan",
}
def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like",
[NoSideEffect, SameOperandsAndResultShape,
InferTypeOpInterface,
[NoSideEffect, HLOClient_Broadcasting, HLOClient_BroadcastingElementwise,
SameOperandsAndResultShape, InferTypeOpInterface,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
NativeOpTrait<"InferTensorType">]> {

View File

@ -19,6 +19,22 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
return %1 : tensor<*xf32>
}
// CHECK-LABEL: @compare_const_like
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>)
func @compare_const_like(%arg0 : tensor<*xf32>) -> tensor<*xi1> {
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]]) ( {
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<*xf32>):
// CHECK: %[[ZERO:.*]] = "chlo.constant_like"(%[[ARG1]]) {value = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: %[[CMP_GT:.*]] = chlo.broadcast_compare %[[ARG1]], %[[ZERO]] {comparison_direction = "GT"} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xi1>
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[CMP_GT]]) : (tensor<*xi1>) -> ()
// CHECK: }) : (tensor<*xf32>) -> tensor<*xi1>
// CHECK: return %[[RES]] : tensor<*xi1>
%0 = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
%1 = chlo.broadcast_compare %arg0, %0 {comparison_direction = "GT"}
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xi1>
return %1 : tensor<*xi1>
}
// CHECK-SCF-LABEL: @add_mul
// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>)
// CHECK-SCF-DAG: %[[C1:.*]] = constant 1