[MLIR][HLO] Avoid duplicate cluster operands when merging

When merging rank specialization clusters, avoid duplicating operands. A fewer
number of operands usually allows better rank specialization.

PiperOrigin-RevId: 378445946
This commit is contained in:
A. Unique TensorFlower 2021-06-09 10:53:51 -07:00 committed by TensorFlow MLIR Team
parent b580722041
commit 9f67417b41
2 changed files with 16 additions and 10 deletions

View File

@ -181,7 +181,10 @@ struct MergeRankSpecializationClusterOpsPattern
SmallVector<Value, 8> new_operands; SmallVector<Value, 8> new_operands;
for (Value v : preceding_op.operands()) new_operands.push_back(v); for (Value v : preceding_op.operands()) new_operands.push_back(v);
for (Value v : op.operands()) { for (Value v : op.operands()) {
if (v.getDefiningOp() != preceding_op) new_operands.push_back(v); if (v.getDefiningOp() != preceding_op &&
!llvm::is_contained(preceding_op.operands(), v)) {
new_operands.push_back(v);
}
} }
// Merge cluster results. Consider only those results of the preceding // Merge cluster results. Consider only those results of the preceding
@ -221,7 +224,6 @@ struct MergeRankSpecializationClusterOpsPattern
// Map operands and copy operations of the second cluster. If they result // Map operands and copy operations of the second cluster. If they result
// from the preceeding cluster, we can simply map the corresponding value // from the preceeding cluster, we can simply map the corresponding value
// internally. // internally.
int64_t block_arg_offset = preceding_op->getNumOperands();
for (auto it : llvm::zip(body->getArguments(), op.operands())) { for (auto it : llvm::zip(body->getArguments(), op.operands())) {
Value block_arg, operand; Value block_arg, operand;
std::tie(block_arg, operand) = it; std::tie(block_arg, operand) = it;
@ -231,7 +233,8 @@ struct MergeRankSpecializationClusterOpsPattern
bvm.map(block_arg, bvm.map(block_arg,
bvm.lookup(preceding_yield_op.getOperand(where.getIndex()))); bvm.lookup(preceding_yield_op.getOperand(where.getIndex())));
} else { } else {
bvm.map(block_arg, new_body->getArgument(block_arg_offset++)); auto where = llvm::find(new_op.operands(), operand);
bvm.map(block_arg, new_body->getArgument(where.getIndex()));
} }
} }
for (Operation &nested_op : body->without_terminator()) { for (Operation &nested_op : body->without_terminator()) {

View File

@ -565,19 +565,22 @@ func @merge_clusters(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>)
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]], %[[ARG1]]) // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]], %[[ARG1]])
// CHECK: ^bb0(%[[ARG0_:.*]]: tensor<*xf64>, %[[ARG1_:.*]]: tensor<*xf64>): // CHECK: ^bb0(%[[ARG0_:.*]]: tensor<*xf64>, %[[ARG1_:.*]]: tensor<*xf64>):
// CHECK: %[[TMP0:.*]] = "mhlo.tanh"(%[[ARG0_]]) // CHECK: %[[TMP0:.*]] = "mhlo.tanh"(%[[ARG0_]])
// CHECK: %[[TMP1:.*]] = chlo.broadcast_add %[[TMP0]], %[[ARG1_]] // CHECK: %[[TMP1:.*]] = chlo.broadcast_add %[[TMP0]], %[[ARG0_]]
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP1]]) // CHECK: %[[TMP2:.*]] = chlo.broadcast_add %[[TMP1]], %[[ARG1_]]
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP2]])
// CHECK: return %[[RES]] // CHECK: return %[[RES]]
%0 = "chlo.rank_specialization_cluster"(%arg0) ({ %0 = "chlo.rank_specialization_cluster"(%arg0) ({
^bb0(%arg0_: tensor<*xf64>): ^bb0(%arg0_: tensor<*xf64>):
%1 = "mhlo.tanh"(%arg0_) : (tensor<*xf64>) -> tensor<*xf64> %1 = "mhlo.tanh"(%arg0_) : (tensor<*xf64>) -> tensor<*xf64>
"chlo.rank_specialization_cluster_yield"(%1) : (tensor<*xf64>) -> () "chlo.rank_specialization_cluster_yield"(%1) : (tensor<*xf64>) -> ()
}) : (tensor<*xf64>) -> (tensor<*xf64>) }) : (tensor<*xf64>) -> (tensor<*xf64>)
%2 = "chlo.rank_specialization_cluster"(%0, %arg1) ({ %2 = "chlo.rank_specialization_cluster"(%0, %arg0, %arg1) ({
^bb0(%3: tensor<*xf64>, %4: tensor<*xf64>): ^bb0(%3: tensor<*xf64>, %4: tensor<*xf64>, %5: tensor<*xf64>):
%5 = "chlo.broadcast_add"(%3, %4) %6 = "chlo.broadcast_add"(%3, %4)
: (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
"chlo.rank_specialization_cluster_yield"(%5) : (tensor<*xf64>) -> () %7 = "chlo.broadcast_add"(%6, %5)
}) : (tensor<*xf64>, tensor<*xf64>) -> (tensor<*xf64>) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
"chlo.rank_specialization_cluster_yield"(%7) : (tensor<*xf64>) -> ()
}) : (tensor<*xf64>, tensor<*xf64>, tensor<*xf64>) -> (tensor<*xf64>)
return %2 : tensor<*xf64> return %2 : tensor<*xf64>
} }