Fix pass definition to inherit from the TableGen generated base class (NFC)

PiperOrigin-RevId: 379860210
This commit is contained in:
Mehdi Amini 2021-06-16 19:04:23 -07:00 committed by TensorFlow MLIR Team
parent 2e08c246e9
commit 8c8e81cb69
23 changed files with 126 additions and 124 deletions

15
BUILD
View File

@ -748,6 +748,7 @@ cc_library(
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"],
deps = [
":hlo",
":pass_details",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -799,6 +800,7 @@ cc_library(
":hlo",
":lhlo",
":map_lmhlo_to_scalar_op",
":pass_details",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:IR",
@ -814,6 +816,7 @@ cc_library(
srcs = ["lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc"],
deps = [
":lhlo",
":pass_details",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
@ -837,6 +840,7 @@ cc_library(
":hlo",
":lhlo",
":map_lmhlo_to_scalar_op",
":pass_details",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:IR",
@ -863,6 +867,7 @@ cc_library(
deps = [
":hlo",
":map_chlo_to_hlo_op",
":pass_details",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -885,6 +890,7 @@ cc_library(
deps = [
":hlo",
":map_chlo_to_hlo_op",
":pass_details",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
@ -928,6 +934,7 @@ cc_library(
":hlo",
":lhlo",
":map_lmhlo_to_scalar_op",
":pass_details",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:GPUDialect",
@ -948,6 +955,7 @@ cc_library(
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"],
deps = [
":lhlo",
":pass_details",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:IR",
@ -1008,6 +1016,7 @@ cc_library(
deps = [
":cycle_detector",
":hlo",
":pass_details",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
@ -1046,6 +1055,7 @@ cc_library(
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"],
deps = [
":hlo",
":pass_details",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -1064,6 +1074,7 @@ cc_library(
":hlo",
":legalize_to_standard_inc_gen",
":legalize_trigonometric_to_approximation",
":pass_details",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -1083,6 +1094,7 @@ cc_library(
],
deps = [
":hlo",
":pass_details",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -1102,6 +1114,7 @@ cc_library(
],
includes = ["include"],
deps = [
":pass_details",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MathDialect",
@ -1149,6 +1162,7 @@ cc_library(
deps = [
":hlo",
":lower_complex_inc_gen",
":pass_details",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
@ -1201,6 +1215,7 @@ cc_library(
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"],
deps = [
":lhlo",
":pass_details",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",

View File

@ -15,13 +15,13 @@ limitations under the License.
include "mlir/Pass/PassBase.td"
def LhloLegalizeToLinalgPass : Pass<"lhlo-legalize-to-linalg", "FuncOp"> {
def LhloLegalizeToLinalgPass : FunctionPass<"lhlo-legalize-to-linalg"> {
let summary = "Legalize from LHLO dialect to Linalg dialect.";
let constructor = "createLegalizeLhloToLinalgPass()";
}
def LhloFuseLinalgPass : Pass<"lhlo-fuse-linalg", "FuncOp"> {
def LhloFuseLinalgPass : FunctionPass<"lhlo-fuse-linalg"> {
let summary = "Greedily fuse linalg ops obtained after LHLO lowering.";
let constructor = "createLhloFuseLinalgPass()";
let options = [
@ -34,24 +34,24 @@ def LhloFuseLinalgPass : Pass<"lhlo-fuse-linalg", "FuncOp"> {
}
def LhloLegalizeToAffinePass : Pass<"lhlo-legalize-to-affine", "FuncOp"> {
def LhloLegalizeToAffinePass : FunctionPass<"lhlo-legalize-to-affine"> {
let summary = "Legalize from LHLO dialect to affine dialect.";
let constructor = "createLhloLegalizeToAffinePass()";
}
def LhloLegalizeToGpuPass : Pass<"lhlo-legalize-to-gpu", "FuncOp"> {
def LhloLegalizeToGpuPass : FunctionPass<"lhlo-legalize-to-gpu"> {
let summary = "Legalize from LHLO dialect to GPU dialect.";
let constructor = "createLegalizeToGpuPass()";
}
def LhloLegalizeToParallelLoopsPass : Pass<"lhlo-legalize-to-parallel-loops", "FuncOp"> {
def LhloLegalizeToParallelLoopsPass : FunctionPass<"lhlo-legalize-to-parallel-loops"> {
let summary = "Legalize from LHLO dialect to parallel loops.";
let constructor = "createLegalizeLhloToParallelLoopsPass()";
}
def LegalizeTensorLoadOpPass : Pass<"lhlo-legalize-tensor-load-op", "FuncOp"> {
def LegalizeTensorLoadOpPass : FunctionPass<"lhlo-legalize-tensor-load-op"> {
let summary = "Legalize tensor load ops that are inserted during mhlo to lmhlo conversion.";
let constructor = "createLegalizeTensorLoadOpPass()";
}

View File

@ -37,64 +37,64 @@ def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> {
];
}
def LegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "FuncOp"> {
def LegalizeControlFlowPass : FunctionPass<"mhlo-legalize-control-flow"> {
let summary = "Legalize from MHLO control flow to CFG control flow.";
let constructor = "createLegalizeControlFlowPass()";
}
def LegalizeControlFlowToScfPass : Pass<"mhlo-control-flow-to-scf", "FuncOp"> {
def LegalizeControlFlowToScfPass : FunctionPass<"mhlo-control-flow-to-scf"> {
let summary = "Legalize from MHLO control flow to SCF control flow.";
let constructor = "createControlFlowToScfPass()";
}
def LegalizeGatherToTorchIndexSelectPass : Pass<"mhlo-legalize-gather-to-torch-index-select", "FuncOp"> {
def LegalizeGatherToTorchIndexSelectPass : FunctionPass<"mhlo-legalize-gather-to-torch-index-select"> {
let summary = "Legalizes gathers to a torch index select.";
let constructor = "createLegalizeGatherToTorchIndexSelectPass()";
}
def LegalizeTanhToApproximationPass : Pass<"mhlo-legalize-trigonometric-to-approximation", "FuncOp"> {
def LegalizeTanhToApproximationPass : FunctionPass<"mhlo-legalize-trigonometric-to-approximation"> {
let summary = "Legalize trigonometric operations from standard dialect to an approximation.";
let constructor = "createLegalizeTrigonometricToApproximationPass()";
}
def HloLegalizeToLinalgPass : Pass<"hlo-legalize-to-linalg", "FuncOp"> {
def HloLegalizeToLinalgPass : FunctionPass<"hlo-legalize-to-linalg"> {
let summary = "Legalize from HLO dialect to Linalg dialect.";
let constructor = "createLegalizeHloToLinalgPass()";
}
def LegalizeToStandardPass : Pass<"mhlo-legalize-to-std", "FuncOp"> {
def LegalizeToStandardPass : FunctionPass<"mhlo-legalize-to-std"> {
let summary = "Legalize from MHLO dialect to standard dialect.";
let constructor = "createLegalizeToStdPass()";
}
def LowerComplexPass : Pass<"mhlo-test-lower-complex", "FuncOp"> {
def LowerComplexPass : FunctionPass<"mhlo-test-lower-complex"> {
let summary = "Lower complex operations into non-complex operations.";
let constructor = "createLowerComplexPass()";
}
def LegalizeGeneralDotPass : Pass<"mhlo-test-lower-general-dot", "FuncOp"> {
def LegalizeGeneralDotPass : FunctionPass<"mhlo-test-lower-general-dot"> {
let summary = "Tests lowering general dot to a non-batched dot when possible.";
let constructor = "createLegalizeGeneralDotPass()";
}
def TestMaterializeBroadcastsPass : Pass<"mhlo-test-materialize-broadcasts", "FuncOp"> {
def TestMaterializeBroadcastsPass : FunctionPass<"mhlo-test-materialize-broadcasts"> {
let summary = "Test pass for materializing 'broadcast_dimensions' attributes.";
let constructor = "createTestMaterializeBroadcastsPass()";
}
def MhloFusionPass : Pass<"mhlo-fusion", "FuncOp"> {
def MhloFusionPass : FunctionPass<"mhlo-fusion"> {
let summary = "Fuse mhlo ops to kLoop/kInput fusion patterns.";
let constructor = "createMhloFusionPass()";
}
def OptimizeMhloPass : Pass<"mhlo-test-optimize", "FuncOp"> {
def OptimizeMhloPass : FunctionPass<"mhlo-test-optimize"> {
let summary = "Run optional HLO optimizations.";
let constructor = "createOptimizeMhloPass()";
}
@ -107,18 +107,18 @@ def SinkConstantsToControlFlowPass : FunctionPass<"mhlo-sink-constants-to-contro
}
def TestInferShapedTypeMethodsPass : Pass<"mhlo-test-infer-shaped-type-methods", "FuncOp"> {
def TestInferShapedTypeMethodsPass : FunctionPass<"mhlo-test-infer-shaped-type-methods"> {
let summary = "Uses test ops to invoke InferShapedTypeOpInterface methods.";
let constructor = "createTestInferShapedTypeMethodsPass()";
}
def TransformUnrankedHloPass : Pass<"mhlo-transform-unranked-hlo", "FuncOp"> {
def TransformUnrankedHloPass : FunctionPass<"mhlo-transform-unranked-hlo"> {
let summary = "Realize element-wise operations on ranked tensors where possible.";
let constructor = "createTransformUnrankedHloPass()";
}
def BroadcastPropagationPass : Pass<"mhlo-broadcast-propagation", "FuncOp"> {
def BroadcastPropagationPass : FunctionPass<"mhlo-broadcast-propagation"> {
let summary = "Move dynamic broadcasts up over element-wise operations and "
"broadcast the operands rather than the result. This will eventually allow "
"for larger fusions.";

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
@ -376,7 +377,7 @@ struct EarlyBroadcastInDimOpPattern
};
struct BroadcastPropagationPass
: public PassWrapper<BroadcastPropagationPass, FunctionPass> {
: public BroadcastPropagationPassBase<BroadcastPropagationPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
}

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project
@ -37,7 +38,7 @@ namespace mlir {
namespace mhlo {
namespace {
struct LegalizeControlFlowPass
: public mlir::PassWrapper<LegalizeControlFlowPass, FunctionPass> {
: public LegalizeControlFlowPassBase<LegalizeControlFlowPass> {
// Perform the lowering to MLIR control flow.
void runOnFunction() override;
};

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/IR/BuiltinOps.h"
@ -128,7 +129,8 @@ struct GatherIsTorchIndexSelect : public OpRewritePattern<GatherOp> {
};
struct LegalizeGatherToTorchIndexSelectPass
: public PassWrapper<LegalizeGatherToTorchIndexSelectPass, FunctionPass> {
: public LegalizeGatherToTorchIndexSelectPassBase<
LegalizeGatherToTorchIndexSelectPass> {
/// Perform the lowering of standard dialect operations to approximations.
void runOnFunction() override {
OwningRewritePatternList patterns(&getContext());

View File

@ -16,6 +16,7 @@ limitations under the License.
// This file implements logic for lowering memref.tensor_load ops that are
// inserted during `mhlo-legalize-to-lmhlo`.
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -71,7 +72,7 @@ struct ForwardShapeOfOp : public OpRewritePattern<ShapeOfOp> {
};
struct LegalizeTensorLoadOpPass
: public mlir::PassWrapper<LegalizeTensorLoadOpPass, FunctionPass> {
: public LegalizeTensorLoadOpPassBase<LegalizeTensorLoadOpPass> {
// Perform the lowering to remove memref.tensor_load ops inserted during
// `mhlo-legalize-to-lmhlo`.
void runOnFunction() override {

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "llvm/ADT/SetVector.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@ -2428,7 +2429,7 @@ class RemoveSignTypeConverter : public TypeConverter {
// iterator_types = ["parallel", "parallel"],
// } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
struct LhloLegalizeToLinalgPass
: public PassWrapper<LhloLegalizeToLinalgPass, FunctionPass> {
: public lmhlo::LhloLegalizeToLinalgPassBase<LhloLegalizeToLinalgPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry
.insert<AffineDialect, complex::ComplexDialect, linalg::LinalgDialect,
@ -2454,7 +2455,7 @@ struct LhloLegalizeToLinalgPass
};
struct HloLegalizeToLinalgPass
: public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
: public mhlo::HloLegalizeToLinalgPassBase<HloLegalizeToLinalgPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry
.insert<linalg::LinalgDialect, scf::SCFDialect, complex::ComplexDialect,

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "llvm/ADT/StringSwitch.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -177,7 +178,7 @@ class ConvertIotaOp : public OpRewritePattern<mhlo::IotaOp> {
namespace {
struct LegalizeToStandardPass
: public PassWrapper<LegalizeToStandardPass, FunctionPass> {
: public LegalizeToStandardPassBase<LegalizeToStandardPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<StandardOpsDialect>();
}

View File

@ -16,6 +16,7 @@ limitations under the License.
// This file implements the lowering for trigonometric standard ops to
// approximations.
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/Math/IR/Math.h"
@ -154,8 +155,8 @@ class ApproximateTanhLowering
};
struct LegalizeTrigonometricToApproximationPass
: public PassWrapper<LegalizeTrigonometricToApproximationPass,
FunctionPass> {
: public LegalizeTanhToApproximationPassBase<
LegalizeTrigonometricToApproximationPass> {
/// Perform the lowering of standard dialect operations to approximations.
void runOnFunction() override {
OwningRewritePatternList patterns(&getContext());

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
@ -36,8 +37,7 @@ namespace {
using linalg::LinalgOp;
class LhloFuseLinalgPass
: public PassWrapper<LhloFuseLinalgPass, FunctionPass> {
class LhloFuseLinalgPass : public LhloFuseLinalgPassBase<LhloFuseLinalgPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
}
@ -202,18 +202,6 @@ class LhloFuseLinalgPass
.setLoopType(loopType));
return tiled_generic_op.hasValue();
}
Option<bool> use_parallel_loops_{
*this, "use-parallel-loops",
llvm::cl::desc(
"Tiles GenericOp consumer to parallel loops before linalg fusion"),
llvm::cl::init(false)};
ListOption<unsigned> tile_sizes_{
*this, "tile-sizes",
llvm::cl::desc(
"Tile sizes by which to tile linalg generic before linalg fusion"),
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
};
} // namespace

View File

@ -16,6 +16,7 @@ limitations under the License.
// This file implements logic for lowering LHLO dialect to Affine dialect.
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -229,7 +230,7 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context,
}
struct LhloLegalizeToAffinePass
: public PassWrapper<LhloLegalizeToAffinePass, FunctionPass> {
: public LhloLegalizeToAffinePassBase<LhloLegalizeToAffinePass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<AffineDialect>();
}

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
@ -172,7 +173,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
};
struct LhloLegalizeToGpuPass
: public PassWrapper<LhloLegalizeToGpuPass, FunctionPass> {
: public LhloLegalizeToGpuPassBase<LhloLegalizeToGpuPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect,
memref::MemRefDialect, scf::SCFDialect>();

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
@ -700,7 +701,8 @@ class SelectAndScatterOpConverter
};
struct LhloLegalizeToParallelLoopsPass
: public PassWrapper<LhloLegalizeToParallelLoopsPass, FunctionPass> {
: public LhloLegalizeToParallelLoopsPassBase<
LhloLegalizeToParallelLoopsPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<StandardOpsDialect, scf::SCFDialect>();
}

View File

@ -24,7 +24,9 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir-hlo/utils/hlo_utils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/MLIRContext.h"
@ -35,35 +37,17 @@ limitations under the License.
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using mlir::FunctionPass;
using mlir::OwningRewritePatternList;
using mlir::PassWrapper;
namespace {
class LowerComplexPass : public PassWrapper<LowerComplexPass, FunctionPass> {
public:
explicit LowerComplexPass() : PassWrapper<LowerComplexPass, FunctionPass>() {}
/// Performs the lowering to MHLO dialect.
void runOnFunction() override;
};
} // end anonymous namespace
namespace mlir {
namespace mhlo {
namespace {
class LowerComplexPass : public LowerComplexPassBase<LowerComplexPass> {
public:
/// Performs the lowering to MHLO dialect.
void runOnFunction() override;
};
#include "generated_lower_complex.inc"
} // end anonymous namespace
void PopulateComplexLoweringPatterns(MLIRContext* context,
OwningRewritePatternList* patterns) {
populateWithGenerated(*patterns);
}
} // end namespace mhlo
} // end namespace mlir
// Lowers the complex operations that can be represented using other operations.
void LowerComplexPass::runOnFunction() {
// Add lowering patterns to the list.
@ -73,6 +57,15 @@ void LowerComplexPass::runOnFunction() {
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
std::unique_ptr<FunctionPass> mlir::mhlo::createLowerComplexPass() {
return std::make_unique<LowerComplexPass>();
} // end anonymous namespace
} // end namespace mhlo
} // end namespace mlir
void mlir::mhlo::PopulateComplexLoweringPatterns(
MLIRContext* context, OwningRewritePatternList* patterns) {
populateWithGenerated(*patterns);
}
std::unique_ptr<mlir::FunctionPass> mlir::mhlo::createLowerComplexPass() {
return std::make_unique<mlir::mhlo::LowerComplexPass>();
}

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -30,28 +31,16 @@ limitations under the License.
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using mlir::DenseIntElementsAttr;
using mlir::ElementsAttr;
using mlir::failure;
using mlir::FunctionPass;
using mlir::LogicalResult;
using mlir::MLIRContext;
using mlir::OpRewritePattern;
using mlir::OwningRewritePatternList;
using mlir::PassWrapper;
using mlir::PatternRewriter;
using mlir::RankedTensorType;
using mlir::success;
using mlir::Value;
namespace mlir {
namespace mhlo {
namespace {
Value TransposeReshape(Value arg, mlir::Location loc,
Value TransposeReshape(Value arg, Location loc,
llvm::ArrayRef<int64_t> left_dims,
llvm::ArrayRef<int64_t> right_dims,
llvm::ArrayRef<int64_t> arg_shape,
PatternRewriter *rewriter) {
auto element_type = mlir::getElementTypeOrSelf(arg.getType());
auto element_type = getElementTypeOrSelf(arg.getType());
int64_t left_size = 1;
for (auto dim : left_dims) {
@ -68,7 +57,7 @@ Value TransposeReshape(Value arg, mlir::Location loc,
left_dims.end());
transpose_permutation.append(right_dims.begin(), right_dims.end());
mlir::TensorType transpose_permutation_type = RankedTensorType::get(
TensorType transpose_permutation_type = RankedTensorType::get(
{static_cast<int64_t>(transpose_permutation.size())},
rewriter->getIntegerType(64));
@ -83,20 +72,18 @@ Value TransposeReshape(Value arg, mlir::Location loc,
transposed_shape.push_back(arg_shape[val]);
}
auto transpose_type = RankedTensorType::get(transposed_shape, element_type);
auto transpose_result = rewriter->create<mlir::mhlo::TransposeOp>(
auto transpose_result = rewriter->create<TransposeOp>(
loc, transpose_type, arg, transpose_permutation_attr);
// Return the final result.
auto reshaped_type =
RankedTensorType::get({left_size, right_size}, element_type);
return rewriter->create<mlir::mhlo::ReshapeOp>(loc, reshaped_type,
transpose_result);
return rewriter->create<ReshapeOp>(loc, reshaped_type, transpose_result);
}
Value ProcessDotArg(Value arg, mlir::Location loc,
ElementsAttr contract_dims_attr, bool outer_dims_first,
PatternRewriter *rewriter) {
auto shape = arg.getType().cast<mlir::ShapedType>().getShape();
Value ProcessDotArg(Value arg, Location loc, ElementsAttr contract_dims_attr,
bool outer_dims_first, PatternRewriter *rewriter) {
auto shape = arg.getType().cast<ShapedType>().getShape();
llvm::SmallVector<bool, 5> is_outer_dim;
is_outer_dim.resize(shape.size(), true);
@ -124,7 +111,7 @@ Value ProcessDotArg(Value arg, mlir::Location loc,
return TransposeReshape(arg, loc, contract_dims, outer_dims, shape, rewriter);
}
struct GeneralDotConvert : public OpRewritePattern<mlir::mhlo::DotGeneralOp> {
struct GeneralDotConvert : public OpRewritePattern<DotGeneralOp> {
// Attempts to lower a General Dot operator to a standard Dot operator.
// General dots include batching dimensions and can have collapsing
// dimensions along any axis. Inserting correctly arrange transpose and
@ -136,9 +123,9 @@ struct GeneralDotConvert : public OpRewritePattern<mlir::mhlo::DotGeneralOp> {
explicit GeneralDotConvert(MLIRContext *context)
: OpRewritePattern(context) {}
LogicalResult matchAndRewrite(mlir::mhlo::DotGeneralOp op,
LogicalResult matchAndRewrite(DotGeneralOp op,
PatternRewriter &rewriter) const override {
auto dot_element_type = mlir::getElementTypeOrSelf(op);
auto dot_element_type = getElementTypeOrSelf(op);
auto dot_numbers = op.dot_dimension_numbers();
if (dot_numbers.lhs_batching_dimensions().getNumElements() != 0 ||
@ -155,8 +142,8 @@ struct GeneralDotConvert : public OpRewritePattern<mlir::mhlo::DotGeneralOp> {
/*outer_dims_first=*/false, &rewriter);
// Accept only static shaped types.
auto lhs_shape_type = lhs.getType().dyn_cast_or_null<mlir::ShapedType>();
auto rhs_shape_type = rhs.getType().dyn_cast_or_null<mlir::ShapedType>();
auto lhs_shape_type = lhs.getType().dyn_cast_or_null<ShapedType>();
auto rhs_shape_type = rhs.getType().dyn_cast_or_null<ShapedType>();
if (!lhs_shape_type || !rhs_shape_type) return failure();
if (!lhs_shape_type.hasStaticShape() || !rhs_shape_type.hasStaticShape())
return failure();
@ -167,28 +154,29 @@ struct GeneralDotConvert : public OpRewritePattern<mlir::mhlo::DotGeneralOp> {
auto new_dot_type =
RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type);
mlir::ArrayAttr precision_config;
ArrayAttr precision_config;
if (op.precision_config()) precision_config = *op.precision_config();
auto new_dot_op = rewriter.create<mlir::mhlo::DotOp>(
op.getLoc(), new_dot_type, lhs, rhs, precision_config);
auto new_dot_op = rewriter.create<DotOp>(op.getLoc(), new_dot_type, lhs,
rhs, precision_config);
rewriter.replaceOpWithNewOp<mlir::mhlo::ReshapeOp>(op, op.getType(),
new_dot_op);
rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), new_dot_op);
return success();
}
};
struct LegalizeGeneralDotPass
: public PassWrapper<LegalizeGeneralDotPass, FunctionPass> {
: public LegalizeGeneralDotPassBase<LegalizeGeneralDotPass> {
/// Lower all general dots that can be represented as a non-batched matmul.
void runOnFunction() override {
OwningRewritePatternList patterns(&getContext());
mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(&patterns, &getContext());
PopulateGeneralDotOpLoweringPatterns(&patterns, &getContext());
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
} // namespace
} // namespace mhlo
} // namespace mlir
void mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(
OwningRewritePatternList *patterns, MLIRContext *ctx) {

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/MLIRContext.h"
@ -28,7 +29,7 @@ namespace mhlo {
namespace {
struct TestMaterializeBroadcastsPass
: public PassWrapper<TestMaterializeBroadcastsPass, FunctionPass> {
: public TestMaterializeBroadcastsPassBase<TestMaterializeBroadcastsPass> {
void runOnFunction() override {
ConversionTarget conversionTarget(getContext());
OwningRewritePatternList conversionPatterns(&getContext());

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -39,7 +40,7 @@ void MatchAndRewrite(WhileOp whileOp);
/// Pass that converts MHLO control flow to SCF.
class ControlFlowToScfPass
: public mlir::PassWrapper<ControlFlowToScfPass, FunctionPass> {
: public LegalizeControlFlowToScfPassBase<ControlFlowToScfPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<scf::SCFDialect>();
}

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/Support/Debug.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/utils/cycle_detector.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
@ -480,7 +481,7 @@ class FusionPlanner {
EquivalenceClasses<int32_t> leader_for_node_;
};
struct MhloFusionPass : public mlir::PassWrapper<MhloFusionPass, FunctionPass> {
struct MhloFusionPass : public MhloFusionPassBase<MhloFusionPass> {
void runOnFunction() override {
FuncOp func = getFunction();
if (!IsTargetFunc(func)) {

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -23,28 +24,26 @@ limitations under the License.
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using mlir::FunctionPass;
using mlir::PassWrapper;
namespace mlir {
namespace mhlo {
namespace {
class OptimizeMhloPass : public PassWrapper<OptimizeMhloPass, FunctionPass> {
class OptimizeMhloPass : public OptimizeMhloPassBase<OptimizeMhloPass> {
public:
explicit OptimizeMhloPass() : PassWrapper<OptimizeMhloPass, FunctionPass>() {}
/// Performs the lowering to MHLO dialect.
void runOnFunction() override;
};
} // end anonymous namespace
// Lowers the complex operations that can be represented using other operations.
void OptimizeMhloPass::runOnFunction() {
// Add lowering patterns to the list.
mlir::OwningRewritePatternList patterns(&getContext());
mlir::mhlo::PopulateOptimizeMHLOPatterns(&getContext(), &patterns);
OwningRewritePatternList patterns(&getContext());
PopulateOptimizeMHLOPatterns(&getContext(), &patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
} // end anonymous namespace
} // namespace mhlo
} // namespace mlir
std::unique_ptr<mlir::FunctionPass> mlir::mhlo::createOptimizeMhloPass() {
return std::make_unique<OptimizeMhloPass>();
return std::make_unique<mlir::mhlo::OptimizeMhloPass>();
}

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Identifier.h"
@ -84,7 +85,8 @@ struct ReifyReturnTypeShapesPattern : public RewritePattern {
};
struct TestInferShapedTypeMethodsPass
: public PassWrapper<TestInferShapedTypeMethodsPass, FunctionPass> {
: public TestInferShapedTypeMethodsPassBase<
TestInferShapedTypeMethodsPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<shape::ShapeDialect>();
}

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/SCF/SCF.h"
@ -538,7 +539,7 @@ struct ConvertUnrankedDynamicBroadcastNaryOp
};
struct TransformUnrankedHloPass
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
: public mhlo::TransformUnrankedHloPassBase<TransformUnrankedHloPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<chlo::HloClientDialect, mhlo::MhloDialect, scf::SCFDialect,
shape::ShapeDialect>();

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
@ -30,7 +31,7 @@ namespace mhlo {
namespace {
struct TestUnfuseBatchNormPass
: public PassWrapper<TestUnfuseBatchNormPass, OperationPass<>> {
: public TestUnfuseBatchNormPassBase<TestUnfuseBatchNormPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<memref::MemRefDialect>();
}