mlir-hlo-opt: set preloadDialectsInContext to false.
This requires specifying dependent dialects in several passes. PiperOrigin-RevId: 365758084
This commit is contained in:
		
							parent
							
								
									9ebadc4c4d
								
							
						
					
					
						commit
						6388e8d9ee
					
				
							
								
								
									
										4
									
								
								BUILD
								
								
								
								
							
							
						
						
									
										4
									
								
								BUILD
								
								
								
								
							| 
						 | 
					@ -642,6 +642,7 @@ cc_library(
 | 
				
			||||||
        "@llvm-project//mlir:IR",
 | 
					        "@llvm-project//mlir:IR",
 | 
				
			||||||
        "@llvm-project//mlir:LinalgOps",
 | 
					        "@llvm-project//mlir:LinalgOps",
 | 
				
			||||||
        "@llvm-project//mlir:MathDialect",
 | 
					        "@llvm-project//mlir:MathDialect",
 | 
				
			||||||
 | 
					        "@llvm-project//mlir:MemRefDialect",
 | 
				
			||||||
        "@llvm-project//mlir:Pass",
 | 
					        "@llvm-project//mlir:Pass",
 | 
				
			||||||
        "@llvm-project//mlir:SCFDialect",
 | 
					        "@llvm-project//mlir:SCFDialect",
 | 
				
			||||||
        "@llvm-project//mlir:StandardOps",
 | 
					        "@llvm-project//mlir:StandardOps",
 | 
				
			||||||
| 
						 | 
					@ -709,6 +710,7 @@ cc_library(
 | 
				
			||||||
        "@llvm-project//mlir:GPUDialect",
 | 
					        "@llvm-project//mlir:GPUDialect",
 | 
				
			||||||
        "@llvm-project//mlir:IR",
 | 
					        "@llvm-project//mlir:IR",
 | 
				
			||||||
        "@llvm-project//mlir:LinalgOps",
 | 
					        "@llvm-project//mlir:LinalgOps",
 | 
				
			||||||
 | 
					        "@llvm-project//mlir:MemRefDialect",
 | 
				
			||||||
        "@llvm-project//mlir:Pass",
 | 
					        "@llvm-project//mlir:Pass",
 | 
				
			||||||
        "@llvm-project//mlir:SCFDialect",
 | 
					        "@llvm-project//mlir:SCFDialect",
 | 
				
			||||||
        "@llvm-project//mlir:StandardOps",
 | 
					        "@llvm-project//mlir:StandardOps",
 | 
				
			||||||
| 
						 | 
					@ -752,6 +754,7 @@ cc_library(
 | 
				
			||||||
        ":map_hlo_to_lhlo_op",
 | 
					        ":map_hlo_to_lhlo_op",
 | 
				
			||||||
        "@llvm-project//llvm:Support",
 | 
					        "@llvm-project//llvm:Support",
 | 
				
			||||||
        "@llvm-project//mlir:IR",
 | 
					        "@llvm-project//mlir:IR",
 | 
				
			||||||
 | 
					        "@llvm-project//mlir:MemRefDialect",
 | 
				
			||||||
        "@llvm-project//mlir:Pass",
 | 
					        "@llvm-project//mlir:Pass",
 | 
				
			||||||
        "@llvm-project//mlir:Shape",
 | 
					        "@llvm-project//mlir:Shape",
 | 
				
			||||||
        "@llvm-project//mlir:ShapeTransforms",
 | 
					        "@llvm-project//mlir:ShapeTransforms",
 | 
				
			||||||
| 
						 | 
					@ -1033,6 +1036,7 @@ cc_library(
 | 
				
			||||||
        "@llvm-project//mlir:InferTypeOpInterface",
 | 
					        "@llvm-project//mlir:InferTypeOpInterface",
 | 
				
			||||||
        "@llvm-project//mlir:LLVMDialect",
 | 
					        "@llvm-project//mlir:LLVMDialect",
 | 
				
			||||||
        "@llvm-project//mlir:LLVMTransforms",
 | 
					        "@llvm-project//mlir:LLVMTransforms",
 | 
				
			||||||
 | 
					        "@llvm-project//mlir:MemRefDialect",
 | 
				
			||||||
        "@llvm-project//mlir:Pass",
 | 
					        "@llvm-project//mlir:Pass",
 | 
				
			||||||
        "@llvm-project//mlir:SCFDialect",
 | 
					        "@llvm-project//mlir:SCFDialect",
 | 
				
			||||||
        "@llvm-project//mlir:Shape",
 | 
					        "@llvm-project//mlir:Shape",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -20,6 +20,7 @@ limitations under the License.
 | 
				
			||||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
 | 
					#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
 | 
				
			||||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
 | 
					#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
 | 
				
			||||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
 | 
					#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
 | 
				
			||||||
 | 
					#include "mlir/Dialect/MemRef/IR/MemRef.h"
 | 
				
			||||||
#include "mlir/Dialect/Shape/IR/Shape.h"
 | 
					#include "mlir/Dialect/Shape/IR/Shape.h"
 | 
				
			||||||
#include "mlir/Dialect/Shape/Transforms/Passes.h"
 | 
					#include "mlir/Dialect/Shape/Transforms/Passes.h"
 | 
				
			||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
					#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
				
			||||||
| 
						 | 
					@ -564,7 +565,8 @@ class HloToLhloTensorStoreOpLegacyConverter
 | 
				
			||||||
struct HloLegalizeToLhlo
 | 
					struct HloLegalizeToLhlo
 | 
				
			||||||
    : public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> {
 | 
					    : public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> {
 | 
				
			||||||
  void getDependentDialects(DialectRegistry& registry) const override {
 | 
					  void getDependentDialects(DialectRegistry& registry) const override {
 | 
				
			||||||
    registry.insert<lmhlo::LmhloDialect>();
 | 
					    registry.insert<lmhlo::LmhloDialect, memref::MemRefDialect,
 | 
				
			||||||
 | 
					                    shape::ShapeDialect>();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 public:
 | 
					 public:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -27,6 +27,7 @@ limitations under the License.
 | 
				
			||||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 | 
					#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 | 
				
			||||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
 | 
					#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
 | 
				
			||||||
#include "mlir/Dialect/Math/IR/Math.h"
 | 
					#include "mlir/Dialect/Math/IR/Math.h"
 | 
				
			||||||
 | 
					#include "mlir/Dialect/MemRef/IR/MemRef.h"
 | 
				
			||||||
#include "mlir/Dialect/SCF/SCF.h"
 | 
					#include "mlir/Dialect/SCF/SCF.h"
 | 
				
			||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
					#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
				
			||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
 | 
					#include "mlir/Dialect/Tensor/IR/Tensor.h"
 | 
				
			||||||
| 
						 | 
					@ -1965,7 +1966,9 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
 | 
				
			||||||
struct LhloLegalizeToLinalgPass
 | 
					struct LhloLegalizeToLinalgPass
 | 
				
			||||||
    : public PassWrapper<LhloLegalizeToLinalgPass, FunctionPass> {
 | 
					    : public PassWrapper<LhloLegalizeToLinalgPass, FunctionPass> {
 | 
				
			||||||
  void getDependentDialects(DialectRegistry& registry) const override {
 | 
					  void getDependentDialects(DialectRegistry& registry) const override {
 | 
				
			||||||
    registry.insert<AffineDialect, linalg::LinalgDialect, math::MathDialect>();
 | 
					    registry
 | 
				
			||||||
 | 
					        .insert<AffineDialect, complex::ComplexDialect, linalg::LinalgDialect,
 | 
				
			||||||
 | 
					                math::MathDialect, memref::MemRefDialect>();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  void runOnFunction() override {
 | 
					  void runOnFunction() override {
 | 
				
			||||||
| 
						 | 
					@ -1986,8 +1989,9 @@ struct LhloLegalizeToLinalgPass
 | 
				
			||||||
struct HloLegalizeToLinalgPass
 | 
					struct HloLegalizeToLinalgPass
 | 
				
			||||||
    : public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
 | 
					    : public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
 | 
				
			||||||
  void getDependentDialects(DialectRegistry& registry) const override {
 | 
					  void getDependentDialects(DialectRegistry& registry) const override {
 | 
				
			||||||
    registry.insert<linalg::LinalgDialect, scf::SCFDialect,
 | 
					    registry
 | 
				
			||||||
                    complex::ComplexDialect, math::MathDialect>();
 | 
					        .insert<linalg::LinalgDialect, scf::SCFDialect, complex::ComplexDialect,
 | 
				
			||||||
 | 
					                math::MathDialect, memref::MemRefDialect>();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  void runOnFunction() override {
 | 
					  void runOnFunction() override {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -24,6 +24,7 @@ limitations under the License.
 | 
				
			||||||
#include "mlir/Dialect/GPU/GPUDialect.h"
 | 
					#include "mlir/Dialect/GPU/GPUDialect.h"
 | 
				
			||||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 | 
					#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 | 
				
			||||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
 | 
					#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
 | 
				
			||||||
 | 
					#include "mlir/Dialect/MemRef/IR/MemRef.h"
 | 
				
			||||||
#include "mlir/Dialect/SCF/SCF.h"
 | 
					#include "mlir/Dialect/SCF/SCF.h"
 | 
				
			||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
					#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
				
			||||||
#include "mlir/IR/Attributes.h"
 | 
					#include "mlir/IR/Attributes.h"
 | 
				
			||||||
| 
						 | 
					@ -174,7 +175,7 @@ struct LhloLegalizeToGpuPass
 | 
				
			||||||
    : public PassWrapper<LhloLegalizeToGpuPass, FunctionPass> {
 | 
					    : public PassWrapper<LhloLegalizeToGpuPass, FunctionPass> {
 | 
				
			||||||
  void getDependentDialects(DialectRegistry& registry) const override {
 | 
					  void getDependentDialects(DialectRegistry& registry) const override {
 | 
				
			||||||
    registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect,
 | 
					    registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect,
 | 
				
			||||||
                    scf::SCFDialect>();
 | 
					                    memref::MemRefDialect, scf::SCFDialect>();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  void runOnFunction() override {
 | 
					  void runOnFunction() override {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
 | 
				
			||||||
limitations under the License.
 | 
					limitations under the License.
 | 
				
			||||||
==============================================================================*/
 | 
					==============================================================================*/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "mlir/Dialect/Shape/IR/Shape.h"
 | 
				
			||||||
#include "mlir/IR/Attributes.h"
 | 
					#include "mlir/IR/Attributes.h"
 | 
				
			||||||
#include "mlir/IR/Identifier.h"
 | 
					#include "mlir/IR/Identifier.h"
 | 
				
			||||||
#include "mlir/IR/MLIRContext.h"
 | 
					#include "mlir/IR/MLIRContext.h"
 | 
				
			||||||
| 
						 | 
					@ -83,6 +84,9 @@ struct ReifyReturnTypeShapesPattern : public RewritePattern {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct TestInferShapedTypeMethodsPass
 | 
					struct TestInferShapedTypeMethodsPass
 | 
				
			||||||
    : public PassWrapper<TestInferShapedTypeMethodsPass, FunctionPass> {
 | 
					    : public PassWrapper<TestInferShapedTypeMethodsPass, FunctionPass> {
 | 
				
			||||||
 | 
					  void getDependentDialects(DialectRegistry ®istry) const override {
 | 
				
			||||||
 | 
					    registry.insert<shape::ShapeDialect>();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
  void runOnFunction() override {
 | 
					  void runOnFunction() override {
 | 
				
			||||||
    OwningRewritePatternList patterns(&getContext());
 | 
					    OwningRewritePatternList patterns(&getContext());
 | 
				
			||||||
    patterns.insert<ReifyReturnTypeShapesPattern>(&getContext());
 | 
					    patterns.insert<ReifyReturnTypeShapesPattern>(&getContext());
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -528,7 +528,7 @@ struct ConvertUnrankedDynamicBroadcastSelectOp
 | 
				
			||||||
struct TransformUnrankedHloPass
 | 
					struct TransformUnrankedHloPass
 | 
				
			||||||
    : public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
 | 
					    : public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
 | 
				
			||||||
  void getDependentDialects(DialectRegistry ®istry) const override {
 | 
					  void getDependentDialects(DialectRegistry ®istry) const override {
 | 
				
			||||||
    registry.insert<chlo::HloClientDialect, mhlo::MhloDialect,
 | 
					    registry.insert<chlo::HloClientDialect, mhlo::MhloDialect, scf::SCFDialect,
 | 
				
			||||||
                    shape::ShapeDialect>();
 | 
					                    shape::ShapeDialect>();
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -15,7 +15,9 @@ limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
					#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
				
			||||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
 | 
					#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
 | 
				
			||||||
 | 
					#include "mlir/Dialect/MemRef/IR/MemRef.h"
 | 
				
			||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
					#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
				
			||||||
 | 
					#include "mlir/IR/Dialect.h"
 | 
				
			||||||
#include "mlir/IR/MLIRContext.h"
 | 
					#include "mlir/IR/MLIRContext.h"
 | 
				
			||||||
#include "mlir/IR/Operation.h"
 | 
					#include "mlir/IR/Operation.h"
 | 
				
			||||||
#include "mlir/Pass/Pass.h"
 | 
					#include "mlir/Pass/Pass.h"
 | 
				
			||||||
| 
						 | 
					@ -29,6 +31,9 @@ namespace {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct TestUnfuseBatchNormPass
 | 
					struct TestUnfuseBatchNormPass
 | 
				
			||||||
    : public PassWrapper<TestUnfuseBatchNormPass, OperationPass<>> {
 | 
					    : public PassWrapper<TestUnfuseBatchNormPass, OperationPass<>> {
 | 
				
			||||||
 | 
					  void getDependentDialects(DialectRegistry& registry) const override {
 | 
				
			||||||
 | 
					    registry.insert<memref::MemRefDialect>();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
  void runOnOperation() override {
 | 
					  void runOnOperation() override {
 | 
				
			||||||
    OwningRewritePatternList patterns(&getContext());
 | 
					    OwningRewritePatternList patterns(&getContext());
 | 
				
			||||||
    PopulateUnfuseBatchNormPatterns(&getContext(), &patterns);
 | 
					    PopulateUnfuseBatchNormPatterns(&getContext(), &patterns);
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -35,5 +35,6 @@ int main(int argc, char **argv) {
 | 
				
			||||||
  registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>();
 | 
					  registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n",
 | 
					  return failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n",
 | 
				
			||||||
                                  registry, /*preloadDialectsInContext=*/true));
 | 
					                                  registry,
 | 
				
			||||||
 | 
					                                  /*preloadDialectsInContext=*/false));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue