From 875803e5e1481e033174a3df5621c7407d93316a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 May 2021 04:45:06 -0700 Subject: [PATCH] [MLIR][HLO] Add more tests for `rank-specialization-cluster` pass PiperOrigin-RevId: 373343750 --- tests/rank-specialization.mlir | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/tests/rank-specialization.mlir b/tests/rank-specialization.mlir index 8e8b646..eb505cf 100644 --- a/tests/rank-specialization.mlir +++ b/tests/rank-specialization.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s --mhlo-rank-specialization-cluster | FileCheck %s +// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster | FileCheck %s // CHECK-LABEL: @add_mul // CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>) @@ -17,3 +17,35 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> return %1 : tensor<*xf32> } + +// ----- + +// Unary MHLO operation. +// CHECK-LABEL: @sqrt +// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) +func @sqrt(%arg : tensor<*xf32>) -> tensor<*xf32> { + // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]]) + // CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xf32>): + // CHECK: %[[TMP0:.*]] = "mhlo.sqrt"(%[[ARG_]]) + // CHECK: %[[TMP1:.*]] = "mhlo.sqrt"(%[[TMP0]]) + // CHECK: %[[TMP2:.*]] = "mhlo.sqrt"(%[[TMP1]]) + // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP2]]) + // CHECK: return %[[RES]] + %0 = "mhlo.sqrt"(%arg) : (tensor<*xf32>) -> tensor<*xf32> + %1 = "mhlo.sqrt"(%0) : (tensor<*xf32>) -> tensor<*xf32> + %2 = "mhlo.sqrt"(%1) : (tensor<*xf32>) -> tensor<*xf32> + return %2 : tensor<*xf32> +} + +// ----- + +// Don't cluster single ranked operation. +// CHECK-LABEL: @sqrt_ranked +// CHECK-SAME: (%[[ARG:.*]]: tensor<3x?xf32>) +func @sqrt_ranked(%arg: tensor<3x?xf32>) -> tensor<3x?xf32> { + // CHECK-NOT: rank_specialization_cluster + %0 = "mhlo.sqrt"(%arg) : (tensor<3x?xf32>) -> tensor<3x?xf32> + %1 = "mhlo.sqrt"(%0) : (tensor<3x?xf32>) -> tensor<3x?xf32> + %2 = "mhlo.sqrt"(%1) : (tensor<3x?xf32>) -> tensor<3x?xf32> + return %2 : tensor<3x?xf32> +}