[MLIR][HLO] Avoid duplicate cluster operands when merging
When merging rank specialization clusters, avoid duplicating operands. A fewer number of operands usually allows better rank specialization. PiperOrigin-RevId: 378445946
This commit is contained in:
parent
b580722041
commit
9f67417b41
|
@ -181,7 +181,10 @@ struct MergeRankSpecializationClusterOpsPattern
|
|||
SmallVector<Value, 8> new_operands;
|
||||
for (Value v : preceding_op.operands()) new_operands.push_back(v);
|
||||
for (Value v : op.operands()) {
|
||||
if (v.getDefiningOp() != preceding_op) new_operands.push_back(v);
|
||||
if (v.getDefiningOp() != preceding_op &&
|
||||
!llvm::is_contained(preceding_op.operands(), v)) {
|
||||
new_operands.push_back(v);
|
||||
}
|
||||
}
|
||||
|
||||
// Merge cluster results. Consider only those results of the preceding
|
||||
|
@ -221,7 +224,6 @@ struct MergeRankSpecializationClusterOpsPattern
|
|||
// Map operands and copy operations of the second cluster. If they result
|
||||
// from the preceeding cluster, we can simply map the corresponding value
|
||||
// internally.
|
||||
int64_t block_arg_offset = preceding_op->getNumOperands();
|
||||
for (auto it : llvm::zip(body->getArguments(), op.operands())) {
|
||||
Value block_arg, operand;
|
||||
std::tie(block_arg, operand) = it;
|
||||
|
@ -231,7 +233,8 @@ struct MergeRankSpecializationClusterOpsPattern
|
|||
bvm.map(block_arg,
|
||||
bvm.lookup(preceding_yield_op.getOperand(where.getIndex())));
|
||||
} else {
|
||||
bvm.map(block_arg, new_body->getArgument(block_arg_offset++));
|
||||
auto where = llvm::find(new_op.operands(), operand);
|
||||
bvm.map(block_arg, new_body->getArgument(where.getIndex()));
|
||||
}
|
||||
}
|
||||
for (Operation &nested_op : body->without_terminator()) {
|
||||
|
|
|
@ -565,19 +565,22 @@ func @merge_clusters(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>)
|
|||
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]], %[[ARG1]])
|
||||
// CHECK: ^bb0(%[[ARG0_:.*]]: tensor<*xf64>, %[[ARG1_:.*]]: tensor<*xf64>):
|
||||
// CHECK: %[[TMP0:.*]] = "mhlo.tanh"(%[[ARG0_]])
|
||||
// CHECK: %[[TMP1:.*]] = chlo.broadcast_add %[[TMP0]], %[[ARG1_]]
|
||||
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP1]])
|
||||
// CHECK: %[[TMP1:.*]] = chlo.broadcast_add %[[TMP0]], %[[ARG0_]]
|
||||
// CHECK: %[[TMP2:.*]] = chlo.broadcast_add %[[TMP1]], %[[ARG1_]]
|
||||
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP2]])
|
||||
// CHECK: return %[[RES]]
|
||||
%0 = "chlo.rank_specialization_cluster"(%arg0) ({
|
||||
^bb0(%arg0_: tensor<*xf64>):
|
||||
%1 = "mhlo.tanh"(%arg0_) : (tensor<*xf64>) -> tensor<*xf64>
|
||||
"chlo.rank_specialization_cluster_yield"(%1) : (tensor<*xf64>) -> ()
|
||||
}) : (tensor<*xf64>) -> (tensor<*xf64>)
|
||||
%2 = "chlo.rank_specialization_cluster"(%0, %arg1) ({
|
||||
^bb0(%3: tensor<*xf64>, %4: tensor<*xf64>):
|
||||
%5 = "chlo.broadcast_add"(%3, %4)
|
||||
%2 = "chlo.rank_specialization_cluster"(%0, %arg0, %arg1) ({
|
||||
^bb0(%3: tensor<*xf64>, %4: tensor<*xf64>, %5: tensor<*xf64>):
|
||||
%6 = "chlo.broadcast_add"(%3, %4)
|
||||
: (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
|
||||
"chlo.rank_specialization_cluster_yield"(%5) : (tensor<*xf64>) -> ()
|
||||
}) : (tensor<*xf64>, tensor<*xf64>) -> (tensor<*xf64>)
|
||||
%7 = "chlo.broadcast_add"(%6, %5)
|
||||
: (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
|
||||
"chlo.rank_specialization_cluster_yield"(%7) : (tensor<*xf64>) -> ()
|
||||
}) : (tensor<*xf64>, tensor<*xf64>, tensor<*xf64>) -> (tensor<*xf64>)
|
||||
return %2 : tensor<*xf64>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue