[MLIR][HLO] Add mixed test for `rank-specialization-cluster` pass

PiperOrigin-RevId: 373762814
This commit is contained in:
A. Unique TensorFlower 2021-05-14 04:38:10 -07:00 committed by TensorFlow MLIR Team
parent 9248f0a182
commit 76341f3720
1 changed files with 28 additions and 0 deletions

View File

@ -85,3 +85,31 @@ func @tan(%arg : tensor<*xf32>) -> tensor<*xf32> {
%2 = chlo.tan %1 : tensor<*xf32> -> tensor<*xf32>
return %2 : tensor<*xf32>
}
// -----
// Composition of unary/binary CHLO and unary MHLO ops.
// CHECK-LABEL: @mixed
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>)
func @mixed(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>)
-> tensor<*xf32> {
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG2]], %[[ARG1]], %[[ARG0]])
// CHECK: ^bb0(%[[ARG2_:.*]]: tensor<*xf32>, %[[ARG1_:.*]]: tensor<*xf32>, %[[ARG0_:.*]]: tensor<*xf32>)
// CHECK: %[[TMP0:.*]] = chlo.tan %[[ARG0_]]
// CHECK: %[[TMP1:.*]] = "mhlo.sqrt"(%[[ARG1_]])
// CHECK: %[[TMP2:.*]] = chlo.broadcast_multiply %[[TMP0]], %[[TMP1]]
// CHECK: %[[TMP3:.*]] = chlo.broadcast_add %[[TMP2]], %[[ARG2_]]
// CHECK: %[[TMP4:.*]] = "mhlo.sqrt"(%[[TMP3]])
// CHECK: %[[TMP5:.*]] = chlo.tan %[[TMP4]]
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP5]])
// CHECK: return %[[RES]]
%0 = chlo.tan %arg0 : tensor<*xf32> -> tensor<*xf32>
%1 = "mhlo.sqrt"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
%2 = chlo.broadcast_multiply %0, %1
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
%3 = chlo.broadcast_add %2, %arg2
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
%4 = "mhlo.sqrt"(%3) : (tensor<*xf32>) -> tensor<*xf32>
%5 = chlo.tan %4 : tensor<*xf32> -> tensor<*xf32>
return %5 : tensor<*xf32>
}