mlir-hlo/tests/rank-specialization.mlir

52 lines
2.2 KiB
MLIR

// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster | FileCheck %s
// CHECK-LABEL: @add_mul
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>)
func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
%arg2 : tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG2]], %[[ARG0]], %[[ARG1]]) ( {
// CHECK: ^bb0(%[[ARG2_:.*]]: tensor<*xf32>, %[[ARG0_:.*]]: tensor<*xf32>, %[[ARG1_:.*]]: tensor<*xf32>):
// CHECK: %[[TMP:.*]] = chlo.broadcast_multiply %[[ARG0_]], %[[ARG1_]]
// CHECK: %[[INNER_RES:.*]] = chlo.broadcast_add %[[TMP]], %[[ARG2_]]
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[INNER_RES]])
// CHECK: }) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: return %[[RES]]
%0 = chlo.broadcast_multiply %arg0, %arg1
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
%1 = chlo.broadcast_add %0, %arg2
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %1 : tensor<*xf32>
}
// -----
// Unary MHLO operation.
// CHECK-LABEL: @sqrt
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
func @sqrt(%arg : tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]])
// CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xf32>):
// CHECK: %[[TMP0:.*]] = "mhlo.sqrt"(%[[ARG_]])
// CHECK: %[[TMP1:.*]] = "mhlo.sqrt"(%[[TMP0]])
// CHECK: %[[TMP2:.*]] = "mhlo.sqrt"(%[[TMP1]])
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP2]])
// CHECK: return %[[RES]]
%0 = "mhlo.sqrt"(%arg) : (tensor<*xf32>) -> tensor<*xf32>
%1 = "mhlo.sqrt"(%0) : (tensor<*xf32>) -> tensor<*xf32>
%2 = "mhlo.sqrt"(%1) : (tensor<*xf32>) -> tensor<*xf32>
return %2 : tensor<*xf32>
}
// -----
// Don't cluster single ranked operation.
// CHECK-LABEL: @sqrt_ranked
// CHECK-SAME: (%[[ARG:.*]]: tensor<3x?xf32>)
func @sqrt_ranked(%arg: tensor<3x?xf32>) -> tensor<3x?xf32> {
// CHECK-NOT: rank_specialization_cluster
%0 = "mhlo.sqrt"(%arg) : (tensor<3x?xf32>) -> tensor<3x?xf32>
%1 = "mhlo.sqrt"(%0) : (tensor<3x?xf32>) -> tensor<3x?xf32>
%2 = "mhlo.sqrt"(%1) : (tensor<3x?xf32>) -> tensor<3x?xf32>
return %2 : tensor<3x?xf32>
}