diff --git a/lib/Dialect/mhlo/transforms/rank_specialization.cc b/lib/Dialect/mhlo/transforms/rank_specialization.cc index 91aa421..3f8310f 100644 --- a/lib/Dialect/mhlo/transforms/rank_specialization.cc +++ b/lib/Dialect/mhlo/transforms/rank_specialization.cc @@ -181,7 +181,10 @@ struct MergeRankSpecializationClusterOpsPattern SmallVector new_operands; for (Value v : preceding_op.operands()) new_operands.push_back(v); for (Value v : op.operands()) { - if (v.getDefiningOp() != preceding_op) new_operands.push_back(v); + if (v.getDefiningOp() != preceding_op && + !llvm::is_contained(preceding_op.operands(), v)) { + new_operands.push_back(v); + } } // Merge cluster results. Consider only those results of the preceding @@ -221,7 +224,6 @@ struct MergeRankSpecializationClusterOpsPattern // Map operands and copy operations of the second cluster. If they result // from the preceeding cluster, we can simply map the corresponding value // internally. - int64_t block_arg_offset = preceding_op->getNumOperands(); for (auto it : llvm::zip(body->getArguments(), op.operands())) { Value block_arg, operand; std::tie(block_arg, operand) = it; @@ -231,7 +233,8 @@ struct MergeRankSpecializationClusterOpsPattern bvm.map(block_arg, bvm.lookup(preceding_yield_op.getOperand(where.getIndex()))); } else { - bvm.map(block_arg, new_body->getArgument(block_arg_offset++)); + auto where = llvm::find(new_op.operands(), operand); + bvm.map(block_arg, new_body->getArgument(where.getIndex())); } } for (Operation &nested_op : body->without_terminator()) { diff --git a/tests/rank-specialization.mlir b/tests/rank-specialization.mlir index b216695..f7d406b 100644 --- a/tests/rank-specialization.mlir +++ b/tests/rank-specialization.mlir @@ -565,19 +565,22 @@ func @merge_clusters(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>) // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]], %[[ARG1]]) // CHECK: ^bb0(%[[ARG0_:.*]]: tensor<*xf64>, %[[ARG1_:.*]]: tensor<*xf64>): // CHECK: %[[TMP0:.*]] = "mhlo.tanh"(%[[ARG0_]]) - // CHECK: %[[TMP1:.*]] = chlo.broadcast_add %[[TMP0]], %[[ARG1_]] - // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP1]]) + // CHECK: %[[TMP1:.*]] = chlo.broadcast_add %[[TMP0]], %[[ARG0_]] + // CHECK: %[[TMP2:.*]] = chlo.broadcast_add %[[TMP1]], %[[ARG1_]] + // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP2]]) // CHECK: return %[[RES]] %0 = "chlo.rank_specialization_cluster"(%arg0) ({ ^bb0(%arg0_: tensor<*xf64>): %1 = "mhlo.tanh"(%arg0_) : (tensor<*xf64>) -> tensor<*xf64> "chlo.rank_specialization_cluster_yield"(%1) : (tensor<*xf64>) -> () }) : (tensor<*xf64>) -> (tensor<*xf64>) - %2 = "chlo.rank_specialization_cluster"(%0, %arg1) ({ - ^bb0(%3: tensor<*xf64>, %4: tensor<*xf64>): - %5 = "chlo.broadcast_add"(%3, %4) + %2 = "chlo.rank_specialization_cluster"(%0, %arg0, %arg1) ({ + ^bb0(%3: tensor<*xf64>, %4: tensor<*xf64>, %5: tensor<*xf64>): + %6 = "chlo.broadcast_add"(%3, %4) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> - "chlo.rank_specialization_cluster_yield"(%5) : (tensor<*xf64>) -> () - }) : (tensor<*xf64>, tensor<*xf64>) -> (tensor<*xf64>) + %7 = "chlo.broadcast_add"(%6, %5) + : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> + "chlo.rank_specialization_cluster_yield"(%7) : (tensor<*xf64>) -> () + }) : (tensor<*xf64>, tensor<*xf64>, tensor<*xf64>) -> (tensor<*xf64>) return %2 : tensor<*xf64> }