Fix pass definition to inherit from the TableGen generated base class (NFC)

PiperOrigin-RevId: 379860210
This commit is contained in:
Mehdi Amini 2021-06-16 19:04:23 -07:00 committed by TensorFlow MLIR Team
parent 2e08c246e9
commit 8c8e81cb69
23 changed files with 126 additions and 124 deletions

15
BUILD
View File

@ -748,6 +748,7 @@ cc_library(
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"], hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"],
deps = [ deps = [
":hlo", ":hlo",
":pass_details",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -799,6 +800,7 @@ cc_library(
":hlo", ":hlo",
":lhlo", ":lhlo",
":map_lmhlo_to_scalar_op", ":map_lmhlo_to_scalar_op",
":pass_details",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine", "@llvm-project//mlir:Affine",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
@ -814,6 +816,7 @@ cc_library(
srcs = ["lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc"], srcs = ["lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc"],
deps = [ deps = [
":lhlo", ":lhlo",
":pass_details",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:LinalgOps",
@ -837,6 +840,7 @@ cc_library(
":hlo", ":hlo",
":lhlo", ":lhlo",
":map_lmhlo_to_scalar_op", ":map_lmhlo_to_scalar_op",
":pass_details",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine", "@llvm-project//mlir:Affine",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
@ -863,6 +867,7 @@ cc_library(
deps = [ deps = [
":hlo", ":hlo",
":map_chlo_to_hlo_op", ":map_chlo_to_hlo_op",
":pass_details",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -885,6 +890,7 @@ cc_library(
deps = [ deps = [
":hlo", ":hlo",
":map_chlo_to_hlo_op", ":map_chlo_to_hlo_op",
":pass_details",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:InferTypeOpInterface",
@ -928,6 +934,7 @@ cc_library(
":hlo", ":hlo",
":lhlo", ":lhlo",
":map_lmhlo_to_scalar_op", ":map_lmhlo_to_scalar_op",
":pass_details",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine", "@llvm-project//mlir:Affine",
"@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:GPUDialect",
@ -948,6 +955,7 @@ cc_library(
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"], hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"],
deps = [ deps = [
":lhlo", ":lhlo",
":pass_details",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine", "@llvm-project//mlir:Affine",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
@ -1008,6 +1016,7 @@ cc_library(
deps = [ deps = [
":cycle_detector", ":cycle_detector",
":hlo", ":hlo",
":pass_details",
"@llvm-project//llvm:Core", "@llvm-project//llvm:Core",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
@ -1046,6 +1055,7 @@ cc_library(
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"], hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"],
deps = [ deps = [
":hlo", ":hlo",
":pass_details",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -1064,6 +1074,7 @@ cc_library(
":hlo", ":hlo",
":legalize_to_standard_inc_gen", ":legalize_to_standard_inc_gen",
":legalize_trigonometric_to_approximation", ":legalize_trigonometric_to_approximation",
":pass_details",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -1083,6 +1094,7 @@ cc_library(
], ],
deps = [ deps = [
":hlo", ":hlo",
":pass_details",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -1102,6 +1114,7 @@ cc_library(
], ],
includes = ["include"], includes = ["include"],
deps = [ deps = [
":pass_details",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:MathDialect", "@llvm-project//mlir:MathDialect",
@ -1149,6 +1162,7 @@ cc_library(
deps = [ deps = [
":hlo", ":hlo",
":lower_complex_inc_gen", ":lower_complex_inc_gen",
":pass_details",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
@ -1201,6 +1215,7 @@ cc_library(
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"], hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"],
deps = [ deps = [
":lhlo", ":lhlo",
":pass_details",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:MemRefDialect",

View File

@ -15,13 +15,13 @@ limitations under the License.
include "mlir/Pass/PassBase.td" include "mlir/Pass/PassBase.td"
def LhloLegalizeToLinalgPass : Pass<"lhlo-legalize-to-linalg", "FuncOp"> { def LhloLegalizeToLinalgPass : FunctionPass<"lhlo-legalize-to-linalg"> {
let summary = "Legalize from LHLO dialect to Linalg dialect."; let summary = "Legalize from LHLO dialect to Linalg dialect.";
let constructor = "createLegalizeLhloToLinalgPass()"; let constructor = "createLegalizeLhloToLinalgPass()";
} }
def LhloFuseLinalgPass : Pass<"lhlo-fuse-linalg", "FuncOp"> { def LhloFuseLinalgPass : FunctionPass<"lhlo-fuse-linalg"> {
let summary = "Greedily fuse linalg ops obtained after LHLO lowering."; let summary = "Greedily fuse linalg ops obtained after LHLO lowering.";
let constructor = "createLhloFuseLinalgPass()"; let constructor = "createLhloFuseLinalgPass()";
let options = [ let options = [
@ -34,24 +34,24 @@ def LhloFuseLinalgPass : Pass<"lhlo-fuse-linalg", "FuncOp"> {
} }
def LhloLegalizeToAffinePass : Pass<"lhlo-legalize-to-affine", "FuncOp"> { def LhloLegalizeToAffinePass : FunctionPass<"lhlo-legalize-to-affine"> {
let summary = "Legalize from LHLO dialect to affine dialect."; let summary = "Legalize from LHLO dialect to affine dialect.";
let constructor = "createLhloLegalizeToAffinePass()"; let constructor = "createLhloLegalizeToAffinePass()";
} }
def LhloLegalizeToGpuPass : Pass<"lhlo-legalize-to-gpu", "FuncOp"> { def LhloLegalizeToGpuPass : FunctionPass<"lhlo-legalize-to-gpu"> {
let summary = "Legalize from LHLO dialect to GPU dialect."; let summary = "Legalize from LHLO dialect to GPU dialect.";
let constructor = "createLegalizeToGpuPass()"; let constructor = "createLegalizeToGpuPass()";
} }
def LhloLegalizeToParallelLoopsPass : Pass<"lhlo-legalize-to-parallel-loops", "FuncOp"> { def LhloLegalizeToParallelLoopsPass : FunctionPass<"lhlo-legalize-to-parallel-loops"> {
let summary = "Legalize from LHLO dialect to parallel loops."; let summary = "Legalize from LHLO dialect to parallel loops.";
let constructor = "createLegalizeLhloToParallelLoopsPass()"; let constructor = "createLegalizeLhloToParallelLoopsPass()";
} }
def LegalizeTensorLoadOpPass : Pass<"lhlo-legalize-tensor-load-op", "FuncOp"> { def LegalizeTensorLoadOpPass : FunctionPass<"lhlo-legalize-tensor-load-op"> {
let summary = "Legalize tensor load ops that are inserted during mhlo to lmhlo conversion."; let summary = "Legalize tensor load ops that are inserted during mhlo to lmhlo conversion.";
let constructor = "createLegalizeTensorLoadOpPass()"; let constructor = "createLegalizeTensorLoadOpPass()";
} }

View File

@ -37,64 +37,64 @@ def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> {
]; ];
} }
def LegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "FuncOp"> { def LegalizeControlFlowPass : FunctionPass<"mhlo-legalize-control-flow"> {
let summary = "Legalize from MHLO control flow to CFG control flow."; let summary = "Legalize from MHLO control flow to CFG control flow.";
let constructor = "createLegalizeControlFlowPass()"; let constructor = "createLegalizeControlFlowPass()";
} }
def LegalizeControlFlowToScfPass : Pass<"mhlo-control-flow-to-scf", "FuncOp"> { def LegalizeControlFlowToScfPass : FunctionPass<"mhlo-control-flow-to-scf"> {
let summary = "Legalize from MHLO control flow to SCF control flow."; let summary = "Legalize from MHLO control flow to SCF control flow.";
let constructor = "createControlFlowToScfPass()"; let constructor = "createControlFlowToScfPass()";
} }
def LegalizeGatherToTorchIndexSelectPass : Pass<"mhlo-legalize-gather-to-torch-index-select", "FuncOp"> { def LegalizeGatherToTorchIndexSelectPass : FunctionPass<"mhlo-legalize-gather-to-torch-index-select"> {
let summary = "Legalizes gathers to a torch index select."; let summary = "Legalizes gathers to a torch index select.";
let constructor = "createLegalizeGatherToTorchIndexSelectPass()"; let constructor = "createLegalizeGatherToTorchIndexSelectPass()";
} }
def LegalizeTanhToApproximationPass : Pass<"mhlo-legalize-trigonometric-to-approximation", "FuncOp"> { def LegalizeTanhToApproximationPass : FunctionPass<"mhlo-legalize-trigonometric-to-approximation"> {
let summary = "Legalize trigonometric operations from standard dialect to an approximation."; let summary = "Legalize trigonometric operations from standard dialect to an approximation.";
let constructor = "createLegalizeTrigonometricToApproximationPass()"; let constructor = "createLegalizeTrigonometricToApproximationPass()";
} }
def HloLegalizeToLinalgPass : Pass<"hlo-legalize-to-linalg", "FuncOp"> { def HloLegalizeToLinalgPass : FunctionPass<"hlo-legalize-to-linalg"> {
let summary = "Legalize from HLO dialect to Linalg dialect."; let summary = "Legalize from HLO dialect to Linalg dialect.";
let constructor = "createLegalizeHloToLinalgPass()"; let constructor = "createLegalizeHloToLinalgPass()";
} }
def LegalizeToStandardPass : Pass<"mhlo-legalize-to-std", "FuncOp"> { def LegalizeToStandardPass : FunctionPass<"mhlo-legalize-to-std"> {
let summary = "Legalize from MHLO dialect to standard dialect."; let summary = "Legalize from MHLO dialect to standard dialect.";
let constructor = "createLegalizeToStdPass()"; let constructor = "createLegalizeToStdPass()";
} }
def LowerComplexPass : Pass<"mhlo-test-lower-complex", "FuncOp"> { def LowerComplexPass : FunctionPass<"mhlo-test-lower-complex"> {
let summary = "Lower complex operations into non-complex operations."; let summary = "Lower complex operations into non-complex operations.";
let constructor = "createLowerComplexPass()"; let constructor = "createLowerComplexPass()";
} }
def LegalizeGeneralDotPass : Pass<"mhlo-test-lower-general-dot", "FuncOp"> { def LegalizeGeneralDotPass : FunctionPass<"mhlo-test-lower-general-dot"> {
let summary = "Tests lowering general dot to a non-batched dot when possible."; let summary = "Tests lowering general dot to a non-batched dot when possible.";
let constructor = "createLegalizeGeneralDotPass()"; let constructor = "createLegalizeGeneralDotPass()";
} }
def TestMaterializeBroadcastsPass : Pass<"mhlo-test-materialize-broadcasts", "FuncOp"> { def TestMaterializeBroadcastsPass : FunctionPass<"mhlo-test-materialize-broadcasts"> {
let summary = "Test pass for materializing 'broadcast_dimensions' attributes."; let summary = "Test pass for materializing 'broadcast_dimensions' attributes.";
let constructor = "createTestMaterializeBroadcastsPass()"; let constructor = "createTestMaterializeBroadcastsPass()";
} }
def MhloFusionPass : Pass<"mhlo-fusion", "FuncOp"> { def MhloFusionPass : FunctionPass<"mhlo-fusion"> {
let summary = "Fuse mhlo ops to kLoop/kInput fusion patterns."; let summary = "Fuse mhlo ops to kLoop/kInput fusion patterns.";
let constructor = "createMhloFusionPass()"; let constructor = "createMhloFusionPass()";
} }
def OptimizeMhloPass : Pass<"mhlo-test-optimize", "FuncOp"> { def OptimizeMhloPass : FunctionPass<"mhlo-test-optimize"> {
let summary = "Run optional HLO optimizations."; let summary = "Run optional HLO optimizations.";
let constructor = "createOptimizeMhloPass()"; let constructor = "createOptimizeMhloPass()";
} }
@ -107,18 +107,18 @@ def SinkConstantsToControlFlowPass : FunctionPass<"mhlo-sink-constants-to-contro
} }
def TestInferShapedTypeMethodsPass : Pass<"mhlo-test-infer-shaped-type-methods", "FuncOp"> { def TestInferShapedTypeMethodsPass : FunctionPass<"mhlo-test-infer-shaped-type-methods"> {
let summary = "Uses test ops to invoke InferShapedTypeOpInterface methods."; let summary = "Uses test ops to invoke InferShapedTypeOpInterface methods.";
let constructor = "createTestInferShapedTypeMethodsPass()"; let constructor = "createTestInferShapedTypeMethodsPass()";
} }
def TransformUnrankedHloPass : Pass<"mhlo-transform-unranked-hlo", "FuncOp"> { def TransformUnrankedHloPass : FunctionPass<"mhlo-transform-unranked-hlo"> {
let summary = "Realize element-wise operations on ranked tensors where possible."; let summary = "Realize element-wise operations on ranked tensors where possible.";
let constructor = "createTransformUnrankedHloPass()"; let constructor = "createTransformUnrankedHloPass()";
} }
def BroadcastPropagationPass : Pass<"mhlo-broadcast-propagation", "FuncOp"> { def BroadcastPropagationPass : FunctionPass<"mhlo-broadcast-propagation"> {
let summary = "Move dynamic broadcasts up over element-wise operations and " let summary = "Move dynamic broadcasts up over element-wise operations and "
"broadcast the operands rather than the result. This will eventually allow " "broadcast the operands rather than the result. This will eventually allow "
"for larger fusions."; "for larger fusions.";

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
@ -376,7 +377,7 @@ struct EarlyBroadcastInDimOpPattern
}; };
struct BroadcastPropagationPass struct BroadcastPropagationPass
: public PassWrapper<BroadcastPropagationPass, FunctionPass> { : public BroadcastPropagationPassBase<BroadcastPropagationPass> {
void getDependentDialects(DialectRegistry &registry) const override { void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<shape::ShapeDialect, mhlo::MhloDialect>(); registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
} }

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project #include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project
@ -37,7 +38,7 @@ namespace mlir {
namespace mhlo { namespace mhlo {
namespace { namespace {
struct LegalizeControlFlowPass struct LegalizeControlFlowPass
: public mlir::PassWrapper<LegalizeControlFlowPass, FunctionPass> { : public LegalizeControlFlowPassBase<LegalizeControlFlowPass> {
// Perform the lowering to MLIR control flow. // Perform the lowering to MLIR control flow.
void runOnFunction() override; void runOnFunction() override;
}; };

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinOps.h"
@ -128,7 +129,8 @@ struct GatherIsTorchIndexSelect : public OpRewritePattern<GatherOp> {
}; };
struct LegalizeGatherToTorchIndexSelectPass struct LegalizeGatherToTorchIndexSelectPass
: public PassWrapper<LegalizeGatherToTorchIndexSelectPass, FunctionPass> { : public LegalizeGatherToTorchIndexSelectPassBase<
LegalizeGatherToTorchIndexSelectPass> {
/// Perform the lowering of standard dialect operations to approximations. /// Perform the lowering of standard dialect operations to approximations.
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns(&getContext()); OwningRewritePatternList patterns(&getContext());

View File

@ -16,6 +16,7 @@ limitations under the License.
// This file implements logic for lowering memref.tensor_load ops that are // This file implements logic for lowering memref.tensor_load ops that are
// inserted during `mhlo-legalize-to-lmhlo`. // inserted during `mhlo-legalize-to-lmhlo`.
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -71,7 +72,7 @@ struct ForwardShapeOfOp : public OpRewritePattern<ShapeOfOp> {
}; };
struct LegalizeTensorLoadOpPass struct LegalizeTensorLoadOpPass
: public mlir::PassWrapper<LegalizeTensorLoadOpPass, FunctionPass> { : public LegalizeTensorLoadOpPassBase<LegalizeTensorLoadOpPass> {
// Perform the lowering to remove memref.tensor_load ops inserted during // Perform the lowering to remove memref.tensor_load ops inserted during
// `mhlo-legalize-to-lmhlo`. // `mhlo-legalize-to-lmhlo`.
void runOnFunction() override { void runOnFunction() override {

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SetVector.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h"
@ -2428,7 +2429,7 @@ class RemoveSignTypeConverter : public TypeConverter {
// iterator_types = ["parallel", "parallel"], // iterator_types = ["parallel", "parallel"],
// } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () // } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
struct LhloLegalizeToLinalgPass struct LhloLegalizeToLinalgPass
: public PassWrapper<LhloLegalizeToLinalgPass, FunctionPass> { : public lmhlo::LhloLegalizeToLinalgPassBase<LhloLegalizeToLinalgPass> {
void getDependentDialects(DialectRegistry& registry) const override { void getDependentDialects(DialectRegistry& registry) const override {
registry registry
.insert<AffineDialect, complex::ComplexDialect, linalg::LinalgDialect, .insert<AffineDialect, complex::ComplexDialect, linalg::LinalgDialect,
@ -2454,7 +2455,7 @@ struct LhloLegalizeToLinalgPass
}; };
struct HloLegalizeToLinalgPass struct HloLegalizeToLinalgPass
: public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> { : public mhlo::HloLegalizeToLinalgPassBase<HloLegalizeToLinalgPass> {
void getDependentDialects(DialectRegistry& registry) const override { void getDependentDialects(DialectRegistry& registry) const override {
registry registry
.insert<linalg::LinalgDialect, scf::SCFDialect, complex::ComplexDialect, .insert<linalg::LinalgDialect, scf::SCFDialect, complex::ComplexDialect,

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/StringSwitch.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -177,7 +178,7 @@ class ConvertIotaOp : public OpRewritePattern<mhlo::IotaOp> {
namespace { namespace {
struct LegalizeToStandardPass struct LegalizeToStandardPass
: public PassWrapper<LegalizeToStandardPass, FunctionPass> { : public LegalizeToStandardPassBase<LegalizeToStandardPass> {
void getDependentDialects(DialectRegistry &registry) const override { void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<StandardOpsDialect>(); registry.insert<StandardOpsDialect>();
} }

View File

@ -16,6 +16,7 @@ limitations under the License.
// This file implements the lowering for trigonometric standard ops to // This file implements the lowering for trigonometric standard ops to
// approximations. // approximations.
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/IR/Math.h"
@ -154,8 +155,8 @@ class ApproximateTanhLowering
}; };
struct LegalizeTrigonometricToApproximationPass struct LegalizeTrigonometricToApproximationPass
: public PassWrapper<LegalizeTrigonometricToApproximationPass, : public LegalizeTanhToApproximationPassBase<
FunctionPass> { LegalizeTrigonometricToApproximationPass> {
/// Perform the lowering of standard dialect operations to approximations. /// Perform the lowering of standard dialect operations to approximations.
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns(&getContext()); OwningRewritePatternList patterns(&getContext());

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
@ -36,8 +37,7 @@ namespace {
using linalg::LinalgOp; using linalg::LinalgOp;
class LhloFuseLinalgPass class LhloFuseLinalgPass : public LhloFuseLinalgPassBase<LhloFuseLinalgPass> {
: public PassWrapper<LhloFuseLinalgPass, FunctionPass> {
void getDependentDialects(DialectRegistry& registry) const override { void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>(); registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
} }
@ -202,18 +202,6 @@ class LhloFuseLinalgPass
.setLoopType(loopType)); .setLoopType(loopType));
return tiled_generic_op.hasValue(); return tiled_generic_op.hasValue();
} }
Option<bool> use_parallel_loops_{
*this, "use-parallel-loops",
llvm::cl::desc(
"Tiles GenericOp consumer to parallel loops before linalg fusion"),
llvm::cl::init(false)};
ListOption<unsigned> tile_sizes_{
*this, "tile-sizes",
llvm::cl::desc(
"Tile sizes by which to tile linalg generic before linalg fusion"),
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
}; };
} // namespace } // namespace

View File

@ -16,6 +16,7 @@ limitations under the License.
// This file implements logic for lowering LHLO dialect to Affine dialect. // This file implements logic for lowering LHLO dialect to Affine dialect.
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -229,7 +230,7 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context,
} }
struct LhloLegalizeToAffinePass struct LhloLegalizeToAffinePass
: public PassWrapper<LhloLegalizeToAffinePass, FunctionPass> { : public LhloLegalizeToAffinePassBase<LhloLegalizeToAffinePass> {
void getDependentDialects(DialectRegistry& registry) const override { void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<AffineDialect>(); registry.insert<AffineDialect>();
} }

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/GPUDialect.h"
@ -172,7 +173,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
}; };
struct LhloLegalizeToGpuPass struct LhloLegalizeToGpuPass
: public PassWrapper<LhloLegalizeToGpuPass, FunctionPass> { : public LhloLegalizeToGpuPassBase<LhloLegalizeToGpuPass> {
void getDependentDialects(DialectRegistry& registry) const override { void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect, registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect,
memref::MemRefDialect, scf::SCFDialect>(); memref::MemRefDialect, scf::SCFDialect>();

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
@ -700,7 +701,8 @@ class SelectAndScatterOpConverter
}; };
struct LhloLegalizeToParallelLoopsPass struct LhloLegalizeToParallelLoopsPass
: public PassWrapper<LhloLegalizeToParallelLoopsPass, FunctionPass> { : public LhloLegalizeToParallelLoopsPassBase<
LhloLegalizeToParallelLoopsPass> {
void getDependentDialects(DialectRegistry& registry) const override { void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<StandardOpsDialect, scf::SCFDialect>(); registry.insert<StandardOpsDialect, scf::SCFDialect>();
} }

View File

@ -24,7 +24,9 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir-hlo/utils/hlo_utils.h" #include "mlir-hlo/utils/hlo_utils.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
@ -35,35 +37,17 @@ limitations under the License.
#include "mlir/Pass/PassRegistry.h" #include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using mlir::FunctionPass;
using mlir::OwningRewritePatternList;
using mlir::PassWrapper;
namespace {
class LowerComplexPass : public PassWrapper<LowerComplexPass, FunctionPass> {
public:
explicit LowerComplexPass() : PassWrapper<LowerComplexPass, FunctionPass>() {}
/// Performs the lowering to MHLO dialect.
void runOnFunction() override;
};
} // end anonymous namespace
namespace mlir { namespace mlir {
namespace mhlo { namespace mhlo {
namespace { namespace {
class LowerComplexPass : public LowerComplexPassBase<LowerComplexPass> {
public:
/// Performs the lowering to MHLO dialect.
void runOnFunction() override;
};
#include "generated_lower_complex.inc" #include "generated_lower_complex.inc"
} // end anonymous namespace
void PopulateComplexLoweringPatterns(MLIRContext* context,
OwningRewritePatternList* patterns) {
populateWithGenerated(*patterns);
}
} // end namespace mhlo
} // end namespace mlir
// Lowers the complex operations that can be represented using other operations. // Lowers the complex operations that can be represented using other operations.
void LowerComplexPass::runOnFunction() { void LowerComplexPass::runOnFunction() {
// Add lowering patterns to the list. // Add lowering patterns to the list.
@ -73,6 +57,15 @@ void LowerComplexPass::runOnFunction() {
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
} }
std::unique_ptr<FunctionPass> mlir::mhlo::createLowerComplexPass() { } // end anonymous namespace
return std::make_unique<LowerComplexPass>(); } // end namespace mhlo
} // end namespace mlir
void mlir::mhlo::PopulateComplexLoweringPatterns(
MLIRContext* context, OwningRewritePatternList* patterns) {
populateWithGenerated(*patterns);
}
std::unique_ptr<mlir::FunctionPass> mlir::mhlo::createLowerComplexPass() {
return std::make_unique<mlir::mhlo::LowerComplexPass>();
} }

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/StringSwitch.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -30,28 +31,16 @@ limitations under the License.
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using mlir::DenseIntElementsAttr; namespace mlir {
using mlir::ElementsAttr; namespace mhlo {
using mlir::failure;
using mlir::FunctionPass;
using mlir::LogicalResult;
using mlir::MLIRContext;
using mlir::OpRewritePattern;
using mlir::OwningRewritePatternList;
using mlir::PassWrapper;
using mlir::PatternRewriter;
using mlir::RankedTensorType;
using mlir::success;
using mlir::Value;
namespace { namespace {
Value TransposeReshape(Value arg, mlir::Location loc, Value TransposeReshape(Value arg, Location loc,
llvm::ArrayRef<int64_t> left_dims, llvm::ArrayRef<int64_t> left_dims,
llvm::ArrayRef<int64_t> right_dims, llvm::ArrayRef<int64_t> right_dims,
llvm::ArrayRef<int64_t> arg_shape, llvm::ArrayRef<int64_t> arg_shape,
PatternRewriter *rewriter) { PatternRewriter *rewriter) {
auto element_type = mlir::getElementTypeOrSelf(arg.getType()); auto element_type = getElementTypeOrSelf(arg.getType());
int64_t left_size = 1; int64_t left_size = 1;
for (auto dim : left_dims) { for (auto dim : left_dims) {
@ -68,7 +57,7 @@ Value TransposeReshape(Value arg, mlir::Location loc,
left_dims.end()); left_dims.end());
transpose_permutation.append(right_dims.begin(), right_dims.end()); transpose_permutation.append(right_dims.begin(), right_dims.end());
mlir::TensorType transpose_permutation_type = RankedTensorType::get( TensorType transpose_permutation_type = RankedTensorType::get(
{static_cast<int64_t>(transpose_permutation.size())}, {static_cast<int64_t>(transpose_permutation.size())},
rewriter->getIntegerType(64)); rewriter->getIntegerType(64));
@ -83,20 +72,18 @@ Value TransposeReshape(Value arg, mlir::Location loc,
transposed_shape.push_back(arg_shape[val]); transposed_shape.push_back(arg_shape[val]);
} }
auto transpose_type = RankedTensorType::get(transposed_shape, element_type); auto transpose_type = RankedTensorType::get(transposed_shape, element_type);
auto transpose_result = rewriter->create<mlir::mhlo::TransposeOp>( auto transpose_result = rewriter->create<TransposeOp>(
loc, transpose_type, arg, transpose_permutation_attr); loc, transpose_type, arg, transpose_permutation_attr);
// Return the final result. // Return the final result.
auto reshaped_type = auto reshaped_type =
RankedTensorType::get({left_size, right_size}, element_type); RankedTensorType::get({left_size, right_size}, element_type);
return rewriter->create<mlir::mhlo::ReshapeOp>(loc, reshaped_type, return rewriter->create<ReshapeOp>(loc, reshaped_type, transpose_result);
transpose_result);
} }
Value ProcessDotArg(Value arg, mlir::Location loc, Value ProcessDotArg(Value arg, Location loc, ElementsAttr contract_dims_attr,
ElementsAttr contract_dims_attr, bool outer_dims_first, bool outer_dims_first, PatternRewriter *rewriter) {
PatternRewriter *rewriter) { auto shape = arg.getType().cast<ShapedType>().getShape();
auto shape = arg.getType().cast<mlir::ShapedType>().getShape();
llvm::SmallVector<bool, 5> is_outer_dim; llvm::SmallVector<bool, 5> is_outer_dim;
is_outer_dim.resize(shape.size(), true); is_outer_dim.resize(shape.size(), true);
@ -124,7 +111,7 @@ Value ProcessDotArg(Value arg, mlir::Location loc,
return TransposeReshape(arg, loc, contract_dims, outer_dims, shape, rewriter); return TransposeReshape(arg, loc, contract_dims, outer_dims, shape, rewriter);
} }
struct GeneralDotConvert : public OpRewritePattern<mlir::mhlo::DotGeneralOp> { struct GeneralDotConvert : public OpRewritePattern<DotGeneralOp> {
// Attempts to lower a General Dot operator to a standard Dot operator. // Attempts to lower a General Dot operator to a standard Dot operator.
// General dots include batching dimensions and can have collapsing // General dots include batching dimensions and can have collapsing
// dimensions along any axis. Inserting correctly arrange transpose and // dimensions along any axis. Inserting correctly arrange transpose and
@ -136,9 +123,9 @@ struct GeneralDotConvert : public OpRewritePattern<mlir::mhlo::DotGeneralOp> {
explicit GeneralDotConvert(MLIRContext *context) explicit GeneralDotConvert(MLIRContext *context)
: OpRewritePattern(context) {} : OpRewritePattern(context) {}
LogicalResult matchAndRewrite(mlir::mhlo::DotGeneralOp op, LogicalResult matchAndRewrite(DotGeneralOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto dot_element_type = mlir::getElementTypeOrSelf(op); auto dot_element_type = getElementTypeOrSelf(op);
auto dot_numbers = op.dot_dimension_numbers(); auto dot_numbers = op.dot_dimension_numbers();
if (dot_numbers.lhs_batching_dimensions().getNumElements() != 0 || if (dot_numbers.lhs_batching_dimensions().getNumElements() != 0 ||
@ -155,8 +142,8 @@ struct GeneralDotConvert : public OpRewritePattern<mlir::mhlo::DotGeneralOp> {
/*outer_dims_first=*/false, &rewriter); /*outer_dims_first=*/false, &rewriter);
// Accept only static shaped types. // Accept only static shaped types.
auto lhs_shape_type = lhs.getType().dyn_cast_or_null<mlir::ShapedType>(); auto lhs_shape_type = lhs.getType().dyn_cast_or_null<ShapedType>();
auto rhs_shape_type = rhs.getType().dyn_cast_or_null<mlir::ShapedType>(); auto rhs_shape_type = rhs.getType().dyn_cast_or_null<ShapedType>();
if (!lhs_shape_type || !rhs_shape_type) return failure(); if (!lhs_shape_type || !rhs_shape_type) return failure();
if (!lhs_shape_type.hasStaticShape() || !rhs_shape_type.hasStaticShape()) if (!lhs_shape_type.hasStaticShape() || !rhs_shape_type.hasStaticShape())
return failure(); return failure();
@ -167,28 +154,29 @@ struct GeneralDotConvert : public OpRewritePattern<mlir::mhlo::DotGeneralOp> {
auto new_dot_type = auto new_dot_type =
RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type); RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type);
mlir::ArrayAttr precision_config; ArrayAttr precision_config;
if (op.precision_config()) precision_config = *op.precision_config(); if (op.precision_config()) precision_config = *op.precision_config();
auto new_dot_op = rewriter.create<mlir::mhlo::DotOp>( auto new_dot_op = rewriter.create<DotOp>(op.getLoc(), new_dot_type, lhs,
op.getLoc(), new_dot_type, lhs, rhs, precision_config); rhs, precision_config);
rewriter.replaceOpWithNewOp<mlir::mhlo::ReshapeOp>(op, op.getType(), rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), new_dot_op);
new_dot_op);
return success(); return success();
} }
}; };
struct LegalizeGeneralDotPass struct LegalizeGeneralDotPass
: public PassWrapper<LegalizeGeneralDotPass, FunctionPass> { : public LegalizeGeneralDotPassBase<LegalizeGeneralDotPass> {
/// Lower all general dots that can be represented as a non-batched matmul. /// Lower all general dots that can be represented as a non-batched matmul.
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns(&getContext()); OwningRewritePatternList patterns(&getContext());
mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(&patterns, &getContext()); PopulateGeneralDotOpLoweringPatterns(&patterns, &getContext());
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
} }
}; };
} // namespace } // namespace
} // namespace mhlo
} // namespace mlir
void mlir::mhlo::PopulateGeneralDotOpLoweringPatterns( void mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(
OwningRewritePatternList *patterns, MLIRContext *ctx) { OwningRewritePatternList *patterns, MLIRContext *ctx) {

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
@ -28,7 +29,7 @@ namespace mhlo {
namespace { namespace {
struct TestMaterializeBroadcastsPass struct TestMaterializeBroadcastsPass
: public PassWrapper<TestMaterializeBroadcastsPass, FunctionPass> { : public TestMaterializeBroadcastsPassBase<TestMaterializeBroadcastsPass> {
void runOnFunction() override { void runOnFunction() override {
ConversionTarget conversionTarget(getContext()); ConversionTarget conversionTarget(getContext());
OwningRewritePatternList conversionPatterns(&getContext()); OwningRewritePatternList conversionPatterns(&getContext());

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -39,7 +40,7 @@ void MatchAndRewrite(WhileOp whileOp);
/// Pass that converts MHLO control flow to SCF. /// Pass that converts MHLO control flow to SCF.
class ControlFlowToScfPass class ControlFlowToScfPass
: public mlir::PassWrapper<ControlFlowToScfPass, FunctionPass> { : public LegalizeControlFlowToScfPassBase<ControlFlowToScfPass> {
void getDependentDialects(DialectRegistry& registry) const override { void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<scf::SCFDialect>(); registry.insert<scf::SCFDialect>();
} }

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/utils/cycle_detector.h" #include "mlir-hlo/utils/cycle_detector.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project
@ -480,7 +481,7 @@ class FusionPlanner {
EquivalenceClasses<int32_t> leader_for_node_; EquivalenceClasses<int32_t> leader_for_node_;
}; };
struct MhloFusionPass : public mlir::PassWrapper<MhloFusionPass, FunctionPass> { struct MhloFusionPass : public MhloFusionPassBase<MhloFusionPass> {
void runOnFunction() override { void runOnFunction() override {
FuncOp func = getFunction(); FuncOp func = getFunction();
if (!IsTargetFunc(func)) { if (!IsTargetFunc(func)) {

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -23,28 +24,26 @@ limitations under the License.
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using mlir::FunctionPass; namespace mlir {
using mlir::PassWrapper; namespace mhlo {
namespace { namespace {
class OptimizeMhloPass : public PassWrapper<OptimizeMhloPass, FunctionPass> { class OptimizeMhloPass : public OptimizeMhloPassBase<OptimizeMhloPass> {
public: public:
explicit OptimizeMhloPass() : PassWrapper<OptimizeMhloPass, FunctionPass>() {}
/// Performs the lowering to MHLO dialect. /// Performs the lowering to MHLO dialect.
void runOnFunction() override; void runOnFunction() override;
}; };
} // end anonymous namespace
// Lowers the complex operations that can be represented using other operations. // Lowers the complex operations that can be represented using other operations.
void OptimizeMhloPass::runOnFunction() { void OptimizeMhloPass::runOnFunction() {
// Add lowering patterns to the list. // Add lowering patterns to the list.
mlir::OwningRewritePatternList patterns(&getContext()); OwningRewritePatternList patterns(&getContext());
mlir::mhlo::PopulateOptimizeMHLOPatterns(&getContext(), &patterns); PopulateOptimizeMHLOPatterns(&getContext(), &patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
} }
} // end anonymous namespace
} // namespace mhlo
} // namespace mlir
std::unique_ptr<mlir::FunctionPass> mlir::mhlo::createOptimizeMhloPass() { std::unique_ptr<mlir::FunctionPass> mlir::mhlo::createOptimizeMhloPass() {
return std::make_unique<OptimizeMhloPass>(); return std::make_unique<mlir::mhlo::OptimizeMhloPass>();
} }

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/Identifier.h" #include "mlir/IR/Identifier.h"
@ -84,7 +85,8 @@ struct ReifyReturnTypeShapesPattern : public RewritePattern {
}; };
struct TestInferShapedTypeMethodsPass struct TestInferShapedTypeMethodsPass
: public PassWrapper<TestInferShapedTypeMethodsPass, FunctionPass> { : public TestInferShapedTypeMethodsPassBase<
TestInferShapedTypeMethodsPass> {
void getDependentDialects(DialectRegistry &registry) const override { void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<shape::ShapeDialect>(); registry.insert<shape::ShapeDialect>();
} }

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
@ -538,7 +539,7 @@ struct ConvertUnrankedDynamicBroadcastNaryOp
}; };
struct TransformUnrankedHloPass struct TransformUnrankedHloPass
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> { : public mhlo::TransformUnrankedHloPassBase<TransformUnrankedHloPass> {
void getDependentDialects(DialectRegistry &registry) const override { void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<chlo::HloClientDialect, mhlo::MhloDialect, scf::SCFDialect, registry.insert<chlo::HloClientDialect, mhlo::MhloDialect, scf::SCFDialect,
shape::ShapeDialect>(); shape::ShapeDialect>();

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -30,7 +31,7 @@ namespace mhlo {
namespace { namespace {
struct TestUnfuseBatchNormPass struct TestUnfuseBatchNormPass
: public PassWrapper<TestUnfuseBatchNormPass, OperationPass<>> { : public TestUnfuseBatchNormPassBase<TestUnfuseBatchNormPass> {
void getDependentDialects(DialectRegistry& registry) const override { void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<memref::MemRefDialect>(); registry.insert<memref::MemRefDialect>();
} }