[MLIR][KernelGen] Make maximum supported rank in rank specialization configurable
The maximum supported target rank of 5 is sufficient for all operations but `select`. Make the maximum target rank configurable in the rank specialization. This reduces the number of generated kernels for operations that don't require it. PiperOrigin-RevId: 376822496
This commit is contained in:
parent
c7c245eaf1
commit
d1828625ab
1
BUILD
1
BUILD
|
@ -812,6 +812,7 @@ cc_library(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":hlo",
|
":hlo",
|
||||||
|
":pass_details",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:InferTypeOpInterface",
|
"@llvm-project//mlir:InferTypeOpInterface",
|
||||||
|
|
|
@ -127,11 +127,16 @@ def TestUnfuseBatchNormPass : Pass<"mhlo-test-unfuse-batch-norm", "FuncOp"> {
|
||||||
/// Rank specialization passes.
|
/// Rank specialization passes.
|
||||||
|
|
||||||
def RankSpecializationClusterPass
|
def RankSpecializationClusterPass
|
||||||
: Pass<"mhlo-rank-specialization-cluster", "FuncOp"> {
|
: FunctionPass<"mhlo-rank-specialization-cluster"> {
|
||||||
let constructor = "createRankSpecializationClusterPass()";
|
let constructor = "createRankSpecializationClusterPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
def RankSpecializationToSCFPass
|
def RankSpecializationToSCFPass
|
||||||
: Pass<"mhlo-rank-specialization-to-scf", "FuncOp"> {
|
: FunctionPass<"mhlo-rank-specialization-to-scf"> {
|
||||||
let constructor = "createRankSpecializationToSCFPass()";
|
let constructor = "createRankSpecializationToSCFPass()";
|
||||||
|
let options = [
|
||||||
|
Option<"max_target_rank_", "max-target-rank", "int", /*default=*/"8",
|
||||||
|
"The maximum supported rank after rank specialization. Any argument "
|
||||||
|
"of greater rank may result in a runtime failure.">,
|
||||||
|
];
|
||||||
}
|
}
|
||||||
|
|
|
@ -74,6 +74,8 @@ std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass();
|
||||||
/// specialization cluster.
|
/// specialization cluster.
|
||||||
/// - Lower rank specialization clusters to SCF and ranked operations.
|
/// - Lower rank specialization clusters to SCF and ranked operations.
|
||||||
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass();
|
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass();
|
||||||
|
std::unique_ptr<FunctionPass> createRankSpecializationToSCFPass(
|
||||||
|
int64_t max_target_rank);
|
||||||
std::unique_ptr<FunctionPass> createRankSpecializationToSCFPass();
|
std::unique_ptr<FunctionPass> createRankSpecializationToSCFPass();
|
||||||
|
|
||||||
std::unique_ptr<FunctionPass> createOptimizeMhloPass();
|
std::unique_ptr<FunctionPass> createOptimizeMhloPass();
|
||||||
|
|
|
@ -103,8 +103,9 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
||||||
/// Populate rank specialization clustering and lowering patterns.
|
/// Populate rank specialization clustering and lowering patterns.
|
||||||
void PopulateRankSpecializationClusterPatterns(
|
void PopulateRankSpecializationClusterPatterns(
|
||||||
MLIRContext *context, OwningRewritePatternList *patterns);
|
MLIRContext *context, OwningRewritePatternList *patterns);
|
||||||
void PopulateRankSpecializationToSCFPatterns(
|
void PopulateRankSpecializationToSCFPatterns(MLIRContext *context,
|
||||||
MLIRContext *context, OwningRewritePatternList *patterns);
|
OwningRewritePatternList *patterns,
|
||||||
|
int64_t max_target_rank);
|
||||||
|
|
||||||
} // namespace mhlo
|
} // namespace mhlo
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
|
@ -158,7 +159,7 @@ struct RankSpecializationClusterPattern : public RewritePattern {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct RankSpecializationClusterPass
|
struct RankSpecializationClusterPass
|
||||||
: public PassWrapper<RankSpecializationClusterPass, FunctionPass> {
|
: public RankSpecializationClusterPassBase<RankSpecializationClusterPass> {
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<mhlo::MhloDialect, chlo::HloClientDialect>();
|
registry.insert<mhlo::MhloDialect, chlo::HloClientDialect>();
|
||||||
}
|
}
|
||||||
|
@ -471,7 +472,7 @@ Value RecusivelyMaterializeTargetRankSpecializationCases(
|
||||||
|
|
||||||
Value MaterializeGenericRankSpecializationCases(
|
Value MaterializeGenericRankSpecializationCases(
|
||||||
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
||||||
const SmallVector<Value, 8> &shapes) {
|
const SmallVector<Value, 8> &shapes, int64_t max_target_rank) {
|
||||||
// Get the minimum broadcast shapes of the operands.
|
// Get the minimum broadcast shapes of the operands.
|
||||||
auto non_scalar_shapes = llvm::to_vector<8>(llvm::make_filter_range(
|
auto non_scalar_shapes = llvm::to_vector<8>(llvm::make_filter_range(
|
||||||
shapes, [](Value v) { return !IsScalarShapeType(v.getType()); }));
|
shapes, [](Value v) { return !IsScalarShapeType(v.getType()); }));
|
||||||
|
@ -505,22 +506,19 @@ Value MaterializeGenericRankSpecializationCases(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Materialize rank specialization for ranks 1, ..., 8.
|
// Materialize rank specialization for ranks 1, ...
|
||||||
// TODO(frgossen): For clusters w/o a select operation, consider only ranks
|
|
||||||
// 1, ..., 5.
|
|
||||||
const int64_t kMinTargetRank = 1;
|
|
||||||
const int64_t kMaxTargetRank = 8;
|
|
||||||
return RecusivelyMaterializeTargetRankSpecializationCases(
|
return RecusivelyMaterializeTargetRankSpecializationCases(
|
||||||
b, loc, op, reduced_shapes, max_rank, kMinTargetRank, kMaxTargetRank);
|
b, loc, op, reduced_shapes, max_rank, /*min_target_rank=*/1,
|
||||||
|
max_target_rank);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value MaterializeDefaultRankSpecializationCases(
|
Value MaterializeDefaultRankSpecializationCases(
|
||||||
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
||||||
const SmallVector<Value, 8> &shapes) {
|
const SmallVector<Value, 8> &shapes, int64_t max_target_rank) {
|
||||||
return MaterializeEqualShapesRankSpecializationCase(
|
return MaterializeEqualShapesRankSpecializationCase(
|
||||||
b, loc, op, shapes, [&](OpBuilder &b, Location loc) {
|
b, loc, op, shapes, [&](OpBuilder &b, Location loc) {
|
||||||
b.create<scf::YieldOp>(
|
b.create<scf::YieldOp>(loc, MaterializeGenericRankSpecializationCases(
|
||||||
loc, MaterializeGenericRankSpecializationCases(b, loc, op, shapes));
|
b, loc, op, shapes, max_target_rank));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -555,7 +553,7 @@ SmallVector<Value, 8> MaterializeRankSpecializationForSingleNonScalarOperand(
|
||||||
|
|
||||||
Value MaterializeRankSpecializationForTwoNonScalarOperands(
|
Value MaterializeRankSpecializationForTwoNonScalarOperands(
|
||||||
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
|
||||||
ValueRange non_scalar_operands) {
|
ValueRange non_scalar_operands, int64_t max_target_rank) {
|
||||||
assert(non_scalar_operands.size() == 2);
|
assert(non_scalar_operands.size() == 2);
|
||||||
|
|
||||||
auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
|
auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
|
||||||
|
@ -574,7 +572,7 @@ Value MaterializeRankSpecializationForTwoNonScalarOperands(
|
||||||
[&](OpBuilder &b, Location loc) {
|
[&](OpBuilder &b, Location loc) {
|
||||||
b.create<scf::YieldOp>(
|
b.create<scf::YieldOp>(
|
||||||
loc, MaterializeDefaultRankSpecializationCases(
|
loc, MaterializeDefaultRankSpecializationCases(
|
||||||
b, loc, op, shapes));
|
b, loc, op, shapes, max_target_rank));
|
||||||
}));
|
}));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -583,15 +581,16 @@ Value MaterializeRankSpecializationForTwoNonScalarOperands(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Materialize rank generic rank specialization.
|
// Materialize rank generic rank specialization.
|
||||||
Value MaterializeDefaultRankSpecialization(
|
Value MaterializeDefaultRankSpecialization(OpBuilder &b, Location loc,
|
||||||
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op) {
|
chlo::RankSpecializationClusterOp op,
|
||||||
|
int64_t max_target_rank) {
|
||||||
auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
|
auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
|
||||||
return b.create<shape::ShapeOfOp>(loc, v).result();
|
return b.create<shape::ShapeOfOp>(loc, v).result();
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Materialize all the different cases.
|
// Materialize all the different cases.
|
||||||
Value unshaped_result =
|
Value unshaped_result = MaterializeDefaultRankSpecializationCases(
|
||||||
MaterializeDefaultRankSpecializationCases(b, loc, op, shapes);
|
b, loc, op, shapes, max_target_rank);
|
||||||
|
|
||||||
// Materialize final reshape once and for all rank specialization cases.
|
// Materialize final reshape once and for all rank specialization cases.
|
||||||
return MaterializeFinalReshape(b, loc, op, unshaped_result).front();
|
return MaterializeFinalReshape(b, loc, op, unshaped_result).front();
|
||||||
|
@ -599,7 +598,10 @@ Value MaterializeDefaultRankSpecialization(
|
||||||
|
|
||||||
struct LowerRankSpecializationClusterPattern
|
struct LowerRankSpecializationClusterPattern
|
||||||
: public OpRewritePattern<chlo::RankSpecializationClusterOp> {
|
: public OpRewritePattern<chlo::RankSpecializationClusterOp> {
|
||||||
using OpRewritePattern<chlo::RankSpecializationClusterOp>::OpRewritePattern;
|
LowerRankSpecializationClusterPattern(MLIRContext *ctx,
|
||||||
|
int64_t max_target_rank)
|
||||||
|
: OpRewritePattern<chlo::RankSpecializationClusterOp>(ctx, /*benefit=*/1),
|
||||||
|
max_target_rank(max_target_rank) {}
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
|
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
|
@ -630,22 +632,33 @@ struct LowerRankSpecializationClusterPattern
|
||||||
llvm::all_of(non_scalar_operands, [](Value v) {
|
llvm::all_of(non_scalar_operands, [](Value v) {
|
||||||
return v.getType().isa<UnrankedTensorType>();
|
return v.getType().isa<UnrankedTensorType>();
|
||||||
})) {
|
})) {
|
||||||
rewriter.replaceOp(op,
|
rewriter.replaceOp(
|
||||||
MaterializeRankSpecializationForTwoNonScalarOperands(
|
op, MaterializeRankSpecializationForTwoNonScalarOperands(
|
||||||
rewriter, loc, op, non_scalar_operands));
|
rewriter, loc, op, non_scalar_operands, max_target_rank));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// For all other cases, reshape the operands to match in rank, apply the
|
// For all other cases, reshape the operands to match in rank, apply the
|
||||||
// operation, and restore the expected shape.
|
// operation, and restore the expected shape.
|
||||||
rewriter.replaceOp(op,
|
rewriter.replaceOp(op, MaterializeDefaultRankSpecialization(
|
||||||
MaterializeDefaultRankSpecialization(rewriter, loc, op));
|
rewriter, loc, op, max_target_rank));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int64_t max_target_rank;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct RankSpecializationToSCFPass
|
struct RankSpecializationToSCFPass
|
||||||
: public PassWrapper<RankSpecializationToSCFPass, FunctionPass> {
|
: public RankSpecializationToSCFPassBase<RankSpecializationToSCFPass> {
|
||||||
|
using RankSpecializationToSCFPassBase<
|
||||||
|
RankSpecializationToSCFPass>::RankSpecializationToSCFPassBase;
|
||||||
|
explicit RankSpecializationToSCFPass(int64_t max_target_rank)
|
||||||
|
: RankSpecializationToSCFPassBase<
|
||||||
|
RankSpecializationToSCFPass>::RankSpecializationToSCFPassBase() {
|
||||||
|
this->max_target_rank_ = max_target_rank;
|
||||||
|
}
|
||||||
|
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<mhlo::MhloDialect, chlo::HloClientDialect,
|
registry.insert<mhlo::MhloDialect, chlo::HloClientDialect,
|
||||||
shape::ShapeDialect, scf::SCFDialect>();
|
shape::ShapeDialect, scf::SCFDialect>();
|
||||||
|
@ -654,7 +667,8 @@ struct RankSpecializationToSCFPass
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
MLIRContext *ctx = &getContext();
|
MLIRContext *ctx = &getContext();
|
||||||
RewritePatternSet patterns(ctx);
|
RewritePatternSet patterns(ctx);
|
||||||
PopulateRankSpecializationToSCFPatterns(ctx, &patterns);
|
PopulateRankSpecializationToSCFPatterns(ctx, &patterns,
|
||||||
|
this->max_target_rank_);
|
||||||
if (failed(
|
if (failed(
|
||||||
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
|
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
|
@ -669,15 +683,22 @@ void PopulateRankSpecializationClusterPatterns(
|
||||||
patterns->insert<RankSpecializationClusterPattern>(context);
|
patterns->insert<RankSpecializationClusterPattern>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PopulateRankSpecializationToSCFPatterns(
|
void PopulateRankSpecializationToSCFPatterns(MLIRContext *context,
|
||||||
MLIRContext *context, OwningRewritePatternList *patterns) {
|
OwningRewritePatternList *patterns,
|
||||||
patterns->insert<LowerRankSpecializationClusterPattern>(context);
|
int64_t max_target_rank) {
|
||||||
|
patterns->insert<LowerRankSpecializationClusterPattern>(context,
|
||||||
|
max_target_rank);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass() {
|
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass() {
|
||||||
return std::make_unique<RankSpecializationClusterPass>();
|
return std::make_unique<RankSpecializationClusterPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<FunctionPass> createRankSpecializationToSCFPass(
|
||||||
|
int64_t max_target_rank) {
|
||||||
|
return std::make_unique<RankSpecializationToSCFPass>(max_target_rank);
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<FunctionPass> createRankSpecializationToSCFPass() {
|
std::unique_ptr<FunctionPass> createRankSpecializationToSCFPass() {
|
||||||
return std::make_unique<RankSpecializationToSCFPass>();
|
return std::make_unique<RankSpecializationToSCFPass>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster | FileCheck %s
|
// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster | FileCheck %s
|
||||||
// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster --mhlo-rank-specialization-to-scf | FileCheck %s --check-prefix CHECK-SCF
|
// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster --mhlo-rank-specialization-to-scf=max-target-rank=8 | FileCheck %s --check-prefix CHECK-SCF
|
||||||
|
|
||||||
// CHECK-LABEL: @add_mul
|
// CHECK-LABEL: @add_mul
|
||||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>)
|
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>)
|
||||||
|
|
Loading…
Reference in New Issue