Remove the dependency on global dialect registry from mlir-hlo
PiperOrigin-RevId: 328457105
This commit is contained in:
parent
ebe7857fb1
commit
36ddbeb6b2
|
@ -29,6 +29,10 @@ namespace {
|
||||||
|
|
||||||
struct TestChloLegalizeToHloPass
|
struct TestChloLegalizeToHloPass
|
||||||
: public PassWrapper<TestChloLegalizeToHloPass, FunctionPass> {
|
: public PassWrapper<TestChloLegalizeToHloPass, FunctionPass> {
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<mhlo::MhloDialect, shape::ShapeDialect, scf::SCFDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
ConversionTarget conversionTarget(getContext());
|
ConversionTarget conversionTarget(getContext());
|
||||||
OwningRewritePatternList conversionPatterns;
|
OwningRewritePatternList conversionPatterns;
|
||||||
|
|
|
@ -388,6 +388,10 @@ class HloToLhloTensorStoreOpConverter
|
||||||
|
|
||||||
struct HloLegalizeToLhlo
|
struct HloLegalizeToLhlo
|
||||||
: public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> {
|
: public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> {
|
||||||
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
|
registry.insert<lmhlo::LmhloDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
HloLegalizeToLhlo() = default;
|
HloLegalizeToLhlo() = default;
|
||||||
HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {
|
HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {
|
||||||
|
|
|
@ -866,6 +866,10 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
// } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
// } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||||
struct LhloLegalizeToLinalgPass
|
struct LhloLegalizeToLinalgPass
|
||||||
: public PassWrapper<LhloLegalizeToLinalgPass, FunctionPass> {
|
: public PassWrapper<LhloLegalizeToLinalgPass, FunctionPass> {
|
||||||
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
|
registry.insert<AffineDialect, linalg::LinalgDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
ConversionTarget target(getContext());
|
ConversionTarget target(getContext());
|
||||||
|
@ -882,6 +886,10 @@ struct LhloLegalizeToLinalgPass
|
||||||
|
|
||||||
struct HloLegalizeToLinalgPass
|
struct HloLegalizeToLinalgPass
|
||||||
: public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
|
: public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
|
||||||
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
|
registry.insert<linalg::LinalgDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
ConversionTarget target(getContext());
|
ConversionTarget target(getContext());
|
||||||
|
|
|
@ -178,6 +178,10 @@ class ConvertIotaOp : public OpRewritePattern<mhlo::IotaOp> {
|
||||||
namespace {
|
namespace {
|
||||||
struct LegalizeToStandardPass
|
struct LegalizeToStandardPass
|
||||||
: public PassWrapper<LegalizeToStandardPass, FunctionPass> {
|
: public PassWrapper<LegalizeToStandardPass, FunctionPass> {
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<StandardOpsDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
/// Perform the lowering to Standard dialect.
|
/// Perform the lowering to Standard dialect.
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
|
|
|
@ -19,8 +19,10 @@ limitations under the License.
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
||||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||||
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Transforms/FoldUtils.h"
|
#include "mlir/Transforms/FoldUtils.h"
|
||||||
|
@ -33,6 +35,10 @@ using linalg::LinalgOp;
|
||||||
|
|
||||||
class LhloFuseLinalgPass
|
class LhloFuseLinalgPass
|
||||||
: public PassWrapper<LhloFuseLinalgPass, FunctionPass> {
|
: public PassWrapper<LhloFuseLinalgPass, FunctionPass> {
|
||||||
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
|
registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
LhloFuseLinalgPass() = default;
|
LhloFuseLinalgPass() = default;
|
||||||
LhloFuseLinalgPass(const LhloFuseLinalgPass&) {}
|
LhloFuseLinalgPass(const LhloFuseLinalgPass&) {}
|
||||||
|
|
|
@ -139,6 +139,9 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context,
|
||||||
|
|
||||||
struct LhloLegalizeToAffinePass
|
struct LhloLegalizeToAffinePass
|
||||||
: public PassWrapper<LhloLegalizeToAffinePass, FunctionPass> {
|
: public PassWrapper<LhloLegalizeToAffinePass, FunctionPass> {
|
||||||
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
|
registry.insert<AffineDialect>();
|
||||||
|
}
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
|
|
|
@ -20,8 +20,10 @@ limitations under the License.
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
||||||
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||||
|
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
|
@ -169,6 +171,11 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
|
||||||
|
|
||||||
struct LhloLegalizeToGpuPass
|
struct LhloLegalizeToGpuPass
|
||||||
: public PassWrapper<LhloLegalizeToGpuPass, FunctionPass> {
|
: public PassWrapper<LhloLegalizeToGpuPass, FunctionPass> {
|
||||||
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
|
registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect,
|
||||||
|
scf::SCFDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
ConversionTarget target(getContext());
|
ConversionTarget target(getContext());
|
||||||
|
|
|
@ -29,6 +29,10 @@ namespace {
|
||||||
class TestLhloToLLVMPass
|
class TestLhloToLLVMPass
|
||||||
: public ::mlir::PassWrapper<TestLhloToLLVMPass,
|
: public ::mlir::PassWrapper<TestLhloToLLVMPass,
|
||||||
::mlir::OperationPass<::mlir::ModuleOp>> {
|
::mlir::OperationPass<::mlir::ModuleOp>> {
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<LLVM::LLVMDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
ModuleOp m = getOperation();
|
ModuleOp m = getOperation();
|
||||||
|
|
|
@ -691,6 +691,10 @@ class SelectAndScatterOpConverter
|
||||||
|
|
||||||
struct LhloLegalizeToParallelLoopsPass
|
struct LhloLegalizeToParallelLoopsPass
|
||||||
: public PassWrapper<LhloLegalizeToParallelLoopsPass, FunctionPass> {
|
: public PassWrapper<LhloLegalizeToParallelLoopsPass, FunctionPass> {
|
||||||
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
|
registry.insert<StandardOpsDialect, scf::SCFDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
|
|
||||||
|
|
|
@ -153,6 +153,10 @@ struct BinaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
||||||
|
|
||||||
struct TransformUnrankedHloPass
|
struct TransformUnrankedHloPass
|
||||||
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
|
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<shape::ShapeDialect>();
|
||||||
|
}
|
||||||
|
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
// Setup conversion target.
|
// Setup conversion target.
|
||||||
MLIRContext &ctx = getContext();
|
MLIRContext &ctx = getContext();
|
||||||
|
|
Loading…
Reference in New Issue