Add Broadcasting and BroadcastingElementwise traits to ConstantLikeOp.
This allows to include such ops in rank specialization clusters. PiperOrigin-RevId: 378380915
This commit is contained in:
parent
b9e45007d5
commit
b6d8160611
|
@ -581,8 +581,8 @@ def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan",
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like",
|
def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like",
|
||||||
[NoSideEffect, SameOperandsAndResultShape,
|
[NoSideEffect, HLOClient_Broadcasting, HLOClient_BroadcastingElementwise,
|
||||||
InferTypeOpInterface,
|
SameOperandsAndResultShape, InferTypeOpInterface,
|
||||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||||
["inferReturnTypeComponents"]>,
|
["inferReturnTypeComponents"]>,
|
||||||
NativeOpTrait<"InferTensorType">]> {
|
NativeOpTrait<"InferTensorType">]> {
|
||||||
|
|
|
@ -19,6 +19,22 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
|
||||||
return %1 : tensor<*xf32>
|
return %1 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @compare_const_like
|
||||||
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>)
|
||||||
|
func @compare_const_like(%arg0 : tensor<*xf32>) -> tensor<*xi1> {
|
||||||
|
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]]) ( {
|
||||||
|
// CHECK: ^bb0(%[[ARG1:.*]]: tensor<*xf32>):
|
||||||
|
// CHECK: %[[ZERO:.*]] = "chlo.constant_like"(%[[ARG1]]) {value = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
// CHECK: %[[CMP_GT:.*]] = chlo.broadcast_compare %[[ARG1]], %[[ZERO]] {comparison_direction = "GT"} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xi1>
|
||||||
|
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[CMP_GT]]) : (tensor<*xi1>) -> ()
|
||||||
|
// CHECK: }) : (tensor<*xf32>) -> tensor<*xi1>
|
||||||
|
// CHECK: return %[[RES]] : tensor<*xi1>
|
||||||
|
%0 = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
%1 = chlo.broadcast_compare %arg0, %0 {comparison_direction = "GT"}
|
||||||
|
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xi1>
|
||||||
|
return %1 : tensor<*xi1>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-SCF-LABEL: @add_mul
|
// CHECK-SCF-LABEL: @add_mul
|
||||||
// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>)
|
// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>)
|
||||||
// CHECK-SCF-DAG: %[[C1:.*]] = constant 1
|
// CHECK-SCF-DAG: %[[C1:.*]] = constant 1
|
||||||
|
|
Loading…
Reference in New Issue