[MLIR][HLO] Rename `move-up-dynamic-broadcasts-for-fusion` to `broadcast-propagation`
PiperOrigin-RevId: 378102608
This commit is contained in:
parent
b2839c735b
commit
c47869f931
6
BUILD
6
BUILD
|
@ -781,8 +781,8 @@ cc_library(
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "move_up_dynamic_broadcasts_for_fusion",
|
name = "broadcast_propagation",
|
||||||
srcs = ["lib/Dialect/mhlo/transforms/move_up_dynamic_broadcasts_for_fusion.cc"],
|
srcs = ["lib/Dialect/mhlo/transforms/broadcast_propagation.cc"],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"include/mlir-hlo/Dialect/mhlo/transforms/passes.h",
|
"include/mlir-hlo/Dialect/mhlo/transforms/passes.h",
|
||||||
"include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h",
|
"include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h",
|
||||||
|
@ -1208,6 +1208,7 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":LmhloPassIncGen",
|
":LmhloPassIncGen",
|
||||||
":MhloPassIncGen",
|
":MhloPassIncGen",
|
||||||
|
":broadcast_propagation",
|
||||||
":chlo_legalize_to_hlo",
|
":chlo_legalize_to_hlo",
|
||||||
":hlo_legalize_to_lhlo",
|
":hlo_legalize_to_lhlo",
|
||||||
":legalize_control_flow",
|
":legalize_control_flow",
|
||||||
|
@ -1224,7 +1225,6 @@ cc_library(
|
||||||
":mhlo_control_flow_to_scf",
|
":mhlo_control_flow_to_scf",
|
||||||
":mhlo_fusion",
|
":mhlo_fusion",
|
||||||
":mhlo_to_mhlo_lowering_patterns",
|
":mhlo_to_mhlo_lowering_patterns",
|
||||||
":move_up_dynamic_broadcasts_for_fusion",
|
|
||||||
":rank_specialization",
|
":rank_specialization",
|
||||||
":sink_constants_to_control_flow",
|
":sink_constants_to_control_flow",
|
||||||
":test_passes",
|
":test_passes",
|
||||||
|
|
|
@ -118,11 +118,11 @@ def TransformUnrankedHloPass : Pass<"mhlo-transform-unranked-hlo", "FuncOp"> {
|
||||||
let constructor = "createTransformUnrankedHloPass()";
|
let constructor = "createTransformUnrankedHloPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
def MoveUpDynamicBroadcastsForFusionPass :
|
def BroadcastPropagationPass : Pass<"mhlo-broadcast-propagation", "FuncOp"> {
|
||||||
Pass<"mhlo-move-up-dynamic-broadcasts-for-fusion", "FuncOp"> {
|
let summary = "Move dynamic broadcasts up over element-wise operations and "
|
||||||
let summary = "Move up dynamic broadcasts and shape computations to allow "
|
"broadcast the operands rather than the result. This will eventually allow "
|
||||||
"for fusion across broadcasts.";
|
"for larger fusions.";
|
||||||
let constructor = "createMoveUpDynamicBroadcastsForFusionPass()";
|
let constructor = "createBroadcastPropagationPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
def TestUnfuseBatchNormPass : Pass<"mhlo-test-unfuse-batch-norm", "FuncOp"> {
|
def TestUnfuseBatchNormPass : Pass<"mhlo-test-unfuse-batch-norm", "FuncOp"> {
|
||||||
|
|
|
@ -68,7 +68,10 @@ std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass();
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
createLegalizeTrigonometricToApproximationPass();
|
createLegalizeTrigonometricToApproximationPass();
|
||||||
|
|
||||||
std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass();
|
// Move dynamic broadcasts up over element-wise operations and broadcast the
|
||||||
|
// operands rather than the result. This will eventually allow for larger
|
||||||
|
// fusions.
|
||||||
|
std::unique_ptr<FunctionPass> createBroadcastPropagationPass();
|
||||||
|
|
||||||
/// Rank specialization passes:
|
/// Rank specialization passes:
|
||||||
/// - Find compatible operations and group them together in one rank
|
/// - Find compatible operations and group them together in one rank
|
||||||
|
|
|
@ -104,10 +104,11 @@ void PopulateUnfuseBatchNormPatterns(MLIRContext *context,
|
||||||
void PopulateTrigonometricToApproximationPatterns(
|
void PopulateTrigonometricToApproximationPatterns(
|
||||||
MLIRContext *context, OwningRewritePatternList *patterns);
|
MLIRContext *context, OwningRewritePatternList *patterns);
|
||||||
|
|
||||||
void PopulateMoveUpDynamicBroadcastsForFusionLegality(ConversionTarget *target);
|
// Populate patterns to move dynamic broadcasts up over element-wise operations
|
||||||
|
// and broadcast the operands rather than the result. This will eventually allow
|
||||||
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
// for larger fusions.
|
||||||
MLIRContext *context, OwningRewritePatternList *patterns);
|
void PopulateBroadcastsPropagationPatterns(MLIRContext *context,
|
||||||
|
OwningRewritePatternList *patterns);
|
||||||
|
|
||||||
/// Populate rank specialization clustering and lowering patterns.
|
/// Populate rank specialization clustering and lowering patterns.
|
||||||
void PopulateRankSpecializationClusterPatterns(
|
void PopulateRankSpecializationClusterPatterns(
|
||||||
|
|
|
@ -48,6 +48,7 @@ add_mlir_library(ChloPasses
|
||||||
)
|
)
|
||||||
|
|
||||||
add_mlir_library(MhloPasses
|
add_mlir_library(MhloPasses
|
||||||
|
broadcast_propagation.cc
|
||||||
legalize_gather_to_torch_index_select.cc
|
legalize_gather_to_torch_index_select.cc
|
||||||
legalize_trigonometric_to_approximation.cc
|
legalize_trigonometric_to_approximation.cc
|
||||||
lower_complex.cc
|
lower_complex.cc
|
||||||
|
@ -56,7 +57,6 @@ add_mlir_library(MhloPasses
|
||||||
materialize_broadcasts.cc
|
materialize_broadcasts.cc
|
||||||
materialize_broadcasts_pass.cc
|
materialize_broadcasts_pass.cc
|
||||||
mhlo_fusion.cc
|
mhlo_fusion.cc
|
||||||
move_up_dynamic_broadcasts_for_fusion.cc
|
|
||||||
optimize_mhlo.cc
|
optimize_mhlo.cc
|
||||||
optimize_mhlo_pass.cc
|
optimize_mhlo_pass.cc
|
||||||
rank_specialization.cc
|
rank_specialization.cc
|
||||||
|
|
|
@ -355,8 +355,8 @@ struct MoveUpBroadcastInDimOpPattern
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MoveUpDynamicBroadcastsForFusionPass
|
struct BroadcastPropagationPass
|
||||||
: public PassWrapper<MoveUpDynamicBroadcastsForFusionPass, FunctionPass> {
|
: public PassWrapper<BroadcastPropagationPass, FunctionPass> {
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
|
registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
|
||||||
}
|
}
|
||||||
|
@ -364,7 +364,7 @@ struct MoveUpDynamicBroadcastsForFusionPass
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
MLIRContext *ctx = &getContext();
|
MLIRContext *ctx = &getContext();
|
||||||
RewritePatternSet patterns(ctx);
|
RewritePatternSet patterns(ctx);
|
||||||
mhlo::PopulateMoveUpDynamicBroadcastsForFusionPatterns(ctx, &patterns);
|
mhlo::PopulateBroadcastsPropagationPatterns(ctx, &patterns);
|
||||||
if (failed(
|
if (failed(
|
||||||
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
|
applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)))) {
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
|
@ -374,8 +374,8 @@ struct MoveUpDynamicBroadcastsForFusionPass
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
void PopulateBroadcastsPropagationPatterns(MLIRContext *context,
|
||||||
MLIRContext *context, OwningRewritePatternList *patterns) {
|
OwningRewritePatternList *patterns) {
|
||||||
// clang-format off
|
// clang-format off
|
||||||
patterns->insert<
|
patterns->insert<
|
||||||
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
|
InlineBroadcastedShapeOperandsPattern<shape::CstrBroadcastableOp>,
|
||||||
|
@ -397,8 +397,8 @@ void PopulateMoveUpDynamicBroadcastsForFusionPatterns(
|
||||||
tensor::CastOp::getCanonicalizationPatterns(*patterns, context);
|
tensor::CastOp::getCanonicalizationPatterns(*patterns, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<FunctionPass> createMoveUpDynamicBroadcastsForFusionPass() {
|
std::unique_ptr<FunctionPass> createBroadcastPropagationPass() {
|
||||||
return std::make_unique<MoveUpDynamicBroadcastsForFusionPass>();
|
return std::make_unique<BroadcastPropagationPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mhlo
|
} // namespace mhlo
|
|
@ -1,4 +1,4 @@
|
||||||
// RUN: mlir-hlo-opt --split-input-file --allow-unregistered-dialect --mhlo-move-up-dynamic-broadcasts-for-fusion --canonicalize --cse %s | FileCheck %s
|
// RUN: mlir-hlo-opt --split-input-file --allow-unregistered-dialect --mhlo-broadcast-propagation --canonicalize --cse %s | FileCheck %s
|
||||||
|
|
||||||
// Shape computations shall be reified.
|
// Shape computations shall be reified.
|
||||||
// CHECK-LABEL: @shape_of_unary
|
// CHECK-LABEL: @shape_of_unary
|
Loading…
Reference in New Issue