[MLIR][KernelGen] Make maximum supported rank in rank specialization configurable
The maximum supported target rank of 5 is sufficient for all operations but `select`. Make the maximum target rank configurable in the rank specialization. This reduces the number of generated kernels for operations that don't require it. PiperOrigin-RevId: 376822496
This commit is contained in:
parent
c7c245eaf1
commit
d1828625ab
1
BUILD
1
BUILD
|
@ -812,6 +812,7 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
":hlo",
|
||||
":pass_details",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:InferTypeOpInterface",
|
||||
|
|
|
@ -127,11 +127,16 @@ def TestUnfuseBatchNormPass : Pass<"mhlo-test-unfuse-batch-norm", "FuncOp"> {
|
|||
/// Rank specialization passes.
|
||||
|
||||
def RankSpecializationClusterPass
|
||||
: Pass<"mhlo-rank-specialization-cluster", "FuncOp"> {
|
||||
: FunctionPass<"mhlo-rank-specialization-cluster"> {
|
||||
let constructor = "createRankSpecializationClusterPass()";
|
||||
}
|
||||
|
||||
def RankSpecializationToSCFPass
|
||||
: Pass<"mhlo-rank-specialization-to-scf", "FuncOp"> {
|
||||
: FunctionPass<"mhlo-rank-specialization-to-scf"> {
|
||||
let constructor = "createRankSpecializationToSCFPass()";
|
||||
let options = [
|
||||
Option<"max_target_rank_", "max-target-rank", "int", /*default=*/"8",
|
||||
"The maximum supported rank after rank specialization. Any argument "
|
||||
"of greater rank may result in a runtime failure.">,
|
||||
];
|
||||
}
|
||||
|
|
|
@ -74,6 +74,8 @@ std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass();
|
|||
/// specialization cluster.
|
||||
/// - Lower rank specialization clusters to SCF and ranked operations.
|
||||
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass();
|
||||
std::unique_ptr<FunctionPass> createRankSpecializationToSCFPass(
|
||||
int64_t max_target_rank);
|
||||
std::unique_ptr<FunctionPass> createRankSpecializationToSCFPass();
|
||||
|
||||
std::unique_ptr<FunctionPass> createOptimizeMhloPass();
|
||||
|
|
|
@ -103,8 +103,9 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
|||
/// Populate rank specialization clustering and lowering patterns.
|
||||
void PopulateRankSpecializationClusterPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList *patterns);
|
||||
void PopulateRankSpecializationToSCFPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList *patterns);
|
||||
void PopulateRankSpecializationToSCFPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns,
|
||||
int64_t max_target_rank);
|
||||
|
||||
} // namespace mhlo
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
|
@ -158,7 +159,7 @@ struct RankSpecializationClusterPattern : public RewritePattern {
|
|||
};
|
||||
|
||||
struct RankSpecializationClusterPass
|
||||
: public PassWrapper<RankSpecializationClusterPass, FunctionPass> {
|
||||
: public RankSpecializationClusterPassBase<RankSpecializationClusterPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<mhlo::MhloDialect, chlo::HloClientDialect>();
|
||||
}
|
||||
|
@ -471,7 +472,7 @@ Value RecusivelyMaterializeTargetRankSpecializationCases(
|
|||
|
||||
Value MaterializeGenericRankSpecializationCases(
|
||||
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
||||
const SmallVector<Value, 8> &shapes) {
|
||||
const SmallVector<Value, 8> &shapes, int64_t max_target_rank) {
|
||||
// Get the minimum broadcast shapes of the operands.
|
||||
auto non_scalar_shapes = llvm::to_vector<8>(llvm::make_filter_range(
|
||||
shapes, [](Value v) { return !IsScalarShapeType(v.getType()); }));
|
||||
|
@ -505,22 +506,19 @@ Value MaterializeGenericRankSpecializationCases(
|
|||
}
|
||||
}
|
||||
|
||||
// Materialize rank specialization for ranks 1, ..., 8.
|
||||
// TODO(frgossen): For clusters w/o a select operation, consider only ranks
|
||||
// 1, ..., 5.
|
||||
const int64_t kMinTargetRank = 1;
|
||||
const int64_t kMaxTargetRank = 8;
|
||||
// Materialize rank specialization for ranks 1, ...
|
||||
return RecusivelyMaterializeTargetRankSpecializationCases(
|
||||
b, loc, op, reduced_shapes, max_rank, kMinTargetRank, kMaxTargetRank);
|
||||
b, loc, op, reduced_shapes, max_rank, /*min_target_rank=*/1,
|
||||
max_target_rank);
|
||||
}
|
||||
|
||||
Value MaterializeDefaultRankSpecializationCases(
|
||||
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
||||
const SmallVector<Value, 8> &shapes) {
|
||||
const SmallVector<Value, 8> &shapes, int64_t max_target_rank) {
|
||||
return MaterializeEqualShapesRankSpecializationCase(
|
||||
b, loc, op, shapes, [&](OpBuilder &b, Location loc) {
|
||||
b.create<scf::YieldOp>(
|
||||
loc, MaterializeGenericRankSpecializationCases(b, loc, op, shapes));
|
||||
b.create<scf::YieldOp>(loc, MaterializeGenericRankSpecializationCases(
|
||||
b, loc, op, shapes, max_target_rank));
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -555,7 +553,7 @@ SmallVector<Value, 8> MaterializeRankSpecializationForSingleNonScalarOperand(
|
|||
|
||||
Value MaterializeRankSpecializationForTwoNonScalarOperands(
|
||||
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
||||
ValueRange non_scalar_operands) {
|
||||
ValueRange non_scalar_operands, int64_t max_target_rank) {
|
||||
assert(non_scalar_operands.size() == 2);
|
||||
|
||||
auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
|
||||
|
@ -574,7 +572,7 @@ Value MaterializeRankSpecializationForTwoNonScalarOperands(
|
|||
[&](OpBuilder &b, Location loc) {
|
||||
b.create<scf::YieldOp>(
|
||||
loc, MaterializeDefaultRankSpecializationCases(
|
||||
b, loc, op, shapes));
|
||||
b, loc, op, shapes, max_target_rank));
|
||||
}));
|
||||
});
|
||||
|
||||
|
@ -583,15 +581,16 @@ Value MaterializeRankSpecializationForTwoNonScalarOperands(
|
|||
}
|
||||
|
||||
// Materialize rank generic rank specialization.
|
||||
Value MaterializeDefaultRankSpecialization(
|
||||
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op) {
|
||||
Value MaterializeDefaultRankSpecialization(OpBuilder &b, Location loc,
|
||||
chlo::RankSpecializationClusterOp op,
|
||||
int64_t max_target_rank) {
|
||||
auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
|
||||
return b.create<shape::ShapeOfOp>(loc, v).result();
|
||||
}));
|
||||
|
||||
// Materialize all the different cases.
|
||||
Value unshaped_result =
|
||||
MaterializeDefaultRankSpecializationCases(b, loc, op, shapes);
|
||||
Value unshaped_result = MaterializeDefaultRankSpecializationCases(
|
||||
b, loc, op, shapes, max_target_rank);
|
||||
|
||||
// Materialize final reshape once and for all rank specialization cases.
|
||||
return MaterializeFinalReshape(b, loc, op, unshaped_result).front();
|
||||
|
@ -599,7 +598,10 @@ Value MaterializeDefaultRankSpecialization(
|
|||
|
||||
struct LowerRankSpecializationClusterPattern
|
||||
: public OpRewritePattern<chlo::RankSpecializationClusterOp> {
|
||||
using OpRewritePattern<chlo::RankSpecializationClusterOp>::OpRewritePattern;
|
||||
LowerRankSpecializationClusterPattern(MLIRContext *ctx,
|
||||
int64_t max_target_rank)
|
||||
: OpRewritePattern<chlo::RankSpecializationClusterOp>(ctx, /*benefit=*/1),
|
||||
max_target_rank(max_target_rank) {}
|
||||
|
||||
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
@ -630,22 +632,33 @@ struct LowerRankSpecializationClusterPattern
|
|||
llvm::all_of(non_scalar_operands, [](Value v) {
|
||||
return v.getType().isa<UnrankedTensorType>();
|
||||
})) {
|
||||
rewriter.replaceOp(op,
|
||||
MaterializeRankSpecializationForTwoNonScalarOperands(
|
||||
rewriter, loc, op, non_scalar_operands));
|
||||
rewriter.replaceOp(
|
||||
op, MaterializeRankSpecializationForTwoNonScalarOperands(
|
||||
rewriter, loc, op, non_scalar_operands, max_target_rank));
|
||||
return success();
|
||||
}
|
||||
|
||||
// For all other cases, reshape the operands to match in rank, apply the
|
||||
// operation, and restore the expected shape.
|
||||
rewriter.replaceOp(op,
|
||||
MaterializeDefaultRankSpecialization(rewriter, loc, op));
|
||||
rewriter.replaceOp(op, MaterializeDefaultRankSpecialization(
|
||||
rewriter, loc, op, max_target_rank));
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t max_target_rank;
|
||||
};
|
||||
|
||||
struct RankSpecializationToSCFPass
|
||||
: public PassWrapper<RankSpecializationToSCFPass, FunctionPass> {
|
||||
: public RankSpecializationToSCFPassBase<RankSpecializationToSCFPass> {
|
||||
using RankSpecializationToSCFPassBase<
|
||||
RankSpecializationToSCFPass>::RankSpecializationToSCFPassBase;
|
||||
explicit RankSpecializationToSCFPass(int64_t max_target_rank)
|
||||
: RankSpecializationToSCFPassBase<
|
||||
RankSpecializationToSCFPass>::RankSpecializationToSCFPassBase() {
|
||||
this->max_target_rank_ = max_target_rank;
|
||||
}
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<mhlo::MhloDialect, chlo::HloClientDialect,
|
||||
shape::ShapeDialect, scf::SCFDialect>();
|
||||
|
@ -654,7 +667,8 @@ struct RankSpecializationToSCFPass
|
|||
void runOnFunction() override {
|
||||
MLIRContext *ctx = &getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
PopulateRankSpecializationToSCFPatterns(ctx, &patterns);
|
||||
PopulateRankSpecializationToSCFPatterns(ctx, &patterns,
|
||||
this->max_target_rank_);
|
||||
if (failed(
|
||||
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
|
@ -669,15 +683,22 @@ void PopulateRankSpecializationClusterPatterns(
|
|||
patterns->insert<RankSpecializationClusterPattern>(context);
|
||||
}
|
||||
|
||||
void PopulateRankSpecializationToSCFPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||
patterns->insert<LowerRankSpecializationClusterPattern>(context);
|
||||
void PopulateRankSpecializationToSCFPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns,
|
||||
int64_t max_target_rank) {
|
||||
patterns->insert<LowerRankSpecializationClusterPattern>(context,
|
||||
max_target_rank);
|
||||
}
|
||||
|
||||
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass() {
|
||||
return std::make_unique<RankSpecializationClusterPass>();
|
||||
}
|
||||
|
||||
std::unique_ptr<FunctionPass> createRankSpecializationToSCFPass(
|
||||
int64_t max_target_rank) {
|
||||
return std::make_unique<RankSpecializationToSCFPass>(max_target_rank);
|
||||
}
|
||||
|
||||
std::unique_ptr<FunctionPass> createRankSpecializationToSCFPass() {
|
||||
return std::make_unique<RankSpecializationToSCFPass>();
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster | FileCheck %s
|
||||
// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster --mhlo-rank-specialization-to-scf | FileCheck %s --check-prefix CHECK-SCF
|
||||
// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster --mhlo-rank-specialization-to-scf=max-target-rank=8 | FileCheck %s --check-prefix CHECK-SCF
|
||||
|
||||
// CHECK-LABEL: @add_mul
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>)
|
||||
|
|
Loading…
Reference in New Issue