From bd5752f0bf2743fcbcb5b8e8491d0cfa643e23ae Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 11 Jun 2021 03:59:02 -0700 Subject: [PATCH] [MLIR][HLO] Find shape equivalences and use them for better rank specialization Find shape equivalence classes among the operands and use them for better rank specialization. If all operands are known to be of the same shape, we can flatten them to rank one. If there are two shape equivalence classes, we can generalize the scalar rank specialization cases. PiperOrigin-RevId: 378844575 --- .../mhlo/transforms/rank_specialization.cc | 218 +++++++++++------- tests/rank-specialization.mlir | 37 ++- 2 files changed, 169 insertions(+), 86 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/rank_specialization.cc b/lib/Dialect/mhlo/transforms/rank_specialization.cc index 3f8310f..4db229c 100644 --- a/lib/Dialect/mhlo/transforms/rank_specialization.cc +++ b/lib/Dialect/mhlo/transforms/rank_specialization.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ +#include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" @@ -38,7 +39,8 @@ limitations under the License. namespace mlir { -/// Needed to build `llvm::SmallSet`s of `mlir::Value`s. +/// Needed to build `llvm::SmallSet`s and `llvm::EquivalenceClasses` of +/// `mlir::Value`s. static bool operator<(const Value &lhs, const Value &rhs) { return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer(); } @@ -308,18 +310,6 @@ Type DeriveUnrankedTensorTypes(Type ty) { return ty; } -Optional FindUniqueNonScalar(ValueRange values) { - Value unique_non_scalar; - for (Value v : values) { - if (!IsScalarTensorType(v.getType())) { - if (unique_non_scalar) return llvm::None; - unique_non_scalar = v; - } - } - if (!unique_non_scalar) return llvm::None; - return unique_non_scalar; -} - SmallVector MaterializeRankedOperations( OpBuilder &b, Location loc, BlockAndValueMapping &bvm, chlo::RankSpecializationClusterOp op) { @@ -375,20 +365,35 @@ SmallVector MaterializeFinalReshape( })); } +Value MaterializeFlatShape(OpBuilder &b, Location loc, ValueRange same_shapes) { + assert(!same_shapes.empty() && "Expected at least one shape."); + Value shape = same_shapes.size() == 1 + ? same_shapes.front() + : b.create(loc, same_shapes.front().getType(), + same_shapes); + return b.create( + loc, + b.create(loc, b.getIndexType(), shape).result()); +} + Value MaterializeScalarRankSpecializationCase( OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, - const SmallVector &shapes, int64_t non_scalar_idx, + const SmallVector &shapes, ValueRange non_scalars_of_same_shape, function_ref else_builder_fn) { - // Materialize predicate: All operands except one are scalars. + // Materialize predicate: All operands are scalars, except the expected + // non-scalars. Value one = b.create(loc, 1); Value all_others_are_scalar; - for (auto it : llvm::enumerate(shapes)) { - if (it.index() == non_scalar_idx) continue; - // For statically known scalars, there is no need to test. - if (IsScalarTensorType(op.getOperand(it.index()).getType())) continue; + for (auto it : llvm::zip(op.operands(), shapes)) { + Value operand, shape; + std::tie(operand, shape) = it; + if (llvm::is_contained(non_scalars_of_same_shape, operand) || + IsScalarTensorType(operand.getType())) { + continue; + } auto literal = b.create(loc, CmpIPredicate::eq, - b.create(loc, it.value()), one); + b.create(loc, shape), one); all_others_are_scalar = all_others_are_scalar ? b.create(loc, all_others_are_scalar, literal) @@ -399,24 +404,31 @@ Value MaterializeScalarRankSpecializationCase( auto if_op = b.create( loc, op->getResultTypes(), all_others_are_scalar, [&](OpBuilder &b, Location loc) { - // Flatten the non-scalar operand. - Value flat_shape = b.create( - loc, b.create(loc, b.getIndexType(), - shapes[non_scalar_idx]) - .result()); - Value non_scalar_operand = op.operands()[non_scalar_idx]; - Value flat_non_scalar_operand = b.create( - loc, - DeriveRankedTensorTypes(non_scalar_operand.getType(), /*rank=*/1), - non_scalar_operand, flat_shape); + // Compute flat non-scalar shape. + SmallVector non_scalar_shapes; + for (auto it : llvm::zip(op.operands(), shapes)) { + Value operand, shape; + std::tie(operand, shape) = it; + if (llvm::is_contained(non_scalars_of_same_shape, operand)) + non_scalar_shapes.push_back(shape); + } + Value flat_shape = MaterializeFlatShape(b, loc, non_scalar_shapes); // Derive ranked operands. auto ranked_operands = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) { - if (v == non_scalar_operand) return flat_non_scalar_operand; + if (IsScalarTensorType(v.getType())) return v; + if (!llvm::is_contained(non_scalars_of_same_shape, v)) { + return b + .create( + loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/0), + v) + .getResult(); + } return b - .create( - loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/0), v) + .create( + loc, DeriveRankedTensorTypes(v.getType(), /*rank=*/1), v, + flat_shape) .getResult(); })); @@ -464,14 +476,7 @@ Value MaterializeEqualShapesRankSpecializationCase( loc, op->getResultTypes(), all_shapes_eq_or_scalar, [&](OpBuilder &b, Location loc) { // Flatten non-scalar operands. - Value shape = non_scalar_shapes.front(); - for (Value s : llvm::drop_begin(non_scalar_shapes)) { - shape = b.create(loc, shape.getType(), - ValueRange{shape, s}); - } - Value flat_shape = b.create( - loc, b.create(loc, b.getIndexType(), shape) - .result()); + Value flat_shape = MaterializeFlatShape(b, loc, non_scalar_shapes); auto flat_operands = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) { if (IsScalarTensorType(v.getType())) return v; @@ -631,18 +636,15 @@ Value MaterializeDefaultRankSpecializationCases( }); } -SmallVector MaterializeRankSpecializationForSingleNonScalarOperand( +SmallVector +MaterializeRankSpecializationForSingleNonScalarShapeEquivalenceClass( OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, - Value non_scalar_operand) { - // Flatten the non-scalar operand. - Value flat_shape = b.create( - loc, b.create( - loc, b.getIndexType(), - b.create(loc, non_scalar_operand)) - .result()); - Value flat_non_scalar_operand = b.create( - loc, DeriveRankedTensorTypes(non_scalar_operand.getType(), /*rank=*/1), - non_scalar_operand, flat_shape); + ValueRange non_scalars_of_same_shape) { + // Compute flat operand shape. + auto non_scalar_shapes = llvm::to_vector<4>(llvm::map_range( + non_scalars_of_same_shape, + [&](Value v) { return b.create(loc, v).result(); })); + Value flat_shape = MaterializeFlatShape(b, loc, non_scalar_shapes); // Materialize ranked variants for the element-wise operations. BlockAndValueMapping bvm; @@ -650,8 +652,14 @@ SmallVector MaterializeRankSpecializationForSingleNonScalarOperand( Value operand; Value bb_arg; std::tie(bb_arg, operand) = it; - bvm.map(bb_arg, - operand == non_scalar_operand ? flat_non_scalar_operand : operand); + if (!IsScalarTensorType(operand.getType())) { + assert(llvm::is_contained(non_scalars_of_same_shape, operand) && + "Expected all non-scalars in the same shape equivalence class."); + operand = b.create( + loc, DeriveRankedTensorTypes(operand.getType(), /*rank=*/1), operand, + flat_shape); + } + bvm.map(bb_arg, operand); } SmallVector unshaped_results = MaterializeRankedOperations(b, loc, bvm, op); @@ -660,24 +668,24 @@ SmallVector MaterializeRankSpecializationForSingleNonScalarOperand( return MaterializeFinalReshape(b, loc, op, unshaped_results); } -Value MaterializeRankSpecializationForTwoNonScalarOperands( +Value MaterializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses( OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, - ValueRange non_scalar_operands, int64_t max_target_rank) { - assert(non_scalar_operands.size() == 2); - + SmallVector, 4> non_scalar_eqs, + int64_t max_target_rank) { + assert(non_scalar_eqs.size() == 2 && + "Expect two non-scalar equivalence classes."); auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) { return b.create(loc, v).result(); })); - auto non_scalar_lhs = llvm::find(op.operands(), non_scalar_operands[0]); - auto non_scalar_rhs = llvm::find(op.operands(), non_scalar_operands[1]); + ValueRange lhs_non_scalar_eqs = non_scalar_eqs[0]; + ValueRange rhs_non_scalar_eqs = non_scalar_eqs[1]; // Materialize all the different cases. Value unshaped_result = MaterializeScalarRankSpecializationCase( - b, loc, op, shapes, non_scalar_rhs.getIndex(), - [&](OpBuilder &b, Location loc) { + b, loc, op, shapes, rhs_non_scalar_eqs, [&](OpBuilder &b, Location loc) { b.create( loc, MaterializeScalarRankSpecializationCase( - b, loc, op, shapes, non_scalar_lhs.getIndex(), + b, loc, op, shapes, lhs_non_scalar_eqs, [&](OpBuilder &b, Location loc) { b.create( loc, MaterializeDefaultRankSpecializationCases( @@ -705,6 +713,54 @@ Value MaterializeDefaultRankSpecialization(OpBuilder &b, Location loc, return MaterializeFinalReshape(b, loc, op, unshaped_result).front(); } +// This is a very limited form of shape inference. It is correct but incomplete. +SmallVector, 4> FindNonScalarShapeEquivalences( + chlo::RankSpecializationClusterOp op) { + llvm::EquivalenceClasses eqs; + + // Bridge the equivalences between operands and block arguments. + for (auto it : llvm::zip(op.operands(), op.getBody()->getArguments())) + eqs.unionSets(std::get<0>(it), std::get<1>(it)); + + // Find equalities through `SameOperandsAndResultShape` trait. + auto union_sets = [&](ValueRange vs) { + if (vs.empty()) return; + Value repr = vs.front(); + for (Value v : vs.drop_front()) eqs.unionSets(repr, v); + }; + for (Operation &nested_op : op.getBody()->without_terminator()) { + if (nested_op.hasTrait()) { + union_sets(nested_op.getOperands()); + union_sets(nested_op.getResults()); + if (!nested_op.getOperands().empty() && !nested_op.getResults().empty()) + eqs.unionSets(nested_op.getResult(0), nested_op.getOperand(0)); + } + // TODO(frgossen): Replace this with a check for the appropriate trait when + // that is available. + if (auto select_op = llvm::dyn_cast(nested_op)) { + union_sets( + {select_op.on_true(), select_op.on_false(), select_op.getResult()}); + } + } + + // Convert to a list-like equivalence class representation. + SmallVector, 4> non_scalar_eqs; + for (Value v : op.operands()) { + if (IsScalarTensorType(v.getType())) continue; + bool inserted = false; + for (auto &eq_class : non_scalar_eqs) { + if (eqs.isEquivalent(eq_class.front(), v)) { + eq_class.push_back(v); + inserted = true; + break; + } + } + if (!inserted) non_scalar_eqs.push_back(SmallVector({v})); + } + + return non_scalar_eqs; +} + struct LowerRankSpecializationClusterPattern : public OpRewritePattern { LowerRankSpecializationClusterPattern(MLIRContext *ctx, @@ -719,31 +775,27 @@ struct LowerRankSpecializationClusterPattern // all operands. if (op.getNumResults() != 1) return failure(); - // If there is only a single non-scalar operand, we can flatten that operand - // completely. + // If there is only a single non-scalar shape equivalence class, we can + // flatten that operands completely. + SmallVector, 4> non_scalar_eqs = + FindNonScalarShapeEquivalences(op); Location loc = op.getLoc(); - auto non_scalar_operands = - llvm::to_vector<2>(llvm::make_filter_range(op.operands(), [](Value v) { - return !IsScalarTensorType(v.getType()); - })); - if (non_scalar_operands.size() == 1) { - rewriter.replaceOp(op, - MaterializeRankSpecializationForSingleNonScalarOperand( - rewriter, loc, op, non_scalar_operands.front())); + if (non_scalar_eqs.size() == 1) { + rewriter.replaceOp( + op, + MaterializeRankSpecializationForSingleNonScalarShapeEquivalenceClass( + rewriter, loc, op, non_scalar_eqs.front())); return success(); } - // If there are exactly two unranked operands and all others are known to be - // scalars, we can consider two extra cases: If either of the unranked - // operands turns out to be a scalar at runtime, we can, again, apply the - // trick for a single non-scalar operand. - if (non_scalar_operands.size() == 2 && - llvm::all_of(non_scalar_operands, [](Value v) { - return v.getType().isa(); - })) { + // If there are exactly two non-scalar shape equivalence classes, we can + // consider two extra cases: If either of the operand classes turns out to + // be all-scalars at runtime, we can, again, flatten all operands. + if (non_scalar_eqs.size() == 2) { rewriter.replaceOp( - op, MaterializeRankSpecializationForTwoNonScalarOperands( - rewriter, loc, op, non_scalar_operands, max_target_rank)); + op, + MaterializeRankSpecializationForTwoNonScalarShapeEquivalenceClasses( + rewriter, loc, op, non_scalar_eqs, max_target_rank)); return success(); } diff --git a/tests/rank-specialization.mlir b/tests/rank-specialization.mlir index f7d406b..c71b71b 100644 --- a/tests/rank-specialization.mlir +++ b/tests/rank-specialization.mlir @@ -51,9 +51,8 @@ func @compare_const_like(%arg0 : tensor<*xf32>) -> tensor<*xi1> { // CHECK-SCF-DAG: %[[EQ21:.*]] = shape.shape_eq %[[SHAPE_ARG2]], %[[SHAPE_ARG1]] // CHECK-SCF-DAG: %[[SHAPES_EQ:.*]] = and %[[EQ20]], %[[EQ21]] // CHECK-SCF: %[[UNSHAPED_RES_EQ_SHAPES:.*]] = scf.if %[[SHAPES_EQ]] -// CHECK-SCF-DAG: %[[S20:.*]] = shape.any %[[SHAPE_ARG2]], %[[SHAPE_ARG0]] -// CHECK-SCF-DAG: %[[S201:.*]] = shape.any %[[S20]], %[[SHAPE_ARG1]] -// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[S201]] +// CHECK-SCF-DAG: %[[ANY_SHAPE:.*]] = shape.any %[[SHAPE_ARG2]], %[[SHAPE_ARG0]], %[[SHAPE_ARG1]] +// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[ANY_SHAPE]] // CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] // CHECK-SCF-DAG: %[[FLAT_ARG0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]]) // CHECK-SCF-DAG: %[[FLAT_ARG1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]]) @@ -584,3 +583,35 @@ func @merge_clusters(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>) }) : (tensor<*xf64>, tensor<*xf64>, tensor<*xf64>) -> (tensor<*xf64>) return %2 : tensor<*xf64> } + +// ----- + +// CHECK-LABEL: @all_equal_shapes_inferrable +// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf64>, %[[ARG1:.*]]: tensor<*xf64>) +func @all_equal_shapes_inferrable(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>) + -> tensor<*xf64> { + // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG0]], %[[ARG1]]) + // CHECK: ^bb0(%[[ARG0_:.*]]: tensor<*xf64>, %[[ARG1_:.*]]: tensor<*xf64>) + // CHECK: %[[INNER_RES:.*]] = mhlo.add %[[ARG0_]], %[[ARG1_]] + // CHECK: "chlo.rank_specialization_cluster_yield"(%[[INNER_RES]]) + // CHECK: return %[[RES]] + %0 = "mhlo.add"(%arg0, %arg1) + : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> + return %0 : tensor<*xf64> +} + +// CHECK-SCF-LABEL: @all_equal_shapes_inferrable +// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf64>, %[[ARG1:.*]]: tensor<*xf64>) +// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]] +// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]] +// CHECK-SCF-DAG: %[[S:.*]] = shape.any %[[S0]], %[[S1]] +// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[S]] +// CHECK-SCF-DAG: %[[FLAT_S:.*]] = tensor.from_elements %[[N]] +// CHECK-SCF-DAG: %[[FLAT0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_S]]) +// CHECK-SCF-DAG: %[[FLAT1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_S]]) +// CHECK-SCF: %[[FLAT_RES:.*]] = mhlo.add %[[FLAT0]], %[[FLAT1]] +// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]] +// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]] +// CHECK-SCF-DAG: %[[RES_S:.*]] = shape.broadcast %8, %9 +// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RES]], %[[RES_S]]) +// CHECK-SCF: return %[[RES]]