[MLIR][HLO] Add mixed test for `rank-specialization-cluster` pass
PiperOrigin-RevId: 373762814
This commit is contained in:
parent
9248f0a182
commit
76341f3720
|
@ -85,3 +85,31 @@ func @tan(%arg : tensor<*xf32>) -> tensor<*xf32> {
|
|||
%2 = chlo.tan %1 : tensor<*xf32> -> tensor<*xf32>
|
||||
return %2 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Composition of unary/binary CHLO and unary MHLO ops.
|
||||
// CHECK-LABEL: @mixed
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>)
|
||||
func @mixed(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>)
|
||||
-> tensor<*xf32> {
|
||||
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG2]], %[[ARG1]], %[[ARG0]])
|
||||
// CHECK: ^bb0(%[[ARG2_:.*]]: tensor<*xf32>, %[[ARG1_:.*]]: tensor<*xf32>, %[[ARG0_:.*]]: tensor<*xf32>)
|
||||
// CHECK: %[[TMP0:.*]] = chlo.tan %[[ARG0_]]
|
||||
// CHECK: %[[TMP1:.*]] = "mhlo.sqrt"(%[[ARG1_]])
|
||||
// CHECK: %[[TMP2:.*]] = chlo.broadcast_multiply %[[TMP0]], %[[TMP1]]
|
||||
// CHECK: %[[TMP3:.*]] = chlo.broadcast_add %[[TMP2]], %[[ARG2_]]
|
||||
// CHECK: %[[TMP4:.*]] = "mhlo.sqrt"(%[[TMP3]])
|
||||
// CHECK: %[[TMP5:.*]] = chlo.tan %[[TMP4]]
|
||||
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP5]])
|
||||
// CHECK: return %[[RES]]
|
||||
%0 = chlo.tan %arg0 : tensor<*xf32> -> tensor<*xf32>
|
||||
%1 = "mhlo.sqrt"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%2 = chlo.broadcast_multiply %0, %1
|
||||
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
%3 = chlo.broadcast_add %2, %arg2
|
||||
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
%4 = "mhlo.sqrt"(%3) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%5 = chlo.tan %4 : tensor<*xf32> -> tensor<*xf32>
|
||||
return %5 : tensor<*xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue