From 8c8e81cb6909fd8917199d08a7a9a91a27b201ac Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Wed, 16 Jun 2021 19:04:23 -0700 Subject: [PATCH] Fix pass definition to inherit from the TableGen generated base class (NFC) PiperOrigin-RevId: 379860210 --- BUILD | 15 +++++ .../Dialect/mhlo/transforms/lmhlo_passes.td | 12 ++-- .../Dialect/mhlo/transforms/mhlo_passes.td | 28 ++++----- .../mhlo/transforms/broadcast_propagation.cc | 3 +- .../mhlo/transforms/legalize_control_flow.cc | 3 +- .../legalize_gather_to_torch_index_select.cc | 4 +- .../transforms/legalize_tensor_load_op.cc | 3 +- .../mhlo/transforms/legalize_to_linalg.cc | 5 +- .../mhlo/transforms/legalize_to_standard.cc | 3 +- ...legalize_trigonometric_to_approximation.cc | 5 +- .../mhlo/transforms/lhlo_fuse_linalg.cc | 16 +---- .../transforms/lhlo_legalize_to_affine.cc | 3 +- .../mhlo/transforms/lhlo_legalize_to_gpu.cc | 3 +- .../lhlo_legalize_to_parallel_loops.cc | 4 +- lib/Dialect/mhlo/transforms/lower_complex.cc | 43 ++++++------- .../mhlo/transforms/lower_general_dot.cc | 60 ++++++++----------- .../transforms/materialize_broadcasts_pass.cc | 3 +- .../transforms/mhlo_control_flow_to_scf.cc | 3 +- lib/Dialect/mhlo/transforms/mhlo_fusion.cc | 3 +- .../mhlo/transforms/optimize_mhlo_pass.cc | 21 ++++--- .../transforms/test_infer_shaped_type_pass.cc | 4 +- .../mhlo/transforms/transform_unranked_hlo.cc | 3 +- .../mhlo/transforms/unfuse_batch_norm_pass.cc | 3 +- 23 files changed, 126 insertions(+), 124 deletions(-) diff --git a/BUILD b/BUILD index aa0d8e8..6c3dc82 100644 --- a/BUILD +++ b/BUILD @@ -748,6 +748,7 @@ cc_library( hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"], deps = [ ":hlo", + ":pass_details", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -799,6 +800,7 @@ cc_library( ":hlo", ":lhlo", ":map_lmhlo_to_scalar_op", + ":pass_details", "@llvm-project//llvm:Support", "@llvm-project//mlir:Affine", "@llvm-project//mlir:IR", @@ -814,6 +816,7 @@ cc_library( srcs = ["lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc"], deps = [ ":lhlo", + ":pass_details", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgOps", @@ -837,6 +840,7 @@ cc_library( ":hlo", ":lhlo", ":map_lmhlo_to_scalar_op", + ":pass_details", "@llvm-project//llvm:Support", "@llvm-project//mlir:Affine", "@llvm-project//mlir:IR", @@ -863,6 +867,7 @@ cc_library( deps = [ ":hlo", ":map_chlo_to_hlo_op", + ":pass_details", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -885,6 +890,7 @@ cc_library( deps = [ ":hlo", ":map_chlo_to_hlo_op", + ":pass_details", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", @@ -928,6 +934,7 @@ cc_library( ":hlo", ":lhlo", ":map_lmhlo_to_scalar_op", + ":pass_details", "@llvm-project//llvm:Support", "@llvm-project//mlir:Affine", "@llvm-project//mlir:GPUDialect", @@ -948,6 +955,7 @@ cc_library( hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"], deps = [ ":lhlo", + ":pass_details", "@llvm-project//llvm:Support", "@llvm-project//mlir:Affine", "@llvm-project//mlir:IR", @@ -1008,6 +1016,7 @@ cc_library( deps = [ ":cycle_detector", ":hlo", + ":pass_details", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -1046,6 +1055,7 @@ cc_library( hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"], deps = [ ":hlo", + ":pass_details", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -1064,6 +1074,7 @@ cc_library( ":hlo", ":legalize_to_standard_inc_gen", ":legalize_trigonometric_to_approximation", + ":pass_details", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -1083,6 +1094,7 @@ cc_library( ], deps = [ ":hlo", + ":pass_details", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -1102,6 +1114,7 @@ cc_library( ], includes = ["include"], deps = [ + ":pass_details", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:MathDialect", @@ -1149,6 +1162,7 @@ cc_library( deps = [ ":hlo", ":lower_complex_inc_gen", + ":pass_details", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", @@ -1201,6 +1215,7 @@ cc_library( hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"], deps = [ ":lhlo", + ":pass_details", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:MemRefDialect", diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td b/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td index 59388f0..6c29e3c 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td +++ b/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td @@ -15,13 +15,13 @@ limitations under the License. include "mlir/Pass/PassBase.td" -def LhloLegalizeToLinalgPass : Pass<"lhlo-legalize-to-linalg", "FuncOp"> { +def LhloLegalizeToLinalgPass : FunctionPass<"lhlo-legalize-to-linalg"> { let summary = "Legalize from LHLO dialect to Linalg dialect."; let constructor = "createLegalizeLhloToLinalgPass()"; } -def LhloFuseLinalgPass : Pass<"lhlo-fuse-linalg", "FuncOp"> { +def LhloFuseLinalgPass : FunctionPass<"lhlo-fuse-linalg"> { let summary = "Greedily fuse linalg ops obtained after LHLO lowering."; let constructor = "createLhloFuseLinalgPass()"; let options = [ @@ -34,24 +34,24 @@ def LhloFuseLinalgPass : Pass<"lhlo-fuse-linalg", "FuncOp"> { } -def LhloLegalizeToAffinePass : Pass<"lhlo-legalize-to-affine", "FuncOp"> { +def LhloLegalizeToAffinePass : FunctionPass<"lhlo-legalize-to-affine"> { let summary = "Legalize from LHLO dialect to affine dialect."; let constructor = "createLhloLegalizeToAffinePass()"; } -def LhloLegalizeToGpuPass : Pass<"lhlo-legalize-to-gpu", "FuncOp"> { +def LhloLegalizeToGpuPass : FunctionPass<"lhlo-legalize-to-gpu"> { let summary = "Legalize from LHLO dialect to GPU dialect."; let constructor = "createLegalizeToGpuPass()"; } -def LhloLegalizeToParallelLoopsPass : Pass<"lhlo-legalize-to-parallel-loops", "FuncOp"> { +def LhloLegalizeToParallelLoopsPass : FunctionPass<"lhlo-legalize-to-parallel-loops"> { let summary = "Legalize from LHLO dialect to parallel loops."; let constructor = "createLegalizeLhloToParallelLoopsPass()"; } -def LegalizeTensorLoadOpPass : Pass<"lhlo-legalize-tensor-load-op", "FuncOp"> { +def LegalizeTensorLoadOpPass : FunctionPass<"lhlo-legalize-tensor-load-op"> { let summary = "Legalize tensor load ops that are inserted during mhlo to lmhlo conversion."; let constructor = "createLegalizeTensorLoadOpPass()"; } diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td b/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td index 055523e..f7ee2d8 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td +++ b/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td @@ -37,64 +37,64 @@ def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> { ]; } -def LegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "FuncOp"> { +def LegalizeControlFlowPass : FunctionPass<"mhlo-legalize-control-flow"> { let summary = "Legalize from MHLO control flow to CFG control flow."; let constructor = "createLegalizeControlFlowPass()"; } -def LegalizeControlFlowToScfPass : Pass<"mhlo-control-flow-to-scf", "FuncOp"> { +def LegalizeControlFlowToScfPass : FunctionPass<"mhlo-control-flow-to-scf"> { let summary = "Legalize from MHLO control flow to SCF control flow."; let constructor = "createControlFlowToScfPass()"; } -def LegalizeGatherToTorchIndexSelectPass : Pass<"mhlo-legalize-gather-to-torch-index-select", "FuncOp"> { +def LegalizeGatherToTorchIndexSelectPass : FunctionPass<"mhlo-legalize-gather-to-torch-index-select"> { let summary = "Legalizes gathers to a torch index select."; let constructor = "createLegalizeGatherToTorchIndexSelectPass()"; } -def LegalizeTanhToApproximationPass : Pass<"mhlo-legalize-trigonometric-to-approximation", "FuncOp"> { +def LegalizeTanhToApproximationPass : FunctionPass<"mhlo-legalize-trigonometric-to-approximation"> { let summary = "Legalize trigonometric operations from standard dialect to an approximation."; let constructor = "createLegalizeTrigonometricToApproximationPass()"; } -def HloLegalizeToLinalgPass : Pass<"hlo-legalize-to-linalg", "FuncOp"> { +def HloLegalizeToLinalgPass : FunctionPass<"hlo-legalize-to-linalg"> { let summary = "Legalize from HLO dialect to Linalg dialect."; let constructor = "createLegalizeHloToLinalgPass()"; } -def LegalizeToStandardPass : Pass<"mhlo-legalize-to-std", "FuncOp"> { +def LegalizeToStandardPass : FunctionPass<"mhlo-legalize-to-std"> { let summary = "Legalize from MHLO dialect to standard dialect."; let constructor = "createLegalizeToStdPass()"; } -def LowerComplexPass : Pass<"mhlo-test-lower-complex", "FuncOp"> { +def LowerComplexPass : FunctionPass<"mhlo-test-lower-complex"> { let summary = "Lower complex operations into non-complex operations."; let constructor = "createLowerComplexPass()"; } -def LegalizeGeneralDotPass : Pass<"mhlo-test-lower-general-dot", "FuncOp"> { +def LegalizeGeneralDotPass : FunctionPass<"mhlo-test-lower-general-dot"> { let summary = "Tests lowering general dot to a non-batched dot when possible."; let constructor = "createLegalizeGeneralDotPass()"; } -def TestMaterializeBroadcastsPass : Pass<"mhlo-test-materialize-broadcasts", "FuncOp"> { +def TestMaterializeBroadcastsPass : FunctionPass<"mhlo-test-materialize-broadcasts"> { let summary = "Test pass for materializing 'broadcast_dimensions' attributes."; let constructor = "createTestMaterializeBroadcastsPass()"; } -def MhloFusionPass : Pass<"mhlo-fusion", "FuncOp"> { +def MhloFusionPass : FunctionPass<"mhlo-fusion"> { let summary = "Fuse mhlo ops to kLoop/kInput fusion patterns."; let constructor = "createMhloFusionPass()"; } -def OptimizeMhloPass : Pass<"mhlo-test-optimize", "FuncOp"> { +def OptimizeMhloPass : FunctionPass<"mhlo-test-optimize"> { let summary = "Run optional HLO optimizations."; let constructor = "createOptimizeMhloPass()"; } @@ -107,18 +107,18 @@ def SinkConstantsToControlFlowPass : FunctionPass<"mhlo-sink-constants-to-contro } -def TestInferShapedTypeMethodsPass : Pass<"mhlo-test-infer-shaped-type-methods", "FuncOp"> { +def TestInferShapedTypeMethodsPass : FunctionPass<"mhlo-test-infer-shaped-type-methods"> { let summary = "Uses test ops to invoke InferShapedTypeOpInterface methods."; let constructor = "createTestInferShapedTypeMethodsPass()"; } -def TransformUnrankedHloPass : Pass<"mhlo-transform-unranked-hlo", "FuncOp"> { +def TransformUnrankedHloPass : FunctionPass<"mhlo-transform-unranked-hlo"> { let summary = "Realize element-wise operations on ranked tensors where possible."; let constructor = "createTransformUnrankedHloPass()"; } -def BroadcastPropagationPass : Pass<"mhlo-broadcast-propagation", "FuncOp"> { +def BroadcastPropagationPass : FunctionPass<"mhlo-broadcast-propagation"> { let summary = "Move dynamic broadcasts up over element-wise operations and " "broadcast the operands rather than the result. This will eventually allow " "for larger fusions."; diff --git a/lib/Dialect/mhlo/transforms/broadcast_propagation.cc b/lib/Dialect/mhlo/transforms/broadcast_propagation.cc index 82ca3e7..91bec86 100644 --- a/lib/Dialect/mhlo/transforms/broadcast_propagation.cc +++ b/lib/Dialect/mhlo/transforms/broadcast_propagation.cc @@ -19,6 +19,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" @@ -376,7 +377,7 @@ struct EarlyBroadcastInDimOpPattern }; struct BroadcastPropagationPass - : public PassWrapper { + : public BroadcastPropagationPassBase { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } diff --git a/lib/Dialect/mhlo/transforms/legalize_control_flow.cc b/lib/Dialect/mhlo/transforms/legalize_control_flow.cc index 17c017d..4ac07ea 100644 --- a/lib/Dialect/mhlo/transforms/legalize_control_flow.cc +++ b/lib/Dialect/mhlo/transforms/legalize_control_flow.cc @@ -19,6 +19,7 @@ limitations under the License. #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project @@ -37,7 +38,7 @@ namespace mlir { namespace mhlo { namespace { struct LegalizeControlFlowPass - : public mlir::PassWrapper { + : public LegalizeControlFlowPassBase { // Perform the lowering to MLIR control flow. void runOnFunction() override; }; diff --git a/lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc b/lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc index 2416076..a90cbc4 100644 --- a/lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc +++ b/lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir/IR/BuiltinOps.h" @@ -128,7 +129,8 @@ struct GatherIsTorchIndexSelect : public OpRewritePattern { }; struct LegalizeGatherToTorchIndexSelectPass - : public PassWrapper { + : public LegalizeGatherToTorchIndexSelectPassBase< + LegalizeGatherToTorchIndexSelectPass> { /// Perform the lowering of standard dialect operations to approximations. void runOnFunction() override { OwningRewritePatternList patterns(&getContext()); diff --git a/lib/Dialect/mhlo/transforms/legalize_tensor_load_op.cc b/lib/Dialect/mhlo/transforms/legalize_tensor_load_op.cc index 6d2733d..81f02a1 100644 --- a/lib/Dialect/mhlo/transforms/legalize_tensor_load_op.cc +++ b/lib/Dialect/mhlo/transforms/legalize_tensor_load_op.cc @@ -16,6 +16,7 @@ limitations under the License. // This file implements logic for lowering memref.tensor_load ops that are // inserted during `mhlo-legalize-to-lmhlo`. +#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -71,7 +72,7 @@ struct ForwardShapeOfOp : public OpRewritePattern { }; struct LegalizeTensorLoadOpPass - : public mlir::PassWrapper { + : public LegalizeTensorLoadOpPassBase { // Perform the lowering to remove memref.tensor_load ops inserted during // `mhlo-legalize-to-lmhlo`. void runOnFunction() override { diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 2478ec5..0c24d5e 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/ADT/SetVector.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -2428,7 +2429,7 @@ class RemoveSignTypeConverter : public TypeConverter { // iterator_types = ["parallel", "parallel"], // } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () struct LhloLegalizeToLinalgPass - : public PassWrapper { + : public lmhlo::LhloLegalizeToLinalgPassBase { void getDependentDialects(DialectRegistry& registry) const override { registry .insert { + : public mhlo::HloLegalizeToLinalgPassBase { void getDependentDialects(DialectRegistry& registry) const override { registry .insert { namespace { struct LegalizeToStandardPass - : public PassWrapper { + : public LegalizeToStandardPassBase { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } diff --git a/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc b/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc index 24058e0..332b2d9 100644 --- a/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc +++ b/lib/Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc @@ -16,6 +16,7 @@ limitations under the License. // This file implements the lowering for trigonometric standard ops to // approximations. +#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -154,8 +155,8 @@ class ApproximateTanhLowering }; struct LegalizeTrigonometricToApproximationPass - : public PassWrapper { + : public LegalizeTanhToApproximationPassBase< + LegalizeTrigonometricToApproximationPass> { /// Perform the lowering of standard dialect operations to approximations. void runOnFunction() override { OwningRewritePatternList patterns(&getContext()); diff --git a/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc b/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc index cecf96a..e371d90 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc @@ -18,6 +18,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" +#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" @@ -36,8 +37,7 @@ namespace { using linalg::LinalgOp; -class LhloFuseLinalgPass - : public PassWrapper { +class LhloFuseLinalgPass : public LhloFuseLinalgPassBase { void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); } @@ -202,18 +202,6 @@ class LhloFuseLinalgPass .setLoopType(loopType)); return tiled_generic_op.hasValue(); } - - Option use_parallel_loops_{ - *this, "use-parallel-loops", - llvm::cl::desc( - "Tiles GenericOp consumer to parallel loops before linalg fusion"), - llvm::cl::init(false)}; - - ListOption tile_sizes_{ - *this, "tile-sizes", - llvm::cl::desc( - "Tile sizes by which to tile linalg generic before linalg fusion"), - llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; }; } // namespace diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc index 2e7d2a4..b211169 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc @@ -16,6 +16,7 @@ limitations under the License. // This file implements logic for lowering LHLO dialect to Affine dialect. #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -229,7 +230,7 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context, } struct LhloLegalizeToAffinePass - : public PassWrapper { + : public LhloLegalizeToAffinePassBase { void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); } diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc index e068fa7..3e5681e 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc @@ -19,6 +19,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/GPU/GPUDialect.h" @@ -172,7 +173,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { }; struct LhloLegalizeToGpuPass - : public PassWrapper { + : public LhloLegalizeToGpuPassBase { void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc index d054ee5..79a3ec4 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc @@ -17,6 +17,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" @@ -700,7 +701,8 @@ class SelectAndScatterOpConverter }; struct LhloLegalizeToParallelLoopsPass - : public PassWrapper { + : public LhloLegalizeToParallelLoopsPassBase< + LhloLegalizeToParallelLoopsPass> { void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); } diff --git a/lib/Dialect/mhlo/transforms/lower_complex.cc b/lib/Dialect/mhlo/transforms/lower_complex.cc index 1e74906..c37bbb7 100644 --- a/lib/Dialect/mhlo/transforms/lower_complex.cc +++ b/lib/Dialect/mhlo/transforms/lower_complex.cc @@ -24,7 +24,9 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir-hlo/utils/hlo_utils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/MLIRContext.h" @@ -35,35 +37,17 @@ limitations under the License. #include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -using mlir::FunctionPass; -using mlir::OwningRewritePatternList; -using mlir::PassWrapper; - -namespace { -class LowerComplexPass : public PassWrapper { - public: - explicit LowerComplexPass() : PassWrapper() {} - - /// Performs the lowering to MHLO dialect. - void runOnFunction() override; -}; -} // end anonymous namespace - namespace mlir { namespace mhlo { namespace { +class LowerComplexPass : public LowerComplexPassBase { + public: + /// Performs the lowering to MHLO dialect. + void runOnFunction() override; +}; #include "generated_lower_complex.inc" -} // end anonymous namespace - -void PopulateComplexLoweringPatterns(MLIRContext* context, - OwningRewritePatternList* patterns) { - populateWithGenerated(*patterns); -} -} // end namespace mhlo -} // end namespace mlir - // Lowers the complex operations that can be represented using other operations. void LowerComplexPass::runOnFunction() { // Add lowering patterns to the list. @@ -73,6 +57,15 @@ void LowerComplexPass::runOnFunction() { (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } -std::unique_ptr mlir::mhlo::createLowerComplexPass() { - return std::make_unique(); +} // end anonymous namespace +} // end namespace mhlo +} // end namespace mlir + +void mlir::mhlo::PopulateComplexLoweringPatterns( + MLIRContext* context, OwningRewritePatternList* patterns) { + populateWithGenerated(*patterns); +} + +std::unique_ptr mlir::mhlo::createLowerComplexPass() { + return std::make_unique(); } diff --git a/lib/Dialect/mhlo/transforms/lower_general_dot.cc b/lib/Dialect/mhlo/transforms/lower_general_dot.cc index 8ab202d..3ddb6c1 100644 --- a/lib/Dialect/mhlo/transforms/lower_general_dot.cc +++ b/lib/Dialect/mhlo/transforms/lower_general_dot.cc @@ -18,6 +18,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -30,28 +31,16 @@ limitations under the License. #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -using mlir::DenseIntElementsAttr; -using mlir::ElementsAttr; -using mlir::failure; -using mlir::FunctionPass; -using mlir::LogicalResult; -using mlir::MLIRContext; -using mlir::OpRewritePattern; -using mlir::OwningRewritePatternList; -using mlir::PassWrapper; -using mlir::PatternRewriter; -using mlir::RankedTensorType; -using mlir::success; -using mlir::Value; - +namespace mlir { +namespace mhlo { namespace { -Value TransposeReshape(Value arg, mlir::Location loc, +Value TransposeReshape(Value arg, Location loc, llvm::ArrayRef left_dims, llvm::ArrayRef right_dims, llvm::ArrayRef arg_shape, PatternRewriter *rewriter) { - auto element_type = mlir::getElementTypeOrSelf(arg.getType()); + auto element_type = getElementTypeOrSelf(arg.getType()); int64_t left_size = 1; for (auto dim : left_dims) { @@ -68,7 +57,7 @@ Value TransposeReshape(Value arg, mlir::Location loc, left_dims.end()); transpose_permutation.append(right_dims.begin(), right_dims.end()); - mlir::TensorType transpose_permutation_type = RankedTensorType::get( + TensorType transpose_permutation_type = RankedTensorType::get( {static_cast(transpose_permutation.size())}, rewriter->getIntegerType(64)); @@ -83,20 +72,18 @@ Value TransposeReshape(Value arg, mlir::Location loc, transposed_shape.push_back(arg_shape[val]); } auto transpose_type = RankedTensorType::get(transposed_shape, element_type); - auto transpose_result = rewriter->create( + auto transpose_result = rewriter->create( loc, transpose_type, arg, transpose_permutation_attr); // Return the final result. auto reshaped_type = RankedTensorType::get({left_size, right_size}, element_type); - return rewriter->create(loc, reshaped_type, - transpose_result); + return rewriter->create(loc, reshaped_type, transpose_result); } -Value ProcessDotArg(Value arg, mlir::Location loc, - ElementsAttr contract_dims_attr, bool outer_dims_first, - PatternRewriter *rewriter) { - auto shape = arg.getType().cast().getShape(); +Value ProcessDotArg(Value arg, Location loc, ElementsAttr contract_dims_attr, + bool outer_dims_first, PatternRewriter *rewriter) { + auto shape = arg.getType().cast().getShape(); llvm::SmallVector is_outer_dim; is_outer_dim.resize(shape.size(), true); @@ -124,7 +111,7 @@ Value ProcessDotArg(Value arg, mlir::Location loc, return TransposeReshape(arg, loc, contract_dims, outer_dims, shape, rewriter); } -struct GeneralDotConvert : public OpRewritePattern { +struct GeneralDotConvert : public OpRewritePattern { // Attempts to lower a General Dot operator to a standard Dot operator. // General dots include batching dimensions and can have collapsing // dimensions along any axis. Inserting correctly arrange transpose and @@ -136,9 +123,9 @@ struct GeneralDotConvert : public OpRewritePattern { explicit GeneralDotConvert(MLIRContext *context) : OpRewritePattern(context) {} - LogicalResult matchAndRewrite(mlir::mhlo::DotGeneralOp op, + LogicalResult matchAndRewrite(DotGeneralOp op, PatternRewriter &rewriter) const override { - auto dot_element_type = mlir::getElementTypeOrSelf(op); + auto dot_element_type = getElementTypeOrSelf(op); auto dot_numbers = op.dot_dimension_numbers(); if (dot_numbers.lhs_batching_dimensions().getNumElements() != 0 || @@ -155,8 +142,8 @@ struct GeneralDotConvert : public OpRewritePattern { /*outer_dims_first=*/false, &rewriter); // Accept only static shaped types. - auto lhs_shape_type = lhs.getType().dyn_cast_or_null(); - auto rhs_shape_type = rhs.getType().dyn_cast_or_null(); + auto lhs_shape_type = lhs.getType().dyn_cast_or_null(); + auto rhs_shape_type = rhs.getType().dyn_cast_or_null(); if (!lhs_shape_type || !rhs_shape_type) return failure(); if (!lhs_shape_type.hasStaticShape() || !rhs_shape_type.hasStaticShape()) return failure(); @@ -167,28 +154,29 @@ struct GeneralDotConvert : public OpRewritePattern { auto new_dot_type = RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type); - mlir::ArrayAttr precision_config; + ArrayAttr precision_config; if (op.precision_config()) precision_config = *op.precision_config(); - auto new_dot_op = rewriter.create( - op.getLoc(), new_dot_type, lhs, rhs, precision_config); + auto new_dot_op = rewriter.create(op.getLoc(), new_dot_type, lhs, + rhs, precision_config); - rewriter.replaceOpWithNewOp(op, op.getType(), - new_dot_op); + rewriter.replaceOpWithNewOp(op, op.getType(), new_dot_op); return success(); } }; struct LegalizeGeneralDotPass - : public PassWrapper { + : public LegalizeGeneralDotPassBase { /// Lower all general dots that can be represented as a non-batched matmul. void runOnFunction() override { OwningRewritePatternList patterns(&getContext()); - mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(&patterns, &getContext()); + PopulateGeneralDotOpLoweringPatterns(&patterns, &getContext()); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } }; } // namespace +} // namespace mhlo +} // namespace mlir void mlir::mhlo::PopulateGeneralDotOpLoweringPatterns( OwningRewritePatternList *patterns, MLIRContext *ctx) { diff --git a/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc b/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc index 4cb618d..b6fa010 100644 --- a/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc +++ b/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/MLIRContext.h" @@ -28,7 +29,7 @@ namespace mhlo { namespace { struct TestMaterializeBroadcastsPass - : public PassWrapper { + : public TestMaterializeBroadcastsPassBase { void runOnFunction() override { ConversionTarget conversionTarget(getContext()); OwningRewritePatternList conversionPatterns(&getContext()); diff --git a/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc b/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc index d7b8537..2fa9ee9 100644 --- a/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc +++ b/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc @@ -16,6 +16,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -39,7 +40,7 @@ void MatchAndRewrite(WhileOp whileOp); /// Pass that converts MHLO control flow to SCF. class ControlFlowToScfPass - : public mlir::PassWrapper { + : public LegalizeControlFlowToScfPassBase { void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); } diff --git a/lib/Dialect/mhlo/transforms/mhlo_fusion.cc b/lib/Dialect/mhlo/transforms/mhlo_fusion.cc index 3dff7db..a987b0d 100644 --- a/lib/Dialect/mhlo/transforms/mhlo_fusion.cc +++ b/lib/Dialect/mhlo/transforms/mhlo_fusion.cc @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/ADT/EquivalenceClasses.h" #include "llvm/Support/Debug.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" #include "mlir-hlo/utils/cycle_detector.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project @@ -480,7 +481,7 @@ class FusionPlanner { EquivalenceClasses leader_for_node_; }; -struct MhloFusionPass : public mlir::PassWrapper { +struct MhloFusionPass : public MhloFusionPassBase { void runOnFunction() override { FuncOp func = getFunction(); if (!IsTargetFunc(func)) { diff --git a/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc b/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc index 539b70a..36da395 100644 --- a/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc +++ b/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -23,28 +24,26 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -using mlir::FunctionPass; -using mlir::PassWrapper; - +namespace mlir { +namespace mhlo { namespace { -class OptimizeMhloPass : public PassWrapper { +class OptimizeMhloPass : public OptimizeMhloPassBase { public: - explicit OptimizeMhloPass() : PassWrapper() {} - /// Performs the lowering to MHLO dialect. void runOnFunction() override; }; -} // end anonymous namespace - // Lowers the complex operations that can be represented using other operations. void OptimizeMhloPass::runOnFunction() { // Add lowering patterns to the list. - mlir::OwningRewritePatternList patterns(&getContext()); - mlir::mhlo::PopulateOptimizeMHLOPatterns(&getContext(), &patterns); + OwningRewritePatternList patterns(&getContext()); + PopulateOptimizeMHLOPatterns(&getContext(), &patterns); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } +} // end anonymous namespace +} // namespace mhlo +} // namespace mlir std::unique_ptr mlir::mhlo::createOptimizeMhloPass() { - return std::make_unique(); + return std::make_unique(); } diff --git a/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc b/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc index 251cc8b..1048076 100644 --- a/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc +++ b/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Identifier.h" @@ -84,7 +85,8 @@ struct ReifyReturnTypeShapesPattern : public RewritePattern { }; struct TestInferShapedTypeMethodsPass - : public PassWrapper { + : public TestInferShapedTypeMethodsPassBase< + TestInferShapedTypeMethodsPass> { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index c99901f..00cba94 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -16,6 +16,7 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir/Dialect/SCF/SCF.h" @@ -538,7 +539,7 @@ struct ConvertUnrankedDynamicBroadcastNaryOp }; struct TransformUnrankedHloPass - : public PassWrapper { + : public mhlo::TransformUnrankedHloPassBase { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); diff --git a/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc b/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc index 34726c0..f0179a2 100644 --- a/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc +++ b/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -30,7 +31,7 @@ namespace mhlo { namespace { struct TestUnfuseBatchNormPass - : public PassWrapper> { + : public TestUnfuseBatchNormPassBase { void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); }