From 36ddbeb6b24e887b1dc86c35b1173a7c7988ea54 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Tue, 25 Aug 2020 20:30:05 -0700 Subject: [PATCH] Remove the dependency on global dialect registry from mlir-hlo PiperOrigin-RevId: 328457105 --- lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc | 4 ++++ lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc | 4 ++++ lib/Dialect/mhlo/transforms/legalize_to_linalg.cc | 8 ++++++++ lib/Dialect/mhlo/transforms/legalize_to_standard.cc | 4 ++++ lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc | 6 ++++++ lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc | 3 +++ lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc | 7 +++++++ lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc | 4 ++++ .../mhlo/transforms/lhlo_legalize_to_parallel_loops.cc | 4 ++++ lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc | 4 ++++ 10 files changed, 48 insertions(+) diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc index 50cd6df..263b6cd 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc @@ -29,6 +29,10 @@ namespace { struct TestChloLegalizeToHloPass : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnFunction() override { ConversionTarget conversionTarget(getContext()); OwningRewritePatternList conversionPatterns; diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index a8c3ad1..a784c05 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -388,6 +388,10 @@ class HloToLhloTensorStoreOpConverter struct HloLegalizeToLhlo : public PassWrapper> { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: HloLegalizeToLhlo() = default; HloLegalizeToLhlo(const HloLegalizeToLhlo& o) { diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 033021c..ff780fd 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -866,6 +866,10 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, // } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () struct LhloLegalizeToLinalgPass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); @@ -882,6 +886,10 @@ struct LhloLegalizeToLinalgPass struct HloLegalizeToLinalgPass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); diff --git a/lib/Dialect/mhlo/transforms/legalize_to_standard.cc b/lib/Dialect/mhlo/transforms/legalize_to_standard.cc index cc574e0..5000fce 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_standard.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_standard.cc @@ -178,6 +178,10 @@ class ConvertIotaOp : public OpRewritePattern { namespace { struct LegalizeToStandardPass : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + /// Perform the lowering to Standard dialect. void runOnFunction() override; }; diff --git a/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc b/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc index 1467f01..6dc5b64 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc @@ -19,8 +19,10 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/FoldUtils.h" @@ -33,6 +35,10 @@ using linalg::LinalgOp; class LhloFuseLinalgPass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + public: LhloFuseLinalgPass() = default; LhloFuseLinalgPass(const LhloFuseLinalgPass&) {} diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc index 0789132..2771afc 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc @@ -139,6 +139,9 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context, struct LhloLegalizeToAffinePass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } void runOnFunction() override { OwningRewritePatternList patterns; auto func = getFunction(); diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc index cffb58b..fbade8f 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc @@ -20,8 +20,10 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" @@ -169,6 +171,11 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { struct LhloLegalizeToGpuPass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc index 8493a1f..3d49027 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc @@ -29,6 +29,10 @@ namespace { class TestLhloToLLVMPass : public ::mlir::PassWrapper> { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + public: void runOnOperation() override { ModuleOp m = getOperation(); diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc index 19f47d0..d9a2d99 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc @@ -691,6 +691,10 @@ class SelectAndScatterOpConverter struct LhloLegalizeToParallelLoopsPass : public PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnFunction() override { auto func = getFunction(); diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 7c985ea..58a0f0e 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -153,6 +153,10 @@ struct BinaryElementwiseOpConversion : public OpRewritePattern { struct TransformUnrankedHloPass : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnFunction() override { // Setup conversion target. MLIRContext &ctx = getContext();