Remove the dependency on global dialect registry from mlir-hlo

PiperOrigin-RevId: 328457105
This commit is contained in:
Mehdi Amini 2020-08-25 20:30:05 -07:00 committed by TensorFlow MLIR Team
parent ebe7857fb1
commit 36ddbeb6b2
10 changed files with 48 additions and 0 deletions

View File

@ -29,6 +29,10 @@ namespace {
struct TestChloLegalizeToHloPass struct TestChloLegalizeToHloPass
: public PassWrapper<TestChloLegalizeToHloPass, FunctionPass> { : public PassWrapper<TestChloLegalizeToHloPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<mhlo::MhloDialect, shape::ShapeDialect, scf::SCFDialect>();
}
void runOnFunction() override { void runOnFunction() override {
ConversionTarget conversionTarget(getContext()); ConversionTarget conversionTarget(getContext());
OwningRewritePatternList conversionPatterns; OwningRewritePatternList conversionPatterns;

View File

@ -388,6 +388,10 @@ class HloToLhloTensorStoreOpConverter
struct HloLegalizeToLhlo struct HloLegalizeToLhlo
: public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> { : public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<lmhlo::LmhloDialect>();
}
public: public:
HloLegalizeToLhlo() = default; HloLegalizeToLhlo() = default;
HloLegalizeToLhlo(const HloLegalizeToLhlo& o) { HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {

View File

@ -866,6 +866,10 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
// } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () // } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
struct LhloLegalizeToLinalgPass struct LhloLegalizeToLinalgPass
: public PassWrapper<LhloLegalizeToLinalgPass, FunctionPass> { : public PassWrapper<LhloLegalizeToLinalgPass, FunctionPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<AffineDialect, linalg::LinalgDialect>();
}
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
ConversionTarget target(getContext()); ConversionTarget target(getContext());
@ -882,6 +886,10 @@ struct LhloLegalizeToLinalgPass
struct HloLegalizeToLinalgPass struct HloLegalizeToLinalgPass
: public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> { : public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<linalg::LinalgDialect>();
}
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
ConversionTarget target(getContext()); ConversionTarget target(getContext());

View File

@ -178,6 +178,10 @@ class ConvertIotaOp : public OpRewritePattern<mhlo::IotaOp> {
namespace { namespace {
struct LegalizeToStandardPass struct LegalizeToStandardPass
: public PassWrapper<LegalizeToStandardPass, FunctionPass> { : public PassWrapper<LegalizeToStandardPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<StandardOpsDialect>();
}
/// Perform the lowering to Standard dialect. /// Perform the lowering to Standard dialect.
void runOnFunction() override; void runOnFunction() override;
}; };

View File

@ -19,8 +19,10 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/FoldUtils.h"
@ -33,6 +35,10 @@ using linalg::LinalgOp;
class LhloFuseLinalgPass class LhloFuseLinalgPass
: public PassWrapper<LhloFuseLinalgPass, FunctionPass> { : public PassWrapper<LhloFuseLinalgPass, FunctionPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
}
public: public:
LhloFuseLinalgPass() = default; LhloFuseLinalgPass() = default;
LhloFuseLinalgPass(const LhloFuseLinalgPass&) {} LhloFuseLinalgPass(const LhloFuseLinalgPass&) {}

View File

@ -139,6 +139,9 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context,
struct LhloLegalizeToAffinePass struct LhloLegalizeToAffinePass
: public PassWrapper<LhloLegalizeToAffinePass, FunctionPass> { : public PassWrapper<LhloLegalizeToAffinePass, FunctionPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<AffineDialect>();
}
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
auto func = getFunction(); auto func = getFunction();

View File

@ -20,8 +20,10 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
@ -169,6 +171,11 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
struct LhloLegalizeToGpuPass struct LhloLegalizeToGpuPass
: public PassWrapper<LhloLegalizeToGpuPass, FunctionPass> { : public PassWrapper<LhloLegalizeToGpuPass, FunctionPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect,
scf::SCFDialect>();
}
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
ConversionTarget target(getContext()); ConversionTarget target(getContext());

View File

@ -29,6 +29,10 @@ namespace {
class TestLhloToLLVMPass class TestLhloToLLVMPass
: public ::mlir::PassWrapper<TestLhloToLLVMPass, : public ::mlir::PassWrapper<TestLhloToLLVMPass,
::mlir::OperationPass<::mlir::ModuleOp>> { ::mlir::OperationPass<::mlir::ModuleOp>> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect>();
}
public: public:
void runOnOperation() override { void runOnOperation() override {
ModuleOp m = getOperation(); ModuleOp m = getOperation();

View File

@ -691,6 +691,10 @@ class SelectAndScatterOpConverter
struct LhloLegalizeToParallelLoopsPass struct LhloLegalizeToParallelLoopsPass
: public PassWrapper<LhloLegalizeToParallelLoopsPass, FunctionPass> { : public PassWrapper<LhloLegalizeToParallelLoopsPass, FunctionPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<StandardOpsDialect, scf::SCFDialect>();
}
void runOnFunction() override { void runOnFunction() override {
auto func = getFunction(); auto func = getFunction();

View File

@ -153,6 +153,10 @@ struct BinaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
struct TransformUnrankedHloPass struct TransformUnrankedHloPass
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> { : public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<shape::ShapeDialect>();
}
void runOnFunction() override { void runOnFunction() override {
// Setup conversion target. // Setup conversion target.
MLIRContext &ctx = getContext(); MLIRContext &ctx = getContext();