[MLIR][HLO] Avoid duplicate cluster operands when merging
When merging rank specialization clusters, avoid duplicating operands. A fewer number of operands usually allows better rank specialization. PiperOrigin-RevId: 378445946
This commit is contained in:
parent
b580722041
commit
9f67417b41
|
@ -181,7 +181,10 @@ struct MergeRankSpecializationClusterOpsPattern
|
||||||
SmallVector<Value, 8> new_operands;
|
SmallVector<Value, 8> new_operands;
|
||||||
for (Value v : preceding_op.operands()) new_operands.push_back(v);
|
for (Value v : preceding_op.operands()) new_operands.push_back(v);
|
||||||
for (Value v : op.operands()) {
|
for (Value v : op.operands()) {
|
||||||
if (v.getDefiningOp() != preceding_op) new_operands.push_back(v);
|
if (v.getDefiningOp() != preceding_op &&
|
||||||
|
!llvm::is_contained(preceding_op.operands(), v)) {
|
||||||
|
new_operands.push_back(v);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Merge cluster results. Consider only those results of the preceding
|
// Merge cluster results. Consider only those results of the preceding
|
||||||
|
@ -221,7 +224,6 @@ struct MergeRankSpecializationClusterOpsPattern
|
||||||
// Map operands and copy operations of the second cluster. If they result
|
// Map operands and copy operations of the second cluster. If they result
|
||||||
// from the preceeding cluster, we can simply map the corresponding value
|
// from the preceeding cluster, we can simply map the corresponding value
|
||||||
// internally.
|
// internally.
|
||||||
int64_t block_arg_offset = preceding_op->getNumOperands();
|
|
||||||
for (auto it : llvm::zip(body->getArguments(), op.operands())) {
|
for (auto it : llvm::zip(body->getArguments(), op.operands())) {
|
||||||
Value block_arg, operand;
|
Value block_arg, operand;
|
||||||
std::tie(block_arg, operand) = it;
|
std::tie(block_arg, operand) = it;
|
||||||
|
@ -231,7 +233,8 @@ struct MergeRankSpecializationClusterOpsPattern
|
||||||
bvm.map(block_arg,
|
bvm.map(block_arg,
|
||||||
bvm.lookup(preceding_yield_op.getOperand(where.getIndex())));
|
bvm.lookup(preceding_yield_op.getOperand(where.getIndex())));
|
||||||
} else {
|
} else {
|
||||||
bvm.map(block_arg, new_body->getArgument(block_arg_offset++));
|
auto where = llvm::find(new_op.operands(), operand);
|
||||||
|
bvm.map(block_arg, new_body->getArgument(where.getIndex()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (Operation &nested_op : body->without_terminator()) {
|
for (Operation &nested_op : body->without_terminator()) {
|
||||||
|
|
|
@ -565,19 +565,22 @@ func @merge_clusters(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>)
|
||||||
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]], %[[ARG1]])
|
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]], %[[ARG1]])
|
||||||
// CHECK: ^bb0(%[[ARG0_:.*]]: tensor<*xf64>, %[[ARG1_:.*]]: tensor<*xf64>):
|
// CHECK: ^bb0(%[[ARG0_:.*]]: tensor<*xf64>, %[[ARG1_:.*]]: tensor<*xf64>):
|
||||||
// CHECK: %[[TMP0:.*]] = "mhlo.tanh"(%[[ARG0_]])
|
// CHECK: %[[TMP0:.*]] = "mhlo.tanh"(%[[ARG0_]])
|
||||||
// CHECK: %[[TMP1:.*]] = chlo.broadcast_add %[[TMP0]], %[[ARG1_]]
|
// CHECK: %[[TMP1:.*]] = chlo.broadcast_add %[[TMP0]], %[[ARG0_]]
|
||||||
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP1]])
|
// CHECK: %[[TMP2:.*]] = chlo.broadcast_add %[[TMP1]], %[[ARG1_]]
|
||||||
|
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP2]])
|
||||||
// CHECK: return %[[RES]]
|
// CHECK: return %[[RES]]
|
||||||
%0 = "chlo.rank_specialization_cluster"(%arg0) ({
|
%0 = "chlo.rank_specialization_cluster"(%arg0) ({
|
||||||
^bb0(%arg0_: tensor<*xf64>):
|
^bb0(%arg0_: tensor<*xf64>):
|
||||||
%1 = "mhlo.tanh"(%arg0_) : (tensor<*xf64>) -> tensor<*xf64>
|
%1 = "mhlo.tanh"(%arg0_) : (tensor<*xf64>) -> tensor<*xf64>
|
||||||
"chlo.rank_specialization_cluster_yield"(%1) : (tensor<*xf64>) -> ()
|
"chlo.rank_specialization_cluster_yield"(%1) : (tensor<*xf64>) -> ()
|
||||||
}) : (tensor<*xf64>) -> (tensor<*xf64>)
|
}) : (tensor<*xf64>) -> (tensor<*xf64>)
|
||||||
%2 = "chlo.rank_specialization_cluster"(%0, %arg1) ({
|
%2 = "chlo.rank_specialization_cluster"(%0, %arg0, %arg1) ({
|
||||||
^bb0(%3: tensor<*xf64>, %4: tensor<*xf64>):
|
^bb0(%3: tensor<*xf64>, %4: tensor<*xf64>, %5: tensor<*xf64>):
|
||||||
%5 = "chlo.broadcast_add"(%3, %4)
|
%6 = "chlo.broadcast_add"(%3, %4)
|
||||||
: (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
|
: (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
|
||||||
"chlo.rank_specialization_cluster_yield"(%5) : (tensor<*xf64>) -> ()
|
%7 = "chlo.broadcast_add"(%6, %5)
|
||||||
}) : (tensor<*xf64>, tensor<*xf64>) -> (tensor<*xf64>)
|
: (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
|
||||||
|
"chlo.rank_specialization_cluster_yield"(%7) : (tensor<*xf64>) -> ()
|
||||||
|
}) : (tensor<*xf64>, tensor<*xf64>, tensor<*xf64>) -> (tensor<*xf64>)
|
||||||
return %2 : tensor<*xf64>
|
return %2 : tensor<*xf64>
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue