[MLIR][HLO] Rename `move-up-dynamic-broadcasts-for-fusion` to `broadcast-propagation`
PiperOrigin-RevId: 378102608
This commit is contained in:
parent
b2839c735b
commit
c47869f931
6
BUILD
6
BUILD
|
@ -781,8 +781,8 @@ cc_library(
|
|||
)
|
||||
|
||||
cc_library(
|
||||
name = "move_up_dynamic_broadcasts_for_fusion",
|
||||
srcs = ["lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc"],
|
||||
name = "broadcast_propagation",
|
||||
srcs = ["lib/Dialect/mhlo/transforms/broadcast_propagation.cc"],
|
||||
hdrs = [
|
||||
"include/mlir-hlo/Dialect/mhlo/transforms/passes.h",
|
||||
"include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h",
|
||||
|
@ -1208,6 +1208,7 @@ cc_library(
|
|||
deps = [
|
||||
":LmhloPassIncGen",
|
||||
":MhloPassIncGen",
|
||||
":broadcast_propagation",
|
||||
":chlo_legalize_to_hlo",
|
||||
":hlo_legalize_to_lhlo",
|
||||
":legalize_control_flow",
|
||||
|
@ -1224,7 +1225,6 @@ cc_library(
|
|||
":mhlo_control_flow_to_scf",
|
||||
":mhlo_fusion",
|
||||
":mhlo_to_mhlo_lowering_patterns",
|
||||
":move_up_dynamic_broadcasts_for_fusion",
|
||||
":rank_specialization",
|
||||
":sink_constants_to_control_flow",
|
||||
":test_passes",
|
||||
|
|
|
@ -118,11 +118,11 @@ def TransformUnrankedHloPass : Pass<"mhlo-transform-unranked-hlo", "FuncOp"> {
|
|||
let constructor = "createTransformUnrankedHloPass()";
|
||||
}
|
||||
|
||||
def MoveUpDynamicBroadcastsForFusionPass :
|
||||
Pass<"mhlo-move-up-dynamic-broadcasts-for-fusion", "FuncOp"> {
|
||||
let summary = "Move up dynamic broadcasts and shape computations to allow "
|
||||
"for fusion across broadcasts.";
|
||||
let constructor = "createMoveUpDynamicBroadcastsForFusionPass()";
|
||||
def BroadcastPropagationPass : Pass<"mhlo-broadcast-propagation", "FuncOp"> {
|
||||
let summary = "Move dynamic broadcasts up over element-wise operations and "
|
||||
"broadcast the operands rather than the result. This will eventually allow "
|
||||
"for larger fusions.";
|
||||
let constructor = "createBroadcastPropagationPass()";
|
||||
}
|
||||
|
||||
def TestUnfuseBatchNormPass : Pass<"mhlo-test-unfuse-batch-norm", "FuncOp"> {
|
||||
|
|
|
@ -68,7 +68,10 @@ std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass();
|
|||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
createLegalizeTrigonometricToApproximationPass();
|
||||
|
||||
std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass();
|
||||
// Move dynamic broadcasts up over element-wise operations and broadcast the
|
||||
// operands rather than the result. This will eventually allow for larger
|
||||
// fusions.
|
||||
std::unique_ptr<FunctionPass> createBroadcastPropagationPass();
|
||||
|
||||
/// Rank specialization passes:
|
||||
/// - Find compatible operations and group them together in one rank
|
||||
|
|
|
@ -104,10 +104,11 @@ void PopulateUnfuseBatchNormPatterns(MLIRContext *context,
|
|||
void PopulateTrigonometricToApproximationPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList *patterns);
|
||||
|
||||
void PopulateMoveUpDynamicBroadcastsForFusionLegality(ConversionTarget *target);
|
||||
|
||||
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList *patterns);
|
||||
// Populate patterns to move dynamic broadcasts up over element-wise operations
|
||||
// and broadcast the operands rather than the result. This will eventually allow
|
||||
// for larger fusions.
|
||||
void PopulateBroadcastsPropagationPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
/// Populate rank specialization clustering and lowering patterns.
|
||||
void PopulateRankSpecializationClusterPatterns(
|
||||
|
|
|
@ -48,6 +48,7 @@ add_mlir_library(ChloPasses
|
|||
)
|
||||
|
||||
add_mlir_library(MhloPasses
|
||||
broadcast_propagation.cc
|
||||
legalize_gather_to_torch_index_select.cc
|
||||
legalize_trigonometric_to_approximation.cc
|
||||
lower_complex.cc
|
||||
|
@ -56,7 +57,6 @@ add_mlir_library(MhloPasses
|
|||
materialize_broadcasts.cc
|
||||
materialize_broadcasts_pass.cc
|
||||
mhlo_fusion.cc
|
||||
move_up_dynamic_broadcasts_for_fusion.cc
|
||||
optimize_mhlo.cc
|
||||
optimize_mhlo_pass.cc
|
||||
rank_specialization.cc
|
||||
|
|
|
@ -355,8 +355,8 @@ struct MoveUpBroadcastInDimOpPattern
|
|||
}
|
||||
};
|
||||
|
||||
struct MoveUpDynamicBroadcastsForFusionPass
|
||||
: public PassWrapper<MoveUpDynamicBroadcastsForFusionPass, FunctionPass> {
|
||||
struct BroadcastPropagationPass
|
||||
: public PassWrapper<BroadcastPropagationPass, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
|
||||
}
|
||||
|
@ -364,7 +364,7 @@ struct MoveUpDynamicBroadcastsForFusionPass
|
|||
void runOnFunction() override {
|
||||
MLIRContext *ctx = &getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(ctx, &patterns);
|
||||
mhlo::PopulateBroadcastsPropagationPatterns(ctx, &patterns);
|
||||
if (failed(
|
||||
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
|
@ -374,8 +374,8 @@ struct MoveUpDynamicBroadcastsForFusionPass
|
|||
|
||||
} // namespace
|
||||
|
||||
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
||||
MLIRContext *context, OwningRewritePatternList *patterns) {
|
||||
void PopulateBroadcastsPropagationPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns) {
|
||||
// clang-format off
|
||||
patterns->insert<
|
||||
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
|
||||
|
@ -397,8 +397,8 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
|||
tensor::CastOp::getCanonicalizationPatterns(*patterns, context);
|
||||
}
|
||||
|
||||
std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass() {
|
||||
return std::make_unique<MoveUpDynamicBroadcastsForFusionPass>();
|
||||
std::unique_ptr<FunctionPass> createBroadcastPropagationPass() {
|
||||
return std::make_unique<BroadcastPropagationPass>();
|
||||
}
|
||||
|
||||
} // namespace mhlo
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-hlo-opt --split-input-file --allow-unregistered-dialect --mhlo-move-up-dynamic-broadcasts-for-fusion --canonicalize --cse %s | FileCheck %s
|
||||
// RUN: mlir-hlo-opt --split-input-file --allow-unregistered-dialect --mhlo-broadcast-propagation --canonicalize --cse %s | FileCheck %s
|
||||
|
||||
// Shape computations shall be reified.
|
||||
// CHECK-LABEL: @shape_of_unary
|
Loading…
Reference in New Issue