[MLIR][HLO] Add more tests for `rank-specialization-cluster` pass
PiperOrigin-RevId: 373343750
This commit is contained in:
parent
b95162f182
commit
875803e5e1
|
@ -1,4 +1,4 @@
|
||||||
// RUN: mlir-hlo-opt %s --mhlo-rank-specialization-cluster | FileCheck %s
|
// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: @add_mul
|
// CHECK-LABEL: @add_mul
|
||||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>)
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>)
|
||||||
|
@ -17,3 +17,35 @@ func @add_mul(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>,
|
||||||
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
: (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||||
return %1 : tensor<*xf32>
|
return %1 : tensor<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Unary MHLO operation.
|
||||||
|
// CHECK-LABEL: @sqrt
|
||||||
|
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
|
||||||
|
func @sqrt(%arg : tensor<*xf32>) -> tensor<*xf32> {
|
||||||
|
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG]])
|
||||||
|
// CHECK: ^bb0(%[[ARG_:.*]]: tensor<*xf32>):
|
||||||
|
// CHECK: %[[TMP0:.*]] = "mhlo.sqrt"(%[[ARG_]])
|
||||||
|
// CHECK: %[[TMP1:.*]] = "mhlo.sqrt"(%[[TMP0]])
|
||||||
|
// CHECK: %[[TMP2:.*]] = "mhlo.sqrt"(%[[TMP1]])
|
||||||
|
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP2]])
|
||||||
|
// CHECK: return %[[RES]]
|
||||||
|
%0 = "mhlo.sqrt"(%arg) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
%1 = "mhlo.sqrt"(%0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
%2 = "mhlo.sqrt"(%1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
return %2 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Don't cluster single ranked operation.
|
||||||
|
// CHECK-LABEL: @sqrt_ranked
|
||||||
|
// CHECK-SAME: (%[[ARG:.*]]: tensor<3x?xf32>)
|
||||||
|
func @sqrt_ranked(%arg: tensor<3x?xf32>) -> tensor<3x?xf32> {
|
||||||
|
// CHECK-NOT: rank_specialization_cluster
|
||||||
|
%0 = "mhlo.sqrt"(%arg) : (tensor<3x?xf32>) -> tensor<3x?xf32>
|
||||||
|
%1 = "mhlo.sqrt"(%0) : (tensor<3x?xf32>) -> tensor<3x?xf32>
|
||||||
|
%2 = "mhlo.sqrt"(%1) : (tensor<3x?xf32>) -> tensor<3x?xf32>
|
||||||
|
return %2 : tensor<3x?xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue