[MLIR][HLO] Rename `move-up-dynamic-broadcasts-for-fusion` to `broadcast-propagation`

PiperOrigin-RevId: 378102608
This commit is contained in:
A. Unique TensorFlower 2021-06-08 01:50:02 -07:00 committed by TensorFlow MLIR Team
parent b2839c735b
commit c47869f931
7 changed files with 26 additions and 22 deletions

6
BUILD
View File

@ -781,8 +781,8 @@ cc_library(
) )
cc_library( cc_library(
name = "move_up_dynamic_broadcasts_for_fusion", name = "broadcast_propagation",
srcs = ["lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc"], srcs = ["lib/Dialect/mhlo/transforms/broadcast_propagation.cc"],
hdrs = [ hdrs = [
"include/mlir-hlo/Dialect/mhlo/transforms/passes.h", "include/mlir-hlo/Dialect/mhlo/transforms/passes.h",
"include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h", "include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h",
@ -1208,6 +1208,7 @@ cc_library(
deps = [ deps = [
":LmhloPassIncGen", ":LmhloPassIncGen",
":MhloPassIncGen", ":MhloPassIncGen",
":broadcast_propagation",
":chlo_legalize_to_hlo", ":chlo_legalize_to_hlo",
":hlo_legalize_to_lhlo", ":hlo_legalize_to_lhlo",
":legalize_control_flow", ":legalize_control_flow",
@ -1224,7 +1225,6 @@ cc_library(
":mhlo_control_flow_to_scf", ":mhlo_control_flow_to_scf",
":mhlo_fusion", ":mhlo_fusion",
":mhlo_to_mhlo_lowering_patterns", ":mhlo_to_mhlo_lowering_patterns",
":move_up_dynamic_broadcasts_for_fusion",
":rank_specialization", ":rank_specialization",
":sink_constants_to_control_flow", ":sink_constants_to_control_flow",
":test_passes", ":test_passes",

View File

@ -118,11 +118,11 @@ def TransformUnrankedHloPass : Pass<"mhlo-transform-unranked-hlo", "FuncOp"> {
let constructor = "createTransformUnrankedHloPass()"; let constructor = "createTransformUnrankedHloPass()";
} }
def MoveUpDynamicBroadcastsForFusionPass : def BroadcastPropagationPass : Pass<"mhlo-broadcast-propagation", "FuncOp"> {
Pass<"mhlo-move-up-dynamic-broadcasts-for-fusion", "FuncOp"> { let summary = "Move dynamic broadcasts up over element-wise operations and "
let summary = "Move up dynamic broadcasts and shape computations to allow " "broadcast the operands rather than the result. This will eventually allow "
"for fusion across broadcasts."; "for larger fusions.";
let constructor = "createMoveUpDynamicBroadcastsForFusionPass()"; let constructor = "createBroadcastPropagationPass()";
} }
def TestUnfuseBatchNormPass : Pass<"mhlo-test-unfuse-batch-norm", "FuncOp"> { def TestUnfuseBatchNormPass : Pass<"mhlo-test-unfuse-batch-norm", "FuncOp"> {

View File

@ -68,7 +68,10 @@ std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass();
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<FuncOp>>
createLegalizeTrigonometricToApproximationPass(); createLegalizeTrigonometricToApproximationPass();
std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass(); // Move dynamic broadcasts up over element-wise operations and broadcast the
// operands rather than the result. This will eventually allow for larger
// fusions.
std::unique_ptr<FunctionPass> createBroadcastPropagationPass();
/// Rank specialization passes: /// Rank specialization passes:
/// - Find compatible operations and group them together in one rank /// - Find compatible operations and group them together in one rank

View File

@ -104,10 +104,11 @@ void PopulateUnfuseBatchNormPatterns(MLIRContext *context,
void PopulateTrigonometricToApproximationPatterns( void PopulateTrigonometricToApproximationPatterns(
MLIRContext *context, OwningRewritePatternList *patterns); MLIRContext *context, OwningRewritePatternList *patterns);
void PopulateMoveUpDynamicBroadcastsForFusionLegality(ConversionTarget *target); // Populate patterns to move dynamic broadcasts up over element-wise operations
// and broadcast the operands rather than the result. This will eventually allow
void PopulateMoveUpDynamicBroadcastsForFusionPatterns( // for larger fusions.
MLIRContext *context, OwningRewritePatternList *patterns); void PopulateBroadcastsPropagationPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
/// Populate rank specialization clustering and lowering patterns. /// Populate rank specialization clustering and lowering patterns.
void PopulateRankSpecializationClusterPatterns( void PopulateRankSpecializationClusterPatterns(

View File

@ -48,6 +48,7 @@ add_mlir_library(ChloPasses
) )
add_mlir_library(MhloPasses add_mlir_library(MhloPasses
broadcast_propagation.cc
legalize_gather_to_torch_index_select.cc legalize_gather_to_torch_index_select.cc
legalize_trigonometric_to_approximation.cc legalize_trigonometric_to_approximation.cc
lower_complex.cc lower_complex.cc
@ -56,7 +57,6 @@ add_mlir_library(MhloPasses
materialize_broadcasts.cc materialize_broadcasts.cc
materialize_broadcasts_pass.cc materialize_broadcasts_pass.cc
mhlo_fusion.cc mhlo_fusion.cc
move_up_dynamic_broadcasts_for_fusion.cc
optimize_mhlo.cc optimize_mhlo.cc
optimize_mhlo_pass.cc optimize_mhlo_pass.cc
rank_specialization.cc rank_specialization.cc

View File

@ -355,8 +355,8 @@ struct MoveUpBroadcastInDimOpPattern
} }
}; };
struct MoveUpDynamicBroadcastsForFusionPass struct BroadcastPropagationPass
: public PassWrapper<MoveUpDynamicBroadcastsForFusionPass, FunctionPass> { : public PassWrapper<BroadcastPropagationPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override { void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<shape::ShapeDialect, mhlo::MhloDialect>(); registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
} }
@ -364,7 +364,7 @@ struct MoveUpDynamicBroadcastsForFusionPass
void runOnFunction() override { void runOnFunction() override {
MLIRContext *ctx = &getContext(); MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx); RewritePatternSet patterns(ctx);
mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(ctx, &patterns); mhlo::PopulateBroadcastsPropagationPatterns(ctx, &patterns);
if (failed( if (failed(
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) { applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
return signalPassFailure(); return signalPassFailure();
@ -374,8 +374,8 @@ struct MoveUpDynamicBroadcastsForFusionPass
} // namespace } // namespace
void PopulateMoveUpDynamicBroadcastsForFusionPatterns( void PopulateBroadcastsPropagationPatterns(MLIRContext *context,
MLIRContext *context, OwningRewritePatternList *patterns) { OwningRewritePatternList *patterns) {
// clang-format off // clang-format off
patterns->insert< patterns->insert<
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>, InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
@ -397,8 +397,8 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
tensor::CastOp::getCanonicalizationPatterns(*patterns, context); tensor::CastOp::getCanonicalizationPatterns(*patterns, context);
} }
std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass() { std::unique_ptr<FunctionPass> createBroadcastPropagationPass() {
return std::make_unique<MoveUpDynamicBroadcastsForFusionPass>(); return std::make_unique<BroadcastPropagationPass>();
} }
} // namespace mhlo } // namespace mhlo

View File

@ -1,4 +1,4 @@
// RUN: mlir-hlo-opt --split-input-file --allow-unregistered-dialect --mhlo-move-up-dynamic-broadcasts-for-fusion --canonicalize --cse %s | FileCheck %s // RUN: mlir-hlo-opt --split-input-file --allow-unregistered-dialect --mhlo-broadcast-propagation --canonicalize --cse %s | FileCheck %s
// Shape computations shall be reified. // Shape computations shall be reified.
// CHECK-LABEL: @shape_of_unary // CHECK-LABEL: @shape_of_unary