mlir-hlo-opt: set preloadDialectsInContext to false.

This requires specifying dependent dialects in several passes.

PiperOrigin-RevId: 365758084
This commit is contained in:
Adrian Kuegel 2021-03-30 01:06:12 -07:00 committed by TensorFlow MLIR Team
parent 9ebadc4c4d
commit 6388e8d9ee
8 changed files with 28 additions and 7 deletions

4
BUILD
View File

@ -642,6 +642,7 @@ cc_library(
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
@ -709,6 +710,7 @@ cc_library(
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
@ -752,6 +754,7 @@ cc_library(
":map_hlo_to_lhlo_op",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:ShapeTransforms",
@ -1033,6 +1036,7 @@ cc_library(
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:LLVMTransforms",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Shape",

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -564,7 +565,8 @@ class HloToLhloTensorStoreOpLegacyConverter
struct HloLegalizeToLhlo
: public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<lmhlo::LmhloDialect>();
registry.insert<lmhlo::LmhloDialect, memref::MemRefDialect,
shape::ShapeDialect>();
}
public:

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@ -1965,7 +1966,9 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
struct LhloLegalizeToLinalgPass
: public PassWrapper<LhloLegalizeToLinalgPass, FunctionPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<AffineDialect, linalg::LinalgDialect, math::MathDialect>();
registry
.insert<AffineDialect, complex::ComplexDialect, linalg::LinalgDialect,
math::MathDialect, memref::MemRefDialect>();
}
void runOnFunction() override {
@ -1986,8 +1989,9 @@ struct LhloLegalizeToLinalgPass
struct HloLegalizeToLinalgPass
: public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<linalg::LinalgDialect, scf::SCFDialect,
complex::ComplexDialect, math::MathDialect>();
registry
.insert<linalg::LinalgDialect, scf::SCFDialect, complex::ComplexDialect,
math::MathDialect, memref::MemRefDialect>();
}
void runOnFunction() override {

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
@ -174,7 +175,7 @@ struct LhloLegalizeToGpuPass
: public PassWrapper<LhloLegalizeToGpuPass, FunctionPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect,
scf::SCFDialect>();
memref::MemRefDialect, scf::SCFDialect>();
}
void runOnFunction() override {

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/MLIRContext.h"
@ -83,6 +84,9 @@ struct ReifyReturnTypeShapesPattern : public RewritePattern {
struct TestInferShapedTypeMethodsPass
: public PassWrapper<TestInferShapedTypeMethodsPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<shape::ShapeDialect>();
}
void runOnFunction() override {
OwningRewritePatternList patterns(&getContext());
patterns.insert<ReifyReturnTypeShapesPattern>(&getContext());

View File

@ -528,7 +528,7 @@ struct ConvertUnrankedDynamicBroadcastSelectOp
struct TransformUnrankedHloPass
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<chlo::HloClientDialect, mhlo::MhloDialect,
registry.insert<chlo::HloClientDialect, mhlo::MhloDialect, scf::SCFDialect,
shape::ShapeDialect>();
}

View File

@ -15,7 +15,9 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"
@ -29,6 +31,9 @@ namespace {
struct TestUnfuseBatchNormPass
: public PassWrapper<TestUnfuseBatchNormPass, OperationPass<>> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<memref::MemRefDialect>();
}
void runOnOperation() override {
OwningRewritePatternList patterns(&getContext());
PopulateUnfuseBatchNormPatterns(&getContext(), &patterns);

View File

@ -35,5 +35,6 @@ int main(int argc, char **argv) {
registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>();
return failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n",
registry, /*preloadDialectsInContext=*/true));
registry,
/*preloadDialectsInContext=*/false));
}