mlir-hlo-opt: set preloadDialectsInContext to false.
This requires specifying dependent dialects in several passes. PiperOrigin-RevId: 365758084
This commit is contained in:
parent
9ebadc4c4d
commit
6388e8d9ee
4
BUILD
4
BUILD
|
@ -642,6 +642,7 @@ cc_library(
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:LinalgOps",
|
"@llvm-project//mlir:LinalgOps",
|
||||||
"@llvm-project//mlir:MathDialect",
|
"@llvm-project//mlir:MathDialect",
|
||||||
|
"@llvm-project//mlir:MemRefDialect",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:SCFDialect",
|
"@llvm-project//mlir:SCFDialect",
|
||||||
"@llvm-project//mlir:StandardOps",
|
"@llvm-project//mlir:StandardOps",
|
||||||
|
@ -709,6 +710,7 @@ cc_library(
|
||||||
"@llvm-project//mlir:GPUDialect",
|
"@llvm-project//mlir:GPUDialect",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:LinalgOps",
|
"@llvm-project//mlir:LinalgOps",
|
||||||
|
"@llvm-project//mlir:MemRefDialect",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:SCFDialect",
|
"@llvm-project//mlir:SCFDialect",
|
||||||
"@llvm-project//mlir:StandardOps",
|
"@llvm-project//mlir:StandardOps",
|
||||||
|
@ -752,6 +754,7 @@ cc_library(
|
||||||
":map_hlo_to_lhlo_op",
|
":map_hlo_to_lhlo_op",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:MemRefDialect",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:Shape",
|
"@llvm-project//mlir:Shape",
|
||||||
"@llvm-project//mlir:ShapeTransforms",
|
"@llvm-project//mlir:ShapeTransforms",
|
||||||
|
@ -1033,6 +1036,7 @@ cc_library(
|
||||||
"@llvm-project//mlir:InferTypeOpInterface",
|
"@llvm-project//mlir:InferTypeOpInterface",
|
||||||
"@llvm-project//mlir:LLVMDialect",
|
"@llvm-project//mlir:LLVMDialect",
|
||||||
"@llvm-project//mlir:LLVMTransforms",
|
"@llvm-project//mlir:LLVMTransforms",
|
||||||
|
"@llvm-project//mlir:MemRefDialect",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:SCFDialect",
|
"@llvm-project//mlir:SCFDialect",
|
||||||
"@llvm-project//mlir:Shape",
|
"@llvm-project//mlir:Shape",
|
||||||
|
|
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
#include "mlir/Dialect/Shape/Transforms/Passes.h"
|
#include "mlir/Dialect/Shape/Transforms/Passes.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
@ -564,7 +565,8 @@ class HloToLhloTensorStoreOpLegacyConverter
|
||||||
struct HloLegalizeToLhlo
|
struct HloLegalizeToLhlo
|
||||||
: public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> {
|
: public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> {
|
||||||
void getDependentDialects(DialectRegistry& registry) const override {
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
registry.insert<lmhlo::LmhloDialect>();
|
registry.insert<lmhlo::LmhloDialect, memref::MemRefDialect,
|
||||||
|
shape::ShapeDialect>();
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||||
#include "mlir/Dialect/Math/IR/Math.h"
|
#include "mlir/Dialect/Math/IR/Math.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
@ -1965,7 +1966,9 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
struct LhloLegalizeToLinalgPass
|
struct LhloLegalizeToLinalgPass
|
||||||
: public PassWrapper<LhloLegalizeToLinalgPass, FunctionPass> {
|
: public PassWrapper<LhloLegalizeToLinalgPass, FunctionPass> {
|
||||||
void getDependentDialects(DialectRegistry& registry) const override {
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
registry.insert<AffineDialect, linalg::LinalgDialect, math::MathDialect>();
|
registry
|
||||||
|
.insert<AffineDialect, complex::ComplexDialect, linalg::LinalgDialect,
|
||||||
|
math::MathDialect, memref::MemRefDialect>();
|
||||||
}
|
}
|
||||||
|
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
|
@ -1986,8 +1989,9 @@ struct LhloLegalizeToLinalgPass
|
||||||
struct HloLegalizeToLinalgPass
|
struct HloLegalizeToLinalgPass
|
||||||
: public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
|
: public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
|
||||||
void getDependentDialects(DialectRegistry& registry) const override {
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
registry.insert<linalg::LinalgDialect, scf::SCFDialect,
|
registry
|
||||||
complex::ComplexDialect, math::MathDialect>();
|
.insert<linalg::LinalgDialect, scf::SCFDialect, complex::ComplexDialect,
|
||||||
|
math::MathDialect, memref::MemRefDialect>();
|
||||||
}
|
}
|
||||||
|
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
|
|
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
|
@ -174,7 +175,7 @@ struct LhloLegalizeToGpuPass
|
||||||
: public PassWrapper<LhloLegalizeToGpuPass, FunctionPass> {
|
: public PassWrapper<LhloLegalizeToGpuPass, FunctionPass> {
|
||||||
void getDependentDialects(DialectRegistry& registry) const override {
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect,
|
registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect,
|
||||||
scf::SCFDialect>();
|
memref::MemRefDialect, scf::SCFDialect>();
|
||||||
}
|
}
|
||||||
|
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/Identifier.h"
|
#include "mlir/IR/Identifier.h"
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
@ -83,6 +84,9 @@ struct ReifyReturnTypeShapesPattern : public RewritePattern {
|
||||||
|
|
||||||
struct TestInferShapedTypeMethodsPass
|
struct TestInferShapedTypeMethodsPass
|
||||||
: public PassWrapper<TestInferShapedTypeMethodsPass, FunctionPass> {
|
: public PassWrapper<TestInferShapedTypeMethodsPass, FunctionPass> {
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<shape::ShapeDialect>();
|
||||||
|
}
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
OwningRewritePatternList patterns(&getContext());
|
OwningRewritePatternList patterns(&getContext());
|
||||||
patterns.insert<ReifyReturnTypeShapesPattern>(&getContext());
|
patterns.insert<ReifyReturnTypeShapesPattern>(&getContext());
|
||||||
|
|
|
@ -528,7 +528,7 @@ struct ConvertUnrankedDynamicBroadcastSelectOp
|
||||||
struct TransformUnrankedHloPass
|
struct TransformUnrankedHloPass
|
||||||
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
|
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<chlo::HloClientDialect, mhlo::MhloDialect,
|
registry.insert<chlo::HloClientDialect, mhlo::MhloDialect, scf::SCFDialect,
|
||||||
shape::ShapeDialect>();
|
shape::ShapeDialect>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,9 @@ limitations under the License.
|
||||||
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/IR/Dialect.h"
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
#include "mlir/IR/Operation.h"
|
#include "mlir/IR/Operation.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
@ -29,6 +31,9 @@ namespace {
|
||||||
|
|
||||||
struct TestUnfuseBatchNormPass
|
struct TestUnfuseBatchNormPass
|
||||||
: public PassWrapper<TestUnfuseBatchNormPass, OperationPass<>> {
|
: public PassWrapper<TestUnfuseBatchNormPass, OperationPass<>> {
|
||||||
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
|
registry.insert<memref::MemRefDialect>();
|
||||||
|
}
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
OwningRewritePatternList patterns(&getContext());
|
OwningRewritePatternList patterns(&getContext());
|
||||||
PopulateUnfuseBatchNormPatterns(&getContext(), &patterns);
|
PopulateUnfuseBatchNormPatterns(&getContext(), &patterns);
|
||||||
|
|
|
@ -35,5 +35,6 @@ int main(int argc, char **argv) {
|
||||||
registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>();
|
registry.insert<mlir::lmhlo_gpu::LmhloGpuDialect>();
|
||||||
|
|
||||||
return failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n",
|
return failed(mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n",
|
||||||
registry, /*preloadDialectsInContext=*/true));
|
registry,
|
||||||
|
/*preloadDialectsInContext=*/false));
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue