From 76341f3720343a30b84065449ff6bf5697705191 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 14 May 2021 04:38:10 -0700 Subject: [PATCH] [MLIR][HLO] Add mixed test for `rank-specialization-cluster` pass PiperOrigin-RevId: 373762814 --- tests/rank-specialization.mlir | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/rank-specialization.mlir b/tests/rank-specialization.mlir index aff12b9..cc0e132 100644 --- a/tests/rank-specialization.mlir +++ b/tests/rank-specialization.mlir @@ -85,3 +85,31 @@ func @tan(%arg : tensor<*xf32>) -> tensor<*xf32> { %2 = chlo.tan %1 : tensor<*xf32> -> tensor<*xf32> return %2 : tensor<*xf32> } + +// ----- + +// Composition of unary/binary CHLO and unary MHLO ops. +// CHECK-LABEL: @mixed +// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>) +func @mixed(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>) + -> tensor<*xf32> { + // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG2]], %[[ARG1]], %[[ARG0]]) + // CHECK: ^bb0(%[[ARG2_:.*]]: tensor<*xf32>, %[[ARG1_:.*]]: tensor<*xf32>, %[[ARG0_:.*]]: tensor<*xf32>) + // CHECK: %[[TMP0:.*]] = chlo.tan %[[ARG0_]] + // CHECK: %[[TMP1:.*]] = "mhlo.sqrt"(%[[ARG1_]]) + // CHECK: %[[TMP2:.*]] = chlo.broadcast_multiply %[[TMP0]], %[[TMP1]] + // CHECK: %[[TMP3:.*]] = chlo.broadcast_add %[[TMP2]], %[[ARG2_]] + // CHECK: %[[TMP4:.*]] = "mhlo.sqrt"(%[[TMP3]]) + // CHECK: %[[TMP5:.*]] = chlo.tan %[[TMP4]] + // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP5]]) + // CHECK: return %[[RES]] + %0 = chlo.tan %arg0 : tensor<*xf32> -> tensor<*xf32> + %1 = "mhlo.sqrt"(%arg1) : (tensor<*xf32>) -> tensor<*xf32> + %2 = chlo.broadcast_multiply %0, %1 + : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %3 = chlo.broadcast_add %2, %arg2 + : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %4 = "mhlo.sqrt"(%3) : (tensor<*xf32>) -> tensor<*xf32> + %5 = chlo.tan %4 : tensor<*xf32> -> tensor<*xf32> + return %5 : tensor<*xf32> +}