[MLIR][KernelGen] Make maximum supported rank in rank specialization configurable

The maximum supported target rank of 5 is sufficient for all operations but
`select`. Make the maximum target rank configurable in the rank specialization.
This reduces the number of generated kernels for operations that don't require
it.

PiperOrigin-RevId: 376822496
This commit is contained in:
A. Unique TensorFlower 2021-06-01 06:53:30 -07:00 committed by TensorFlow MLIR Team
parent c7c245eaf1
commit d1828625ab
6 changed files with 63 additions and 33 deletions

1
BUILD
View File

@ -812,6 +812,7 @@ cc_library(
], ],
deps = [ deps = [
":hlo", ":hlo",
":pass_details",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:InferTypeOpInterface",

View File

@ -127,11 +127,16 @@ def TestUnfuseBatchNormPass : Pass<"mhlo-test-unfuse-batch-norm", "FuncOp"> {
/// Rank specialization passes. /// Rank specialization passes.
def RankSpecializationClusterPass def RankSpecializationClusterPass
: Pass<"mhlo-rank-specialization-cluster", "FuncOp"> { : FunctionPass<"mhlo-rank-specialization-cluster"> {
let constructor = "createRankSpecializationClusterPass()"; let constructor = "createRankSpecializationClusterPass()";
} }
def RankSpecializationToSCFPass def RankSpecializationToSCFPass
: Pass<"mhlo-rank-specialization-to-scf", "FuncOp"> { : FunctionPass<"mhlo-rank-specialization-to-scf"> {
let constructor = "createRankSpecializationToSCFPass()"; let constructor = "createRankSpecializationToSCFPass()";
let options = [
Option<"max_target_rank_", "max-target-rank", "int", /*default=*/"8",
"The maximum supported rank after rank specialization. Any argument "
"of greater rank may result in a runtime failure.">,
];
} }

View File

@ -74,6 +74,8 @@ std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass();
/// specialization cluster. /// specialization cluster.
/// - Lower rank specialization clusters to SCF and ranked operations. /// - Lower rank specialization clusters to SCF and ranked operations.
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass(); std::unique_ptr<FunctionPass> createRankSpecializationClusterPass();
std::unique_ptr<FunctionPass> createRankSpecializationToSCFPass(
int64_t max_target_rank);
std::unique_ptr<FunctionPass> createRankSpecializationToSCFPass(); std::unique_ptr<FunctionPass> createRankSpecializationToSCFPass();
std::unique_ptr<FunctionPass> createOptimizeMhloPass(); std::unique_ptr<FunctionPass> createOptimizeMhloPass();

View File

@ -103,8 +103,9 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
/// Populate rank specialization clustering and lowering patterns. /// Populate rank specialization clustering and lowering patterns.
void PopulateRankSpecializationClusterPatterns( void PopulateRankSpecializationClusterPatterns(
MLIRContext *context, OwningRewritePatternList *patterns); MLIRContext *context, OwningRewritePatternList *patterns);
void PopulateRankSpecializationToSCFPatterns( void PopulateRankSpecializationToSCFPatterns(MLIRContext *context,
MLIRContext *context, OwningRewritePatternList *patterns); OwningRewritePatternList *patterns,
int64_t max_target_rank);
} // namespace mhlo } // namespace mhlo

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
@ -158,7 +159,7 @@ struct RankSpecializationClusterPattern : public RewritePattern {
}; };
struct RankSpecializationClusterPass struct RankSpecializationClusterPass
: public PassWrapper<RankSpecializationClusterPass, FunctionPass> { : public RankSpecializationClusterPassBase<RankSpecializationClusterPass> {
void getDependentDialects(DialectRegistry &registry) const override { void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<mhlo::MhloDialect, chlo::HloClientDialect>(); registry.insert<mhlo::MhloDialect, chlo::HloClientDialect>();
} }
@ -471,7 +472,7 @@ Value RecusivelyMaterializeTargetRankSpecializationCases(
Value MaterializeGenericRankSpecializationCases( Value MaterializeGenericRankSpecializationCases(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
const SmallVector<Value, 8> &shapes) { const SmallVector<Value, 8> &shapes, int64_t max_target_rank) {
// Get the minimum broadcast shapes of the operands. // Get the minimum broadcast shapes of the operands.
auto non_scalar_shapes = llvm::to_vector<8>(llvm::make_filter_range( auto non_scalar_shapes = llvm::to_vector<8>(llvm::make_filter_range(
shapes, [](Value v) { return !IsScalarShapeType(v.getType()); })); shapes, [](Value v) { return !IsScalarShapeType(v.getType()); }));
@ -505,22 +506,19 @@ Value MaterializeGenericRankSpecializationCases(
} }
} }
// Materialize rank specialization for ranks 1, ..., 8. // Materialize rank specialization for ranks 1, ...
// TODO(frgossen): For clusters w/o a select operation, consider only ranks
// 1, ..., 5.
const int64_t kMinTargetRank = 1;
const int64_t kMaxTargetRank = 8;
return RecusivelyMaterializeTargetRankSpecializationCases( return RecusivelyMaterializeTargetRankSpecializationCases(
b, loc, op, reduced_shapes, max_rank, kMinTargetRank, kMaxTargetRank); b, loc, op, reduced_shapes, max_rank, /*min_target_rank=*/1,
max_target_rank);
} }
Value MaterializeDefaultRankSpecializationCases( Value MaterializeDefaultRankSpecializationCases(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
const SmallVector<Value, 8> &shapes) { const SmallVector<Value, 8> &shapes, int64_t max_target_rank) {
return MaterializeEqualShapesRankSpecializationCase( return MaterializeEqualShapesRankSpecializationCase(
b, loc, op, shapes, [&](OpBuilder &b, Location loc) { b, loc, op, shapes, [&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>( b.create<scf::YieldOp>(loc, MaterializeGenericRankSpecializationCases(
loc, MaterializeGenericRankSpecializationCases(b, loc, op, shapes)); b, loc, op, shapes, max_target_rank));
}); });
} }
@ -555,7 +553,7 @@ SmallVector<Value, 8> MaterializeRankSpecializationForSingleNonScalarOperand(
Value MaterializeRankSpecializationForTwoNonScalarOperands( Value MaterializeRankSpecializationForTwoNonScalarOperands(
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op, OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op,
ValueRange non_scalar_operands) { ValueRange non_scalar_operands, int64_t max_target_rank) {
assert(non_scalar_operands.size() == 2); assert(non_scalar_operands.size() == 2);
auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) { auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
@ -574,7 +572,7 @@ Value MaterializeRankSpecializationForTwoNonScalarOperands(
[&](OpBuilder &b, Location loc) { [&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>( b.create<scf::YieldOp>(
loc, MaterializeDefaultRankSpecializationCases( loc, MaterializeDefaultRankSpecializationCases(
b, loc, op, shapes)); b, loc, op, shapes, max_target_rank));
})); }));
}); });
@ -583,15 +581,16 @@ Value MaterializeRankSpecializationForTwoNonScalarOperands(
} }
// Materialize rank generic rank specialization. // Materialize rank generic rank specialization.
Value MaterializeDefaultRankSpecialization( Value MaterializeDefaultRankSpecialization(OpBuilder &b, Location loc,
OpBuilder &b, Location loc, chlo::RankSpecializationClusterOp op) { chlo::RankSpecializationClusterOp op,
int64_t max_target_rank) {
auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) { auto shapes = llvm::to_vector<8>(llvm::map_range(op.operands(), [&](Value v) {
return b.create<shape::ShapeOfOp>(loc, v).result(); return b.create<shape::ShapeOfOp>(loc, v).result();
})); }));
// Materialize all the different cases. // Materialize all the different cases.
Value unshaped_result = Value unshaped_result = MaterializeDefaultRankSpecializationCases(
MaterializeDefaultRankSpecializationCases(b, loc, op, shapes); b, loc, op, shapes, max_target_rank);
// Materialize final reshape once and for all rank specialization cases. // Materialize final reshape once and for all rank specialization cases.
return MaterializeFinalReshape(b, loc, op, unshaped_result).front(); return MaterializeFinalReshape(b, loc, op, unshaped_result).front();
@ -599,7 +598,10 @@ Value MaterializeDefaultRankSpecialization(
struct LowerRankSpecializationClusterPattern struct LowerRankSpecializationClusterPattern
: public OpRewritePattern<chlo::RankSpecializationClusterOp> { : public OpRewritePattern<chlo::RankSpecializationClusterOp> {
using OpRewritePattern<chlo::RankSpecializationClusterOp>::OpRewritePattern; LowerRankSpecializationClusterPattern(MLIRContext *ctx,
int64_t max_target_rank)
: OpRewritePattern<chlo::RankSpecializationClusterOp>(ctx, /*benefit=*/1),
max_target_rank(max_target_rank) {}
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op, LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
@ -630,22 +632,33 @@ struct LowerRankSpecializationClusterPattern
llvm::all_of(non_scalar_operands, [](Value v) { llvm::all_of(non_scalar_operands, [](Value v) {
return v.getType().isa<UnrankedTensorType>(); return v.getType().isa<UnrankedTensorType>();
})) { })) {
rewriter.replaceOp(op, rewriter.replaceOp(
MaterializeRankSpecializationForTwoNonScalarOperands( op, MaterializeRankSpecializationForTwoNonScalarOperands(
rewriter, loc, op, non_scalar_operands)); rewriter, loc, op, non_scalar_operands, max_target_rank));
return success(); return success();
} }
// For all other cases, reshape the operands to match in rank, apply the // For all other cases, reshape the operands to match in rank, apply the
// operation, and restore the expected shape. // operation, and restore the expected shape.
rewriter.replaceOp(op, rewriter.replaceOp(op, MaterializeDefaultRankSpecialization(
MaterializeDefaultRankSpecialization(rewriter, loc, op)); rewriter, loc, op, max_target_rank));
return success(); return success();
} }
private:
int64_t max_target_rank;
}; };
struct RankSpecializationToSCFPass struct RankSpecializationToSCFPass
: public PassWrapper<RankSpecializationToSCFPass, FunctionPass> { : public RankSpecializationToSCFPassBase<RankSpecializationToSCFPass> {
using RankSpecializationToSCFPassBase<
RankSpecializationToSCFPass>::RankSpecializationToSCFPassBase;
explicit RankSpecializationToSCFPass(int64_t max_target_rank)
: RankSpecializationToSCFPassBase<
RankSpecializationToSCFPass>::RankSpecializationToSCFPassBase() {
this->max_target_rank_ = max_target_rank;
}
void getDependentDialects(DialectRegistry &registry) const override { void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<mhlo::MhloDialect, chlo::HloClientDialect, registry.insert<mhlo::MhloDialect, chlo::HloClientDialect,
shape::ShapeDialect, scf::SCFDialect>(); shape::ShapeDialect, scf::SCFDialect>();
@ -654,7 +667,8 @@ struct RankSpecializationToSCFPass
void runOnFunction() override { void runOnFunction() override {
MLIRContext *ctx = &getContext(); MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx); RewritePatternSet patterns(ctx);
PopulateRankSpecializationToSCFPatterns(ctx, &patterns); PopulateRankSpecializationToSCFPatterns(ctx, &patterns,
this->max_target_rank_);
if (failed( if (failed(
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) { applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
return signalPassFailure(); return signalPassFailure();
@ -669,15 +683,22 @@ void PopulateRankSpecializationClusterPatterns(
patterns->insert<RankSpecializationClusterPattern>(context); patterns->insert<RankSpecializationClusterPattern>(context);
} }
void PopulateRankSpecializationToSCFPatterns( void PopulateRankSpecializationToSCFPatterns(MLIRContext *context,
MLIRContext *context, OwningRewritePatternList *patterns) { OwningRewritePatternList *patterns,
patterns->insert<LowerRankSpecializationClusterPattern>(context); int64_t max_target_rank) {
patterns->insert<LowerRankSpecializationClusterPattern>(context,
max_target_rank);
} }
std::unique_ptr<FunctionPass> createRankSpecializationClusterPass() { std::unique_ptr<FunctionPass> createRankSpecializationClusterPass() {
return std::make_unique<RankSpecializationClusterPass>(); return std::make_unique<RankSpecializationClusterPass>();
} }
std::unique_ptr<FunctionPass> createRankSpecializationToSCFPass(
int64_t max_target_rank) {
return std::make_unique<RankSpecializationToSCFPass>(max_target_rank);
}
std::unique_ptr<FunctionPass> createRankSpecializationToSCFPass() { std::unique_ptr<FunctionPass> createRankSpecializationToSCFPass() {
return std::make_unique<RankSpecializationToSCFPass>(); return std::make_unique<RankSpecializationToSCFPass>();
} }

View File

@ -1,5 +1,5 @@
// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster | FileCheck %s // RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster | FileCheck %s
// RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster --mhlo-rank-specialization-to-scf | FileCheck %s --check-prefix CHECK-SCF // RUN: mlir-hlo-opt %s --split-input-file --mhlo-rank-specialization-cluster --mhlo-rank-specialization-to-scf=max-target-rank=8 | FileCheck %s --check-prefix CHECK-SCF
// CHECK-LABEL: @add_mul // CHECK-LABEL: @add_mul
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>) // CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>, %[[ARG2:.*]]: tensor<*xf32>)