mlir-hlo-opt: set preloadDialectsInContext to false.

This requires specifying dependent dialects in several passes.

PiperOrigin-RevId: 365758084
This commit is contained in:
Adrian Kuegel 2021-03-30 01:06:12 -07:00 committed by TensorFlow MLIR Team
parent 9ebadc4c4d
commit 6388e8d9ee
8 changed files with 28 additions and 7 deletions

4
BUILD
View File

@ -642,6 +642,7 @@ cc_library(
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
@ -709,6 +710,7 @@ cc_library(
"@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
@ -752,6 +754,7 @@ cc_library(
":map_hlo_to_lhlo_op", ":map_hlo_to_lhlo_op",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:Shape", "@llvm-project//mlir:Shape",
"@llvm-project//mlir:ShapeTransforms", "@llvm-project//mlir:ShapeTransforms",
@ -1033,6 +1036,7 @@ cc_library(
"@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:LLVMTransforms", "@llvm-project//mlir:LLVMTransforms",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Shape", "@llvm-project//mlir:Shape",

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h" #include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -564,7 +565,8 @@ class HloToLhloTensorStoreOpLegacyConverter
struct HloLegalizeToLhlo struct HloLegalizeToLhlo
: public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> { : public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> {
void getDependentDialects(DialectRegistry& registry) const override { void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<lmhlo::LmhloDialect>(); registry.insert<lmhlo::LmhloDialect, memref::MemRefDialect,
shape::ShapeDialect>();
} }
public: public:

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
@ -1965,7 +1966,9 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
struct LhloLegalizeToLinalgPass struct LhloLegalizeToLinalgPass
: public PassWrapper<LhloLegalizeToLinalgPass, FunctionPass> { : public PassWrapper<LhloLegalizeToLinalgPass, FunctionPass> {
void getDependentDialects(DialectRegistry& registry) const override { void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<AffineDialect, linalg::LinalgDialect, math::MathDialect>(); registry
.insert<AffineDialect, complex::ComplexDialect, linalg::LinalgDialect,
math::MathDialect, memref::MemRefDialect>();
} }
void runOnFunction() override { void runOnFunction() override {
@ -1986,8 +1989,9 @@ struct LhloLegalizeToLinalgPass
struct HloLegalizeToLinalgPass struct HloLegalizeToLinalgPass
: public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> { : public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
void getDependentDialects(DialectRegistry& registry) const override { void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<linalg::LinalgDialect, scf::SCFDialect, registry
complex::ComplexDialect, math::MathDialect>(); .insert<linalg::LinalgDialect, scf::SCFDialect, complex::ComplexDialect,
math::MathDialect, memref::MemRefDialect>();
} }
void runOnFunction() override { void runOnFunction() override {

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
@ -174,7 +175,7 @@ struct LhloLegalizeToGpuPass
: public PassWrapper<LhloLegalizeToGpuPass, FunctionPass> { : public PassWrapper<LhloLegalizeToGpuPass, FunctionPass> {
void getDependentDialects(DialectRegistry& registry) const override { void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect, registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect,
scf::SCFDialect>(); memref::MemRefDialect, scf::SCFDialect>();
} }
void runOnFunction() override { void runOnFunction() override {

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/Identifier.h" #include "mlir/IR/Identifier.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
@ -83,6 +84,9 @@ struct ReifyReturnTypeShapesPattern : public RewritePattern {
struct TestInferShapedTypeMethodsPass struct TestInferShapedTypeMethodsPass
: public PassWrapper<TestInferShapedTypeMethodsPass, FunctionPass> { : public PassWrapper<TestInferShapedTypeMethodsPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<shape::ShapeDialect>();
}
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns(&getContext()); OwningRewritePatternList patterns(&getContext());
patterns.insert<ReifyReturnTypeShapesPattern>(&getContext()); patterns.insert<ReifyReturnTypeShapesPattern>(&getContext());

View File

@ -528,7 +528,7 @@ struct ConvertUnrankedDynamicBroadcastSelectOp
struct TransformUnrankedHloPass struct TransformUnrankedHloPass
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> { : public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override { void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<chlo::HloClientDialect, mhlo::MhloDialect, registry.insert<chlo::HloClientDialect, mhlo::MhloDialect, scf::SCFDialect,
shape::ShapeDialect>(); shape::ShapeDialect>();
} }

View File

@ -15,7 +15,9 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
@ -29,6 +31,9 @@ namespace {
struct TestUnfuseBatchNormPass struct TestUnfuseBatchNormPass
: public PassWrapper<TestUnfuseBatchNormPass, OperationPass<>> { : public PassWrapper<TestUnfuseBatchNormPass, OperationPass<>> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<memref::MemRefDialect>();
}
void runOnOperation() override { void runOnOperation() override {
OwningRewritePatternList patterns(&getContext()); OwningRewritePatternList patterns(&getContext());
PopulateUnfuseBatchNormPatterns(&getContext(), &patterns); PopulateUnfuseBatchNormPatterns(&getContext(), &patterns);

View File

@ -35,5 +35,6 @@ int main(int argc, char **argv) {
registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>(); registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>();
return failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n", return failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n",
registry, /*preloadDialectsInContext=*/true)); registry,
/*preloadDialectsInContext=*/false));
} }