diff --git a/BUILD b/BUILD index 999c0de..cf623ab 100644 --- a/BUILD +++ b/BUILD @@ -812,6 +812,7 @@ cc_library( ], deps = [ ":hlo", + ":pass_details", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td b/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td index 9f739d5..11804f8 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td +++ b/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td @@ -127,11 +127,16 @@ def TestUnfuseBatchNormPass : Pass<"mhlo-test-unfuse-batch-norm", "FuncOp"> { /// Rank specialization passes. def RankSpecializationClusterPass - : Pass<"mhlo-rank-specialization-cluster", "FuncOp"> { + : FunctionPass<"mhlo-rank-specialization-cluster"> { let constructor = "createRankSpecializationClusterPass()"; } def RankSpecializationToSCFPass - : Pass<"mhlo-rank-specialization-to-scf", "FuncOp"> { + : FunctionPass<"mhlo-rank-specialization-to-scf"> { let constructor = "createRankSpecializationToSCFPass()"; + let options = [ + Option<"max_target_rank_", "max-target-rank", "int", /*default=*/"8", + "The maximum supported rank after rank specialization. Any argument " + "of greater rank may result in a runtime failure.">, + ]; } diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h index 0061bc5..9c74818 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h @@ -74,6 +74,8 @@ std::unique_ptr createMoveUpDynamicBroadcastsForFusionPass(); /// specialization cluster. /// - Lower rank specialization clusters to SCF and ranked operations. std::unique_ptr createRankSpecializationClusterPass(); +std::unique_ptr createRankSpecializationToSCFPass( + int64_t max_target_rank); std::unique_ptr createRankSpecializationToSCFPass(); std::unique_ptr createOptimizeMhloPass(); diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index 2f7aa80..6b7c099 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -103,8 +103,9 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns( /// Populate rank specialization clustering and lowering patterns. void PopulateRankSpecializationClusterPatterns( MLIRContext *context, OwningRewritePatternList *patterns); -void PopulateRankSpecializationToSCFPatterns( - MLIRContext *context, OwningRewritePatternList *patterns); +void PopulateRankSpecializationToSCFPatterns(MLIRContext *context, + OwningRewritePatternList *patterns, + int64_t max_target_rank); } // namespace mhlo diff --git a/lib/Dialect/mhlo/transforms/rank_specialization.cc b/lib/Dialect/mhlo/transforms/rank_specialization.cc index 4d17904..60379b1 100644 --- a/lib/Dialect/mhlo/transforms/rank_specialization.cc +++ b/lib/Dialect/mhlo/transforms/rank_specialization.cc @@ -19,6 +19,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir/Dialect/SCF/SCF.h" @@ -158,7 +159,7 @@ struct RankSpecializationClusterPattern : public RewritePattern { }; struct RankSpecializationClusterPass - : public PassWrapper { + : public RankSpecializationClusterPassBase { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } @@ -471,7 +472,7 @@ Value RecusivelyMaterializeTargetRankSpecializationCases( Value MaterializeGenericRankSpecializationCases( OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, - const SmallVector &shapes) { + const SmallVector &shapes, int64_t max_target_rank) { // Get the minimum broadcast shapes of the operands. auto non_scalar_shapes = llvm::to_vector<8>(llvm::make_filter_range( shapes, [](Value v) { return !IsScalarShapeType(v.getType()); })); @@ -505,22 +506,19 @@ Value MaterializeGenericRankSpecializationCases( } } - // Materialize rank specialization for ranks 1, ..., 8. - // TODO(frgossen): For clusters w/o a select operation, consider only ranks - // 1, ..., 5. - const int64_t kMinTargetRank = 1; - const int64_t kMaxTargetRank = 8; + // Materialize rank specialization for ranks 1, ... return RecusivelyMaterializeTargetRankSpecializationCases( - b, loc, op, reduced_shapes, max_rank, kMinTargetRank, kMaxTargetRank); + b, loc, op, reduced_shapes, max_rank, /*min_target_rank=*/1, + max_target_rank); } Value MaterializeDefaultRankSpecializationCases( OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, - const SmallVector &shapes) { + const SmallVector &shapes, int64_t max_target_rank) { return MaterializeEqualShapesRankSpecializationCase( b, loc, op, shapes, [&](OpBuilder &b, Location loc) { - b.create( - loc, MaterializeGenericRankSpecializationCases(b, loc, op, shapes)); + b.create(loc, MaterializeGenericRankSpecializationCases( + b, loc, op, shapes, max_target_rank)); }); } @@ -555,7 +553,7 @@ SmallVector MaterializeRankSpecializationForSingleNonScalarOperand( Value MaterializeRankSpecializationForTwoNonScalarOperands( OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, - ValueRange non_scalar_operands) { + ValueRange non_scalar_operands, int64_t max_target_rank) { assert(non_scalar_operands.size() == 2); auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) { @@ -574,7 +572,7 @@ Value MaterializeRankSpecializationForTwoNonScalarOperands( [&](OpBuilder &b, Location loc) { b.create( loc, MaterializeDefaultRankSpecializationCases( - b, loc, op, shapes)); + b, loc, op, shapes, max_target_rank)); })); }); @@ -583,15 +581,16 @@ Value MaterializeRankSpecializationForTwoNonScalarOperands( } // Materialize rank generic rank specialization. -Value MaterializeDefaultRankSpecialization( - OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op) { +Value MaterializeDefaultRankSpecialization(OpBuilder &b, Location loc, + chlo::RankSpecializationClusterOp op, + int64_t max_target_rank) { auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) { return b.create(loc, v).result(); })); // Materialize all the different cases. - Value unshaped_result = - MaterializeDefaultRankSpecializationCases(b, loc, op, shapes); + Value unshaped_result = MaterializeDefaultRankSpecializationCases( + b, loc, op, shapes, max_target_rank); // Materialize final reshape once and for all rank specialization cases. return MaterializeFinalReshape(b, loc, op, unshaped_result).front(); @@ -599,7 +598,10 @@ Value MaterializeDefaultRankSpecialization( struct LowerRankSpecializationClusterPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + LowerRankSpecializationClusterPattern(MLIRContext *ctx, + int64_t max_target_rank) + : OpRewritePattern(ctx, /*benefit=*/1), + max_target_rank(max_target_rank) {} LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op, PatternRewriter &rewriter) const override { @@ -630,22 +632,33 @@ struct LowerRankSpecializationClusterPattern llvm::all_of(non_scalar_operands, [](Value v) { return v.getType().isa(); })) { - rewriter.replaceOp(op, - MaterializeRankSpecializationForTwoNonScalarOperands( - rewriter, loc, op, non_scalar_operands)); + rewriter.replaceOp( + op, MaterializeRankSpecializationForTwoNonScalarOperands( + rewriter, loc, op, non_scalar_operands, max_target_rank)); return success(); } // For all other cases, reshape the operands to match in rank, apply the // operation, and restore the expected shape. - rewriter.replaceOp(op, - MaterializeDefaultRankSpecialization(rewriter, loc, op)); + rewriter.replaceOp(op, MaterializeDefaultRankSpecialization( + rewriter, loc, op, max_target_rank)); return success(); } + + private: + int64_t max_target_rank; }; struct RankSpecializationToSCFPass - : public PassWrapper { + : public RankSpecializationToSCFPassBase { + using RankSpecializationToSCFPassBase< + RankSpecializationToSCFPass>::RankSpecializationToSCFPassBase; + explicit RankSpecializationToSCFPass(int64_t max_target_rank) + : RankSpecializationToSCFPassBase< + RankSpecializationToSCFPass>::RankSpecializationToSCFPassBase() { + this->max_target_rank_ = max_target_rank; + } + void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -654,7 +667,8 @@ struct RankSpecializationToSCFPass void runOnFunction() override { MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); - PopulateRankSpecializationToSCFPatterns(ctx, &patterns); + PopulateRankSpecializationToSCFPatterns(ctx, &patterns, + this->max_target_rank_); if (failed( applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) { return signalPassFailure(); @@ -669,15 +683,22 @@ void PopulateRankSpecializationClusterPatterns( patterns->insert(context); } -void PopulateRankSpecializationToSCFPatterns( - MLIRContext *context, OwningRewritePatternList *patterns) { - patterns->insert(context); +void PopulateRankSpecializationToSCFPatterns(MLIRContext *context, + OwningRewritePatternList *patterns, + int64_t max_target_rank) { + patterns->insert(context, + max_target_rank); } std::unique_ptr createRankSpecializationClusterPass() { return std::make_unique(); } +std::unique_ptr createRankSpecializationToSCFPass( + int64_t max_target_rank) { + return std::make_unique(max_target_rank); +} + std::unique_ptr createRankSpecializationToSCFPass() { return std::make_unique(); } diff --git a/tests/rank-specialization.mlir b/tests/rank-specialization.mlir index 5c84eb6..fec06df 100644 --- a/tests/rank-specialization.mlir +++ b/tests/rank-specialization.mlir @@ -1,5 +1,5 @@ // RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster | FileCheck %s -// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster --mhlo-rank-specialization-to-scf | FileCheck %s --check-prefix CHECK-SCF +// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster --mhlo-rank-specialization-to-scf=max-target-rank=8 | FileCheck %s --check-prefix CHECK-SCF // CHECK-LABEL: @add_mul // CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>)