diff --git a/lib/Dialect/mhlo/transforms/rank_specialization.cc b/lib/Dialect/mhlo/transforms/rank_specialization.cc index 3f8310f..4db229c 100644 --- a/lib/Dialect/mhlo/transforms/rank_specialization.cc +++ b/lib/Dialect/mhlo/transforms/rank_specialization.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ +#include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" @@ -38,7 +39,8 @@ limitations under the License. namespace mlir { -/// Needed to build `llvm::SmallSet`s of `mlir::Value`s. +/// Needed to build `llvm::SmallSet`s and `llvm::EquivalenceClasses` of +/// `mlir::Value`s. static bool operator<(const Value &lhs, const Value &rhs) { return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer(); } @@ -308,18 +310,6 @@ Type DeriveUnrankedTensorTypes(Type ty) { return ty; } -Optional FindUniqueNonScalar(ValueRange values) { - Value unique_non_scalar; - for (Value v : values) { - if (!IsScalarTensorType(v.getType())) { - if (unique_non_scalar) return llvm::None; - unique_non_scalar = v; - } - } - if (!unique_non_scalar) return llvm::None; - return unique_non_scalar; -} - SmallVector MaterializeRankedOperations( OpBuilder &b, Location loc, BlockAndValueMapping &bvm, chlo::RankSpecializationClusterOp op) { @@ -375,20 +365,35 @@ SmallVector MaterializeFinalReshape( })); } +Value MaterializeFlatShape(OpBuilder &b, Location loc, ValueRange same_shapes) { + assert(!same_shapes.empty() && "Expected at least one shape."); + Value shape = same_shapes.size() == 1 + ? same_shapes.front() + : b.create(loc, same_shapes.front().getType(), + same_shapes); + return b.create( + loc, + b.create(loc, b.getIndexType(), shape).result()); +} + Value MaterializeScalarRankSpecializationCase( OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, - const SmallVector &shapes, int64_t non_scalar_idx, + const SmallVector &shapes, ValueRange non_scalars_of_same_shape, function_ref else_builder_fn) { - // Materialize predicate: All operands except one are scalars. + // Materialize predicate: All operands are scalars, except the expected + // non-scalars. Value one = b.create(loc, 1); Value all_others_are_scalar; - for (auto it : llvm::enumerate(shapes)) { - if (it.index() == non_scalar_idx) continue; - // For statically known scalars, there is no need to test. - if (IsScalarTensorType(op.getOperand(it.index()).getType())) continue; + for (auto it : llvm::zip(op.operands(), shapes)) { + Value operand, shape; + std::tie(operand, shape) = it; + if (llvm::is_contained(non_scalars_of_same_shape, operand) || + IsScalarTensorType(operand.getType())) { + continue; + } auto literal = b.create(loc, CmpIPredicate::eq, - b.create(loc, it.value()), one); + b.create(loc, shape), one); all_others_are_scalar = all_others_are_scalar ? b.create(loc, all_others_are_scalar, literal) @@ -399,24 +404,31 @@ Value MaterializeScalarRankSpecializationCase( auto if_op = b.create( loc, op->getResultTypes(), all_others_are_scalar, [&](OpBuilder &b, Location loc) { - // Flatten the non-scalar operand. - Value flat_shape = b.create( - loc, b.create(loc, b.getIndexType(), - shapes[non_scalar_idx]) - .result()); - Value non_scalar_operand = op.operands()[non_scalar_idx]; - Value flat_non_scalar_operand = b.create( - loc, - DeriveRankedTensorTypes(non_scalar_operand.getType(), /*rank=*/1), - non_scalar_operand, flat_shape); + // Compute flat non-scalar shape. + SmallVector non_scalar_shapes; + for (auto it : llvm::zip(op.operands(), shapes)) { + Value operand, shape; + std::tie(operand, shape) = it; + if (llvm::is_contained(non_scalars_of_same_shape, operand)) + non_scalar_shapes.push_back(shape); + } + Value flat_shape = MaterializeFlatShape(b, loc, non_scalar_shapes); // Derive ranked operands. auto ranked_operands = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) { - if (v == non_scalar_operand) return flat_non_scalar_operand; + if (IsScalarTensorType(v.getType())) return v; + if (!llvm::is_contained(non_scalars_of_same_shape, v)) { + return b + .create( + loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/0), + v) + .getResult(); + } return b - .create( - loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/0), v) + .create( + loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/1), v, + flat_shape) .getResult(); })); @@ -464,14 +476,7 @@ Value MaterializeEqualShapesRankSpecializationCase( loc, op->getResultTypes(), all_shapes_eq_or_scalar, [&](OpBuilder &b, Location loc) { // Flatten non-scalar operands. - Value shape = non_scalar_shapes.front(); - for (Value s : llvm::drop_begin(non_scalar_shapes)) { - shape = b.create(loc, shape.getType(), - ValueRange{shape, s}); - } - Value flat_shape = b.create( - loc, b.create(loc, b.getIndexType(), shape) - .result()); + Value flat_shape = MaterializeFlatShape(b, loc, non_scalar_shapes); auto flat_operands = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) { if (IsScalarTensorType(v.getType())) return v; @@ -631,18 +636,15 @@ Value MaterializeDefaultRankSpecializationCases( }); } -SmallVector MaterializeRankSpecializationForSingleNonScalarOperand( +SmallVector +MaterializeRankSpecializationForSingleNonScalarShapeEquivalenceClass( OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, - Value non_scalar_operand) { - // Flatten the non-scalar operand. - Value flat_shape = b.create( - loc, b.create( - loc, b.getIndexType(), - b.create(loc, non_scalar_operand)) - .result()); - Value flat_non_scalar_operand = b.create( - loc, DeriveRankedTensorTypes(non_scalar_operand.getType(), /*rank=*/1), - non_scalar_operand, flat_shape); + ValueRange non_scalars_of_same_shape) { + // Compute flat operand shape. + auto non_scalar_shapes = llvm::to_vector<4>(llvm::map_range( + non_scalars_of_same_shape, + [&](Value v) { return b.create(loc, v).result(); })); + Value flat_shape = MaterializeFlatShape(b, loc, non_scalar_shapes); // Materialize ranked variants for the element-wise operations. BlockAndValueMapping bvm; @@ -650,8 +652,14 @@ SmallVector MaterializeRankSpecializationForSingleNonScalarOperand( Value operand; Value bb_arg; std::tie(bb_arg, operand) = it; - bvm.map(bb_arg, - operand == non_scalar_operand ? flat_non_scalar_operand : operand); + if (!IsScalarTensorType(operand.getType())) { + assert(llvm::is_contained(non_scalars_of_same_shape, operand) && + "Expected all non-scalars in the same shape equivalence class."); + operand = b.create( + loc, DeriveRankedTensorTypes(operand.getType(), /*rank=*/1), operand, + flat_shape); + } + bvm.map(bb_arg, operand); } SmallVector unshaped_results = MaterializeRankedOperations(b, loc, bvm, op); @@ -660,24 +668,24 @@ SmallVector MaterializeRankSpecializationForSingleNonScalarOperand( return MaterializeFinalReshape(b, loc, op, unshaped_results); } -Value MaterializeRankSpecializationForTwoNonScalarOperands( +Value MaterializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses( OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, - ValueRange non_scalar_operands, int64_t max_target_rank) { - assert(non_scalar_operands.size() == 2); - + SmallVector, 4> non_scalar_eqs, + int64_t max_target_rank) { + assert(non_scalar_eqs.size() == 2 && + "Expect two non-scalar equivalence classes."); auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) { return b.create(loc, v).result(); })); - auto non_scalar_lhs = llvm::find(op.operands(), non_scalar_operands[0]); - auto non_scalar_rhs = llvm::find(op.operands(), non_scalar_operands[1]); + ValueRange lhs_non_scalar_eqs = non_scalar_eqs[0]; + ValueRange rhs_non_scalar_eqs = non_scalar_eqs[1]; // Materialize all the different cases. Value unshaped_result = MaterializeScalarRankSpecializationCase( - b, loc, op, shapes, non_scalar_rhs.getIndex(), - [&](OpBuilder &b, Location loc) { + b, loc, op, shapes, rhs_non_scalar_eqs, [&](OpBuilder &b, Location loc) { b.create( loc, MaterializeScalarRankSpecializationCase( - b, loc, op, shapes, non_scalar_lhs.getIndex(), + b, loc, op, shapes, lhs_non_scalar_eqs, [&](OpBuilder &b, Location loc) { b.create( loc, MaterializeDefaultRankSpecializationCases( @@ -705,6 +713,54 @@ Value MaterializeDefaultRankSpecialization(OpBuilder &b, Location loc, return MaterializeFinalReshape(b, loc, op, unshaped_result).front(); } +// This is a very limited form of shape inference. It is correct but incomplete. +SmallVector, 4> FindNonScalarShapeEquivalences( + chlo::RankSpecializationClusterOp op) { + llvm::EquivalenceClasses eqs; + + // Bridge the equivalences between operands and block arguments. + for (auto it : llvm::zip(op.operands(), op.getBody()->getArguments())) + eqs.unionSets(std::get<0>(it), std::get<1>(it)); + + // Find equalities through `SameOperandsAndResultShape` trait. + auto union_sets = [&](ValueRange vs) { + if (vs.empty()) return; + Value repr = vs.front(); + for (Value v : vs.drop_front()) eqs.unionSets(repr, v); + }; + for (Operation &nested_op : op.getBody()->without_terminator()) { + if (nested_op.hasTrait()) { + union_sets(nested_op.getOperands()); + union_sets(nested_op.getResults()); + if (!nested_op.getOperands().empty() && !nested_op.getResults().empty()) + eqs.unionSets(nested_op.getResult(0), nested_op.getOperand(0)); + } + // TODO(frgossen): Replace this with a check for the appropriate trait when + // that is available. + if (auto select_op = llvm::dyn_cast(nested_op)) { + union_sets( + {select_op.on_true(), select_op.on_false(), select_op.getResult()}); + } + } + + // Convert to a list-like equivalence class representation. + SmallVector, 4> non_scalar_eqs; + for (Value v : op.operands()) { + if (IsScalarTensorType(v.getType())) continue; + bool inserted = false; + for (auto &eq_class : non_scalar_eqs) { + if (eqs.isEquivalent(eq_class.front(), v)) { + eq_class.push_back(v); + inserted = true; + break; + } + } + if (!inserted) non_scalar_eqs.push_back(SmallVector({v})); + } + + return non_scalar_eqs; +} + struct LowerRankSpecializationClusterPattern : public OpRewritePattern { LowerRankSpecializationClusterPattern(MLIRContext *ctx, @@ -719,31 +775,27 @@ struct LowerRankSpecializationClusterPattern // all operands. if (op.getNumResults() != 1) return failure(); - // If there is only a single non-scalar operand, we can flatten that operand - // completely. + // If there is only a single non-scalar shape equivalence class, we can + // flatten that operands completely. + SmallVector, 4> non_scalar_eqs = + FindNonScalarShapeEquivalences(op); Location loc = op.getLoc(); - auto non_scalar_operands = - llvm::to_vector<2>(llvm::make_filter_range(op.operands(), [](Value v) { - return !IsScalarTensorType(v.getType()); - })); - if (non_scalar_operands.size() == 1) { - rewriter.replaceOp(op, - MaterializeRankSpecializationForSingleNonScalarOperand( - rewriter, loc, op, non_scalar_operands.front())); + if (non_scalar_eqs.size() == 1) { + rewriter.replaceOp( + op, + MaterializeRankSpecializationForSingleNonScalarShapeEquivalenceClass( + rewriter, loc, op, non_scalar_eqs.front())); return success(); } - // If there are exactly two unranked operands and all others are known to be - // scalars, we can consider two extra cases: If either of the unranked - // operands turns out to be a scalar at runtime, we can, again, apply the - // trick for a single non-scalar operand. - if (non_scalar_operands.size() == 2 && - llvm::all_of(non_scalar_operands, [](Value v) { - return v.getType().isa(); - })) { + // If there are exactly two non-scalar shape equivalence classes, we can + // consider two extra cases: If either of the operand classes turns out to + // be all-scalars at runtime, we can, again, flatten all operands. + if (non_scalar_eqs.size() == 2) { rewriter.replaceOp( - op, MaterializeRankSpecializationForTwoNonScalarOperands( - rewriter, loc, op, non_scalar_operands, max_target_rank)); + op, + MaterializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses( + rewriter, loc, op, non_scalar_eqs, max_target_rank)); return success(); } diff --git a/tests/rank-specialization.mlir b/tests/rank-specialization.mlir index f7d406b..c71b71b 100644 --- a/tests/rank-specialization.mlir +++ b/tests/rank-specialization.mlir @@ -51,9 +51,8 @@ func @compare_const_like(%arg0 : tensor<*xf32>) -> tensor<*xi1> { // CHECK-SCF-DAG: %[[EQ21:.*]] = shape.shape_eq %[[SHAPE_ARG2]], %[[SHAPE_ARG1]] // CHECK-SCF-DAG: %[[SHAPES_EQ:.*]] = and %[[EQ20]], %[[EQ21]] // CHECK-SCF: %[[UNSHAPED_RES_EQ_SHAPES:.*]] = scf.if %[[SHAPES_EQ]] -// CHECK-SCF-DAG: %[[S20:.*]] = shape.any %[[SHAPE_ARG2]], %[[SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[S201:.*]] = shape.any %[[S20]], %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[S201]] +// CHECK-SCF-DAG: %[[ANY_SHAPE:.*]] = shape.any %[[SHAPE_ARG2]], %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[ANY_SHAPE]] // CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] // CHECK-SCF-DAG: %[[FLAT_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]]) // CHECK-SCF-DAG: %[[FLAT_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]]) @@ -584,3 +583,35 @@ func @merge_clusters(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>) }) : (tensor<*xf64>, tensor<*xf64>, tensor<*xf64>) -> (tensor<*xf64>) return %2 : tensor<*xf64> } + +// ----- + +// CHECK-LABEL: @all_equal_shapes_inferrable +// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf64>, %[[ARG1:.*]]: tensor<*xf64>) +func @all_equal_shapes_inferrable(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>) + -> tensor<*xf64> { + // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]], %[[ARG1]]) + // CHECK: ^bb0(%[[ARG0_:.*]]: tensor<*xf64>, %[[ARG1_:.*]]: tensor<*xf64>) + // CHECK: %[[INNER_RES:.*]] = mhlo.add %[[ARG0_]], %[[ARG1_]] + // CHECK: "chlo.rank_specialization_cluster_yield"(%[[INNER_RES]]) + // CHECK: return %[[RES]] + %0 = "mhlo.add"(%arg0, %arg1) + : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> + return %0 : tensor<*xf64> +} + +// CHECK-SCF-LABEL: @all_equal_shapes_inferrable +// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf64>, %[[ARG1:.*]]: tensor<*xf64>) +// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]] +// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]] +// CHECK-SCF-DAG: %[[S:.*]] = shape.any %[[S0]], %[[S1]] +// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[S]] +// CHECK-SCF-DAG: %[[FLAT_S:.*]] = tensor.from_elements %[[N]] +// CHECK-SCF-DAG: %[[FLAT0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_S]]) +// CHECK-SCF-DAG: %[[FLAT1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_S]]) +// CHECK-SCF: %[[FLAT_RES:.*]] = mhlo.add %[[FLAT0]], %[[FLAT1]] +// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]] +// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]] +// CHECK-SCF-DAG: %[[RES_S:.*]] = shape.broadcast %8, %9 +// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RES]], %[[RES_S]]) +// CHECK-SCF: return %[[RES]]