diff --git a/BUILD b/BUILD index 7d2df24..29712df 100644 --- a/BUILD +++ b/BUILD @@ -642,6 +642,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", @@ -709,6 +710,7 @@ cc_library( "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgOps", + "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", @@ -752,6 +754,7 @@ cc_library( ":map_hlo_to_lhlo_op", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Shape", "@llvm-project//mlir:ShapeTransforms", @@ -1033,6 +1036,7 @@ cc_library( "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:LLVMTransforms", + "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Shape", diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index ea9ad5a..87b2496 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Shape/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -564,7 +565,8 @@ class HloToLhloTensorStoreOpLegacyConverter struct HloLegalizeToLhlo : public PassWrapper> { void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); + registry.insert(); } public: diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index c7fbd80..df77bad 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -1965,7 +1966,9 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, struct LhloLegalizeToLinalgPass : public PassWrapper { void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); + registry + .insert(); } void runOnFunction() override { @@ -1986,8 +1989,9 @@ struct LhloLegalizeToLinalgPass struct HloLegalizeToLinalgPass : public PassWrapper { void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); + registry + .insert(); } void runOnFunction() override { diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc index 915fef9..cc15c61 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" @@ -174,7 +175,7 @@ struct LhloLegalizeToGpuPass : public PassWrapper { void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); + memref::MemRefDialect, scf::SCFDialect>(); } void runOnFunction() override { diff --git a/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc b/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc index de06621..d9859eb 100644 --- a/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc +++ b/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/MLIRContext.h" @@ -83,6 +84,9 @@ struct ReifyReturnTypeShapesPattern : public RewritePattern { struct TestInferShapedTypeMethodsPass : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } void runOnFunction() override { OwningRewritePatternList patterns(&getContext()); patterns.insert(&getContext()); diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index e3984d1..e3ac0b6 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -528,7 +528,7 @@ struct ConvertUnrankedDynamicBroadcastSelectOp struct TransformUnrankedHloPass : public PassWrapper { void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); } diff --git a/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc b/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc index 33e0a12..34726c0 100644 --- a/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc +++ b/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc @@ -15,7 +15,9 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Dialect.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" @@ -29,6 +31,9 @@ namespace { struct TestUnfuseBatchNormPass : public PassWrapper> { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } void runOnOperation() override { OwningRewritePatternList patterns(&getContext()); PopulateUnfuseBatchNormPatterns(&getContext(), &patterns); diff --git a/tools/mlir-hlo-opt/mlir-hlo-opt.cpp b/tools/mlir-hlo-opt/mlir-hlo-opt.cpp index 8c4ccb8..5c9a1e1 100644 --- a/tools/mlir-hlo-opt/mlir-hlo-opt.cpp +++ b/tools/mlir-hlo-opt/mlir-hlo-opt.cpp @@ -35,5 +35,6 @@ int main(int argc, char **argv) { registry.insert(); return failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n", - registry, /*preloadDialectsInContext=*/true)); + registry, + /*preloadDialectsInContext=*/false)); }