Remove the dependency on global dialect registry from mlir-hlo
PiperOrigin-RevId: 328457105
This commit is contained in:
		
							parent
							
								
									ebe7857fb1
								
							
						
					
					
						commit
						36ddbeb6b2
					
				| 
						 | 
					@ -29,6 +29,10 @@ namespace {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct TestChloLegalizeToHloPass
 | 
					struct TestChloLegalizeToHloPass
 | 
				
			||||||
    : public PassWrapper<TestChloLegalizeToHloPass, FunctionPass> {
 | 
					    : public PassWrapper<TestChloLegalizeToHloPass, FunctionPass> {
 | 
				
			||||||
 | 
					  void getDependentDialects(DialectRegistry ®istry) const override {
 | 
				
			||||||
 | 
					    registry.insert<mhlo::MhloDialect, shape::ShapeDialect, scf::SCFDialect>();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  void runOnFunction() override {
 | 
					  void runOnFunction() override {
 | 
				
			||||||
    ConversionTarget conversionTarget(getContext());
 | 
					    ConversionTarget conversionTarget(getContext());
 | 
				
			||||||
    OwningRewritePatternList conversionPatterns;
 | 
					    OwningRewritePatternList conversionPatterns;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -388,6 +388,10 @@ class HloToLhloTensorStoreOpConverter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct HloLegalizeToLhlo
 | 
					struct HloLegalizeToLhlo
 | 
				
			||||||
    : public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> {
 | 
					    : public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> {
 | 
				
			||||||
 | 
					  void getDependentDialects(DialectRegistry& registry) const override {
 | 
				
			||||||
 | 
					    registry.insert<lmhlo::LmhloDialect>();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
  HloLegalizeToLhlo() = default;
 | 
					  HloLegalizeToLhlo() = default;
 | 
				
			||||||
  HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {
 | 
					  HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -866,6 +866,10 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
 | 
				
			||||||
// } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
 | 
					// } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
 | 
				
			||||||
struct LhloLegalizeToLinalgPass
 | 
					struct LhloLegalizeToLinalgPass
 | 
				
			||||||
    : public PassWrapper<LhloLegalizeToLinalgPass, FunctionPass> {
 | 
					    : public PassWrapper<LhloLegalizeToLinalgPass, FunctionPass> {
 | 
				
			||||||
 | 
					  void getDependentDialects(DialectRegistry& registry) const override {
 | 
				
			||||||
 | 
					    registry.insert<AffineDialect, linalg::LinalgDialect>();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  void runOnFunction() override {
 | 
					  void runOnFunction() override {
 | 
				
			||||||
    OwningRewritePatternList patterns;
 | 
					    OwningRewritePatternList patterns;
 | 
				
			||||||
    ConversionTarget target(getContext());
 | 
					    ConversionTarget target(getContext());
 | 
				
			||||||
| 
						 | 
					@ -882,6 +886,10 @@ struct LhloLegalizeToLinalgPass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct HloLegalizeToLinalgPass
 | 
					struct HloLegalizeToLinalgPass
 | 
				
			||||||
    : public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
 | 
					    : public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
 | 
				
			||||||
 | 
					  void getDependentDialects(DialectRegistry& registry) const override {
 | 
				
			||||||
 | 
					    registry.insert<linalg::LinalgDialect>();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  void runOnFunction() override {
 | 
					  void runOnFunction() override {
 | 
				
			||||||
    OwningRewritePatternList patterns;
 | 
					    OwningRewritePatternList patterns;
 | 
				
			||||||
    ConversionTarget target(getContext());
 | 
					    ConversionTarget target(getContext());
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -178,6 +178,10 @@ class ConvertIotaOp : public OpRewritePattern<mhlo::IotaOp> {
 | 
				
			||||||
namespace {
 | 
					namespace {
 | 
				
			||||||
struct LegalizeToStandardPass
 | 
					struct LegalizeToStandardPass
 | 
				
			||||||
    : public PassWrapper<LegalizeToStandardPass, FunctionPass> {
 | 
					    : public PassWrapper<LegalizeToStandardPass, FunctionPass> {
 | 
				
			||||||
 | 
					  void getDependentDialects(DialectRegistry ®istry) const override {
 | 
				
			||||||
 | 
					    registry.insert<StandardOpsDialect>();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /// Perform the lowering to Standard dialect.
 | 
					  /// Perform the lowering to Standard dialect.
 | 
				
			||||||
  void runOnFunction() override;
 | 
					  void runOnFunction() override;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,8 +19,10 @@ limitations under the License.
 | 
				
			||||||
#include "llvm/ADT/ArrayRef.h"
 | 
					#include "llvm/ADT/ArrayRef.h"
 | 
				
			||||||
#include "llvm/ADT/STLExtras.h"
 | 
					#include "llvm/ADT/STLExtras.h"
 | 
				
			||||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
 | 
					#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
 | 
				
			||||||
 | 
					#include "mlir/Dialect/Affine/IR/AffineOps.h"
 | 
				
			||||||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
 | 
					#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
 | 
				
			||||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 | 
					#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 | 
				
			||||||
 | 
					#include "mlir/Dialect/SCF/SCF.h"
 | 
				
			||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
					#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
				
			||||||
#include "mlir/Pass/Pass.h"
 | 
					#include "mlir/Pass/Pass.h"
 | 
				
			||||||
#include "mlir/Transforms/FoldUtils.h"
 | 
					#include "mlir/Transforms/FoldUtils.h"
 | 
				
			||||||
| 
						 | 
					@ -33,6 +35,10 @@ using linalg::LinalgOp;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LhloFuseLinalgPass
 | 
					class LhloFuseLinalgPass
 | 
				
			||||||
    : public PassWrapper<LhloFuseLinalgPass, FunctionPass> {
 | 
					    : public PassWrapper<LhloFuseLinalgPass, FunctionPass> {
 | 
				
			||||||
 | 
					  void getDependentDialects(DialectRegistry& registry) const override {
 | 
				
			||||||
 | 
					    registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
  LhloFuseLinalgPass() = default;
 | 
					  LhloFuseLinalgPass() = default;
 | 
				
			||||||
  LhloFuseLinalgPass(const LhloFuseLinalgPass&) {}
 | 
					  LhloFuseLinalgPass(const LhloFuseLinalgPass&) {}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -139,6 +139,9 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct LhloLegalizeToAffinePass
 | 
					struct LhloLegalizeToAffinePass
 | 
				
			||||||
    : public PassWrapper<LhloLegalizeToAffinePass, FunctionPass> {
 | 
					    : public PassWrapper<LhloLegalizeToAffinePass, FunctionPass> {
 | 
				
			||||||
 | 
					  void getDependentDialects(DialectRegistry& registry) const override {
 | 
				
			||||||
 | 
					    registry.insert<AffineDialect>();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
  void runOnFunction() override {
 | 
					  void runOnFunction() override {
 | 
				
			||||||
    OwningRewritePatternList patterns;
 | 
					    OwningRewritePatternList patterns;
 | 
				
			||||||
    auto func = getFunction();
 | 
					    auto func = getFunction();
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -20,8 +20,10 @@ limitations under the License.
 | 
				
			||||||
#include "llvm/ADT/ArrayRef.h"
 | 
					#include "llvm/ADT/ArrayRef.h"
 | 
				
			||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 | 
					#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 | 
				
			||||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
 | 
					#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
 | 
				
			||||||
 | 
					#include "mlir/Dialect/Affine/IR/AffineOps.h"
 | 
				
			||||||
#include "mlir/Dialect/GPU/GPUDialect.h"
 | 
					#include "mlir/Dialect/GPU/GPUDialect.h"
 | 
				
			||||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 | 
					#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 | 
				
			||||||
 | 
					#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
 | 
				
			||||||
#include "mlir/Dialect/SCF/SCF.h"
 | 
					#include "mlir/Dialect/SCF/SCF.h"
 | 
				
			||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
					#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
				
			||||||
#include "mlir/IR/Attributes.h"
 | 
					#include "mlir/IR/Attributes.h"
 | 
				
			||||||
| 
						 | 
					@ -169,6 +171,11 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct LhloLegalizeToGpuPass
 | 
					struct LhloLegalizeToGpuPass
 | 
				
			||||||
    : public PassWrapper<LhloLegalizeToGpuPass, FunctionPass> {
 | 
					    : public PassWrapper<LhloLegalizeToGpuPass, FunctionPass> {
 | 
				
			||||||
 | 
					  void getDependentDialects(DialectRegistry& registry) const override {
 | 
				
			||||||
 | 
					    registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect,
 | 
				
			||||||
 | 
					                    scf::SCFDialect>();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  void runOnFunction() override {
 | 
					  void runOnFunction() override {
 | 
				
			||||||
    OwningRewritePatternList patterns;
 | 
					    OwningRewritePatternList patterns;
 | 
				
			||||||
    ConversionTarget target(getContext());
 | 
					    ConversionTarget target(getContext());
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -29,6 +29,10 @@ namespace {
 | 
				
			||||||
class TestLhloToLLVMPass
 | 
					class TestLhloToLLVMPass
 | 
				
			||||||
    : public ::mlir::PassWrapper<TestLhloToLLVMPass,
 | 
					    : public ::mlir::PassWrapper<TestLhloToLLVMPass,
 | 
				
			||||||
                                 ::mlir::OperationPass<::mlir::ModuleOp>> {
 | 
					                                 ::mlir::OperationPass<::mlir::ModuleOp>> {
 | 
				
			||||||
 | 
					  void getDependentDialects(DialectRegistry ®istry) const override {
 | 
				
			||||||
 | 
					    registry.insert<LLVM::LLVMDialect>();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
  void runOnOperation() override {
 | 
					  void runOnOperation() override {
 | 
				
			||||||
    ModuleOp m = getOperation();
 | 
					    ModuleOp m = getOperation();
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -691,6 +691,10 @@ class SelectAndScatterOpConverter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct LhloLegalizeToParallelLoopsPass
 | 
					struct LhloLegalizeToParallelLoopsPass
 | 
				
			||||||
    : public PassWrapper<LhloLegalizeToParallelLoopsPass, FunctionPass> {
 | 
					    : public PassWrapper<LhloLegalizeToParallelLoopsPass, FunctionPass> {
 | 
				
			||||||
 | 
					  void getDependentDialects(DialectRegistry& registry) const override {
 | 
				
			||||||
 | 
					    registry.insert<StandardOpsDialect, scf::SCFDialect>();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  void runOnFunction() override {
 | 
					  void runOnFunction() override {
 | 
				
			||||||
    auto func = getFunction();
 | 
					    auto func = getFunction();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -153,6 +153,10 @@ struct BinaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct TransformUnrankedHloPass
 | 
					struct TransformUnrankedHloPass
 | 
				
			||||||
    : public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
 | 
					    : public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
 | 
				
			||||||
 | 
					  void getDependentDialects(DialectRegistry ®istry) const override {
 | 
				
			||||||
 | 
					    registry.insert<shape::ShapeDialect>();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  void runOnFunction() override {
 | 
					  void runOnFunction() override {
 | 
				
			||||||
    // Setup conversion target.
 | 
					    // Setup conversion target.
 | 
				
			||||||
    MLIRContext &ctx = getContext();
 | 
					    MLIRContext &ctx = getContext();
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue