[MLIR][HLO] Rename `move-up-dynamic-broadcasts-for-fusion` to `broadcast-propagation`

PiperOrigin-RevId: 378102608
This commit is contained in:
A. Unique TensorFlower 2021-06-08 01:50:02 -07:00 committed by TensorFlow MLIR Team
parent b2839c735b
commit c47869f931
7 changed files with 26 additions and 22 deletions

6
BUILD
View File

@ -781,8 +781,8 @@ cc_library(
)
cc_library(
name = "move_up_dynamic_broadcasts_for_fusion",
srcs = ["lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc"],
name = "broadcast_propagation",
srcs = ["lib/Dialect/mhlo/transforms/broadcast_propagation.cc"],
hdrs = [
"include/mlir-hlo/Dialect/mhlo/transforms/passes.h",
"include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h",
@ -1208,6 +1208,7 @@ cc_library(
deps = [
":LmhloPassIncGen",
":MhloPassIncGen",
":broadcast_propagation",
":chlo_legalize_to_hlo",
":hlo_legalize_to_lhlo",
":legalize_control_flow",
@ -1224,7 +1225,6 @@ cc_library(
":mhlo_control_flow_to_scf",
":mhlo_fusion",
":mhlo_to_mhlo_lowering_patterns",
":move_up_dynamic_broadcasts_for_fusion",
":rank_specialization",
":sink_constants_to_control_flow",
":test_passes",

View File

@ -118,11 +118,11 @@ def TransformUnrankedHloPass : Pass<"mhlo-transform-unranked-hlo", "FuncOp"> {
let constructor = "createTransformUnrankedHloPass()";
}
def MoveUpDynamicBroadcastsForFusionPass :
Pass<"mhlo-move-up-dynamic-broadcasts-for-fusion", "FuncOp"> {
let summary = "Move up dynamic broadcasts and shape computations to allow "
"for fusion across broadcasts.";
let constructor = "createMoveUpDynamicBroadcastsForFusionPass()";
def BroadcastPropagationPass : Pass<"mhlo-broadcast-propagation", "FuncOp"> {
let summary = "Move dynamic broadcasts up over element-wise operations and "
"broadcast the operands rather than the result. This will eventually allow "
"for larger fusions.";
let constructor = "createBroadcastPropagationPass()";
}
def TestUnfuseBatchNormPass : Pass<"mhlo-test-unfuse-batch-norm", "FuncOp"> {

View File

@ -68,7 +68,10 @@ std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass();
std::unique_ptr<OperationPass<FuncOp>>
createLegalizeTrigonometricToApproximationPass();
std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass();
// Move dynamic broadcasts up over element-wise operations and broadcast the
// operands rather than the result. This will eventually allow for larger
// fusions.
std::unique_ptr<FunctionPass> createBroadcastPropagationPass();
/// Rank specialization passes:
/// - Find compatible operations and group them together in one rank

View File

@ -104,10 +104,11 @@ void PopulateUnfuseBatchNormPatterns(MLIRContext *context,
void PopulateTrigonometricToApproximationPatterns(
MLIRContext *context, OwningRewritePatternList *patterns);
void PopulateMoveUpDynamicBroadcastsForFusionLegality(ConversionTarget *target);
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
MLIRContext *context, OwningRewritePatternList *patterns);
// Populate patterns to move dynamic broadcasts up over element-wise operations
// and broadcast the operands rather than the result. This will eventually allow
// for larger fusions.
void PopulateBroadcastsPropagationPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
/// Populate rank specialization clustering and lowering patterns.
void PopulateRankSpecializationClusterPatterns(

View File

@ -48,6 +48,7 @@ add_mlir_library(ChloPasses
)
add_mlir_library(MhloPasses
broadcast_propagation.cc
legalize_gather_to_torch_index_select.cc
legalize_trigonometric_to_approximation.cc
lower_complex.cc
@ -56,7 +57,6 @@ add_mlir_library(MhloPasses
materialize_broadcasts.cc
materialize_broadcasts_pass.cc
mhlo_fusion.cc
move_up_dynamic_broadcasts_for_fusion.cc
optimize_mhlo.cc
optimize_mhlo_pass.cc
rank_specialization.cc

View File

@ -355,8 +355,8 @@ struct MoveUpBroadcastInDimOpPattern
}
};
struct MoveUpDynamicBroadcastsForFusionPass
: public PassWrapper<MoveUpDynamicBroadcastsForFusionPass, FunctionPass> {
struct BroadcastPropagationPass
: public PassWrapper<BroadcastPropagationPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
}
@ -364,7 +364,7 @@ struct MoveUpDynamicBroadcastsForFusionPass
void runOnFunction() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(ctx, &patterns);
mhlo::PopulateBroadcastsPropagationPatterns(ctx, &patterns);
if (failed(
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
return signalPassFailure();
@ -374,8 +374,8 @@ struct MoveUpDynamicBroadcastsForFusionPass
} // namespace
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
MLIRContext *context, OwningRewritePatternList *patterns) {
void PopulateBroadcastsPropagationPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
// clang-format off
patterns->insert<
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
@ -397,8 +397,8 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
tensor::CastOp::getCanonicalizationPatterns(*patterns, context);
}
std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass() {
return std::make_unique<MoveUpDynamicBroadcastsForFusionPass>();
std::unique_ptr<FunctionPass> createBroadcastPropagationPass() {
return std::make_unique<BroadcastPropagationPass>();
}
} // namespace mhlo

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt --split-input-file --allow-unregistered-dialect --mhlo-move-up-dynamic-broadcasts-for-fusion --canonicalize --cse %s | FileCheck %s
// RUN: mlir-hlo-opt --split-input-file --allow-unregistered-dialect --mhlo-broadcast-propagation --canonicalize --cse %s | FileCheck %s
// Shape computations shall be reified.
// CHECK-LABEL: @shape_of_unary