Fix pass definition to inherit from the TableGen generated base class (NFC)
PiperOrigin-RevId: 379860210
This commit is contained in:
parent
2e08c246e9
commit
8c8e81cb69
15
BUILD
15
BUILD
|
@ -748,6 +748,7 @@ cc_library(
|
||||||
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"],
|
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":hlo",
|
":hlo",
|
||||||
|
":pass_details",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
|
@ -799,6 +800,7 @@ cc_library(
|
||||||
":hlo",
|
":hlo",
|
||||||
":lhlo",
|
":lhlo",
|
||||||
":map_lmhlo_to_scalar_op",
|
":map_lmhlo_to_scalar_op",
|
||||||
|
":pass_details",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:Affine",
|
"@llvm-project//mlir:Affine",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
|
@ -814,6 +816,7 @@ cc_library(
|
||||||
srcs = ["lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc"],
|
srcs = ["lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":lhlo",
|
":lhlo",
|
||||||
|
":pass_details",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:LinalgOps",
|
"@llvm-project//mlir:LinalgOps",
|
||||||
|
@ -837,6 +840,7 @@ cc_library(
|
||||||
":hlo",
|
":hlo",
|
||||||
":lhlo",
|
":lhlo",
|
||||||
":map_lmhlo_to_scalar_op",
|
":map_lmhlo_to_scalar_op",
|
||||||
|
":pass_details",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:Affine",
|
"@llvm-project//mlir:Affine",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
|
@ -863,6 +867,7 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":hlo",
|
":hlo",
|
||||||
":map_chlo_to_hlo_op",
|
":map_chlo_to_hlo_op",
|
||||||
|
":pass_details",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
|
@ -885,6 +890,7 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":hlo",
|
":hlo",
|
||||||
":map_chlo_to_hlo_op",
|
":map_chlo_to_hlo_op",
|
||||||
|
":pass_details",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:InferTypeOpInterface",
|
"@llvm-project//mlir:InferTypeOpInterface",
|
||||||
|
@ -928,6 +934,7 @@ cc_library(
|
||||||
":hlo",
|
":hlo",
|
||||||
":lhlo",
|
":lhlo",
|
||||||
":map_lmhlo_to_scalar_op",
|
":map_lmhlo_to_scalar_op",
|
||||||
|
":pass_details",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:Affine",
|
"@llvm-project//mlir:Affine",
|
||||||
"@llvm-project//mlir:GPUDialect",
|
"@llvm-project//mlir:GPUDialect",
|
||||||
|
@ -948,6 +955,7 @@ cc_library(
|
||||||
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"],
|
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":lhlo",
|
":lhlo",
|
||||||
|
":pass_details",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:Affine",
|
"@llvm-project//mlir:Affine",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
|
@ -1008,6 +1016,7 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":cycle_detector",
|
":cycle_detector",
|
||||||
":hlo",
|
":hlo",
|
||||||
|
":pass_details",
|
||||||
"@llvm-project//llvm:Core",
|
"@llvm-project//llvm:Core",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
|
@ -1046,6 +1055,7 @@ cc_library(
|
||||||
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"],
|
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":hlo",
|
":hlo",
|
||||||
|
":pass_details",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
|
@ -1064,6 +1074,7 @@ cc_library(
|
||||||
":hlo",
|
":hlo",
|
||||||
":legalize_to_standard_inc_gen",
|
":legalize_to_standard_inc_gen",
|
||||||
":legalize_trigonometric_to_approximation",
|
":legalize_trigonometric_to_approximation",
|
||||||
|
":pass_details",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
|
@ -1083,6 +1094,7 @@ cc_library(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":hlo",
|
":hlo",
|
||||||
|
":pass_details",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
|
@ -1102,6 +1114,7 @@ cc_library(
|
||||||
],
|
],
|
||||||
includes = ["include"],
|
includes = ["include"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":pass_details",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:MathDialect",
|
"@llvm-project//mlir:MathDialect",
|
||||||
|
@ -1149,6 +1162,7 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":hlo",
|
":hlo",
|
||||||
":lower_complex_inc_gen",
|
":lower_complex_inc_gen",
|
||||||
|
":pass_details",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:Analysis",
|
"@llvm-project//mlir:Analysis",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
|
@ -1201,6 +1215,7 @@ cc_library(
|
||||||
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"],
|
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":lhlo",
|
":lhlo",
|
||||||
|
":pass_details",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:MemRefDialect",
|
"@llvm-project//mlir:MemRefDialect",
|
||||||
|
|
|
@ -15,13 +15,13 @@ limitations under the License.
|
||||||
|
|
||||||
include "mlir/Pass/PassBase.td"
|
include "mlir/Pass/PassBase.td"
|
||||||
|
|
||||||
def LhloLegalizeToLinalgPass : Pass<"lhlo-legalize-to-linalg", "FuncOp"> {
|
def LhloLegalizeToLinalgPass : FunctionPass<"lhlo-legalize-to-linalg"> {
|
||||||
let summary = "Legalize from LHLO dialect to Linalg dialect.";
|
let summary = "Legalize from LHLO dialect to Linalg dialect.";
|
||||||
let constructor = "createLegalizeLhloToLinalgPass()";
|
let constructor = "createLegalizeLhloToLinalgPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def LhloFuseLinalgPass : Pass<"lhlo-fuse-linalg", "FuncOp"> {
|
def LhloFuseLinalgPass : FunctionPass<"lhlo-fuse-linalg"> {
|
||||||
let summary = "Greedily fuse linalg ops obtained after LHLO lowering.";
|
let summary = "Greedily fuse linalg ops obtained after LHLO lowering.";
|
||||||
let constructor = "createLhloFuseLinalgPass()";
|
let constructor = "createLhloFuseLinalgPass()";
|
||||||
let options = [
|
let options = [
|
||||||
|
@ -34,24 +34,24 @@ def LhloFuseLinalgPass : Pass<"lhlo-fuse-linalg", "FuncOp"> {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def LhloLegalizeToAffinePass : Pass<"lhlo-legalize-to-affine", "FuncOp"> {
|
def LhloLegalizeToAffinePass : FunctionPass<"lhlo-legalize-to-affine"> {
|
||||||
let summary = "Legalize from LHLO dialect to affine dialect.";
|
let summary = "Legalize from LHLO dialect to affine dialect.";
|
||||||
let constructor = "createLhloLegalizeToAffinePass()";
|
let constructor = "createLhloLegalizeToAffinePass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def LhloLegalizeToGpuPass : Pass<"lhlo-legalize-to-gpu", "FuncOp"> {
|
def LhloLegalizeToGpuPass : FunctionPass<"lhlo-legalize-to-gpu"> {
|
||||||
let summary = "Legalize from LHLO dialect to GPU dialect.";
|
let summary = "Legalize from LHLO dialect to GPU dialect.";
|
||||||
let constructor = "createLegalizeToGpuPass()";
|
let constructor = "createLegalizeToGpuPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def LhloLegalizeToParallelLoopsPass : Pass<"lhlo-legalize-to-parallel-loops", "FuncOp"> {
|
def LhloLegalizeToParallelLoopsPass : FunctionPass<"lhlo-legalize-to-parallel-loops"> {
|
||||||
let summary = "Legalize from LHLO dialect to parallel loops.";
|
let summary = "Legalize from LHLO dialect to parallel loops.";
|
||||||
let constructor = "createLegalizeLhloToParallelLoopsPass()";
|
let constructor = "createLegalizeLhloToParallelLoopsPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
def LegalizeTensorLoadOpPass : Pass<"lhlo-legalize-tensor-load-op", "FuncOp"> {
|
def LegalizeTensorLoadOpPass : FunctionPass<"lhlo-legalize-tensor-load-op"> {
|
||||||
let summary = "Legalize tensor load ops that are inserted during mhlo to lmhlo conversion.";
|
let summary = "Legalize tensor load ops that are inserted during mhlo to lmhlo conversion.";
|
||||||
let constructor = "createLegalizeTensorLoadOpPass()";
|
let constructor = "createLegalizeTensorLoadOpPass()";
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,64 +37,64 @@ def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> {
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
def LegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "FuncOp"> {
|
def LegalizeControlFlowPass : FunctionPass<"mhlo-legalize-control-flow"> {
|
||||||
let summary = "Legalize from MHLO control flow to CFG control flow.";
|
let summary = "Legalize from MHLO control flow to CFG control flow.";
|
||||||
let constructor = "createLegalizeControlFlowPass()";
|
let constructor = "createLegalizeControlFlowPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
def LegalizeControlFlowToScfPass : Pass<"mhlo-control-flow-to-scf", "FuncOp"> {
|
def LegalizeControlFlowToScfPass : FunctionPass<"mhlo-control-flow-to-scf"> {
|
||||||
let summary = "Legalize from MHLO control flow to SCF control flow.";
|
let summary = "Legalize from MHLO control flow to SCF control flow.";
|
||||||
let constructor = "createControlFlowToScfPass()";
|
let constructor = "createControlFlowToScfPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
def LegalizeGatherToTorchIndexSelectPass : Pass<"mhlo-legalize-gather-to-torch-index-select", "FuncOp"> {
|
def LegalizeGatherToTorchIndexSelectPass : FunctionPass<"mhlo-legalize-gather-to-torch-index-select"> {
|
||||||
let summary = "Legalizes gathers to a torch index select.";
|
let summary = "Legalizes gathers to a torch index select.";
|
||||||
let constructor = "createLegalizeGatherToTorchIndexSelectPass()";
|
let constructor = "createLegalizeGatherToTorchIndexSelectPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def LegalizeTanhToApproximationPass : Pass<"mhlo-legalize-trigonometric-to-approximation", "FuncOp"> {
|
def LegalizeTanhToApproximationPass : FunctionPass<"mhlo-legalize-trigonometric-to-approximation"> {
|
||||||
let summary = "Legalize trigonometric operations from standard dialect to an approximation.";
|
let summary = "Legalize trigonometric operations from standard dialect to an approximation.";
|
||||||
let constructor = "createLegalizeTrigonometricToApproximationPass()";
|
let constructor = "createLegalizeTrigonometricToApproximationPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def HloLegalizeToLinalgPass : Pass<"hlo-legalize-to-linalg", "FuncOp"> {
|
def HloLegalizeToLinalgPass : FunctionPass<"hlo-legalize-to-linalg"> {
|
||||||
let summary = "Legalize from HLO dialect to Linalg dialect.";
|
let summary = "Legalize from HLO dialect to Linalg dialect.";
|
||||||
let constructor = "createLegalizeHloToLinalgPass()";
|
let constructor = "createLegalizeHloToLinalgPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def LegalizeToStandardPass : Pass<"mhlo-legalize-to-std", "FuncOp"> {
|
def LegalizeToStandardPass : FunctionPass<"mhlo-legalize-to-std"> {
|
||||||
let summary = "Legalize from MHLO dialect to standard dialect.";
|
let summary = "Legalize from MHLO dialect to standard dialect.";
|
||||||
let constructor = "createLegalizeToStdPass()";
|
let constructor = "createLegalizeToStdPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
def LowerComplexPass : Pass<"mhlo-test-lower-complex", "FuncOp"> {
|
def LowerComplexPass : FunctionPass<"mhlo-test-lower-complex"> {
|
||||||
let summary = "Lower complex operations into non-complex operations.";
|
let summary = "Lower complex operations into non-complex operations.";
|
||||||
let constructor = "createLowerComplexPass()";
|
let constructor = "createLowerComplexPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def LegalizeGeneralDotPass : Pass<"mhlo-test-lower-general-dot", "FuncOp"> {
|
def LegalizeGeneralDotPass : FunctionPass<"mhlo-test-lower-general-dot"> {
|
||||||
let summary = "Tests lowering general dot to a non-batched dot when possible.";
|
let summary = "Tests lowering general dot to a non-batched dot when possible.";
|
||||||
let constructor = "createLegalizeGeneralDotPass()";
|
let constructor = "createLegalizeGeneralDotPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def TestMaterializeBroadcastsPass : Pass<"mhlo-test-materialize-broadcasts", "FuncOp"> {
|
def TestMaterializeBroadcastsPass : FunctionPass<"mhlo-test-materialize-broadcasts"> {
|
||||||
let summary = "Test pass for materializing 'broadcast_dimensions' attributes.";
|
let summary = "Test pass for materializing 'broadcast_dimensions' attributes.";
|
||||||
let constructor = "createTestMaterializeBroadcastsPass()";
|
let constructor = "createTestMaterializeBroadcastsPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def MhloFusionPass : Pass<"mhlo-fusion", "FuncOp"> {
|
def MhloFusionPass : FunctionPass<"mhlo-fusion"> {
|
||||||
let summary = "Fuse mhlo ops to kLoop/kInput fusion patterns.";
|
let summary = "Fuse mhlo ops to kLoop/kInput fusion patterns.";
|
||||||
let constructor = "createMhloFusionPass()";
|
let constructor = "createMhloFusionPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def OptimizeMhloPass : Pass<"mhlo-test-optimize", "FuncOp"> {
|
def OptimizeMhloPass : FunctionPass<"mhlo-test-optimize"> {
|
||||||
let summary = "Run optional HLO optimizations.";
|
let summary = "Run optional HLO optimizations.";
|
||||||
let constructor = "createOptimizeMhloPass()";
|
let constructor = "createOptimizeMhloPass()";
|
||||||
}
|
}
|
||||||
|
@ -107,18 +107,18 @@ def SinkConstantsToControlFlowPass : FunctionPass<"mhlo-sink-constants-to-contro
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def TestInferShapedTypeMethodsPass : Pass<"mhlo-test-infer-shaped-type-methods", "FuncOp"> {
|
def TestInferShapedTypeMethodsPass : FunctionPass<"mhlo-test-infer-shaped-type-methods"> {
|
||||||
let summary = "Uses test ops to invoke InferShapedTypeOpInterface methods.";
|
let summary = "Uses test ops to invoke InferShapedTypeOpInterface methods.";
|
||||||
let constructor = "createTestInferShapedTypeMethodsPass()";
|
let constructor = "createTestInferShapedTypeMethodsPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def TransformUnrankedHloPass : Pass<"mhlo-transform-unranked-hlo", "FuncOp"> {
|
def TransformUnrankedHloPass : FunctionPass<"mhlo-transform-unranked-hlo"> {
|
||||||
let summary = "Realize element-wise operations on ranked tensors where possible.";
|
let summary = "Realize element-wise operations on ranked tensors where possible.";
|
||||||
let constructor = "createTransformUnrankedHloPass()";
|
let constructor = "createTransformUnrankedHloPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
def BroadcastPropagationPass : Pass<"mhlo-broadcast-propagation", "FuncOp"> {
|
def BroadcastPropagationPass : FunctionPass<"mhlo-broadcast-propagation"> {
|
||||||
let summary = "Move dynamic broadcasts up over element-wise operations and "
|
let summary = "Move dynamic broadcasts up over element-wise operations and "
|
||||||
"broadcast the operands rather than the result. This will eventually allow "
|
"broadcast the operands rather than the result. This will eventually allow "
|
||||||
"for larger fusions.";
|
"for larger fusions.";
|
||||||
|
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
|
@ -376,7 +377,7 @@ struct EarlyBroadcastInDimOpPattern
|
||||||
};
|
};
|
||||||
|
|
||||||
struct BroadcastPropagationPass
|
struct BroadcastPropagationPass
|
||||||
: public PassWrapper<BroadcastPropagationPass, FunctionPass> {
|
: public BroadcastPropagationPassBase<BroadcastPropagationPass> {
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
|
registry.insert<shape::ShapeDialect, mhlo::MhloDialect>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||||
#include "llvm/ADT/StringSwitch.h"
|
#include "llvm/ADT/StringSwitch.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project
|
#include "mlir/Dialect/Tensor/IR/Tensor.h" // TF:llvm-project
|
||||||
|
@ -37,7 +38,7 @@ namespace mlir {
|
||||||
namespace mhlo {
|
namespace mhlo {
|
||||||
namespace {
|
namespace {
|
||||||
struct LegalizeControlFlowPass
|
struct LegalizeControlFlowPass
|
||||||
: public mlir::PassWrapper<LegalizeControlFlowPass, FunctionPass> {
|
: public LegalizeControlFlowPassBase<LegalizeControlFlowPass> {
|
||||||
// Perform the lowering to MLIR control flow.
|
// Perform the lowering to MLIR control flow.
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
|
|
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
@ -128,7 +129,8 @@ struct GatherIsTorchIndexSelect : public OpRewritePattern<GatherOp> {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LegalizeGatherToTorchIndexSelectPass
|
struct LegalizeGatherToTorchIndexSelectPass
|
||||||
: public PassWrapper<LegalizeGatherToTorchIndexSelectPass, FunctionPass> {
|
: public LegalizeGatherToTorchIndexSelectPassBase<
|
||||||
|
LegalizeGatherToTorchIndexSelectPass> {
|
||||||
/// Perform the lowering of standard dialect operations to approximations.
|
/// Perform the lowering of standard dialect operations to approximations.
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
OwningRewritePatternList patterns(&getContext());
|
OwningRewritePatternList patterns(&getContext());
|
||||||
|
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||||
// This file implements logic for lowering memref.tensor_load ops that are
|
// This file implements logic for lowering memref.tensor_load ops that are
|
||||||
// inserted during `mhlo-legalize-to-lmhlo`.
|
// inserted during `mhlo-legalize-to-lmhlo`.
|
||||||
|
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
@ -71,7 +72,7 @@ struct ForwardShapeOfOp : public OpRewritePattern<ShapeOfOp> {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LegalizeTensorLoadOpPass
|
struct LegalizeTensorLoadOpPass
|
||||||
: public mlir::PassWrapper<LegalizeTensorLoadOpPass, FunctionPass> {
|
: public LegalizeTensorLoadOpPassBase<LegalizeTensorLoadOpPass> {
|
||||||
// Perform the lowering to remove memref.tensor_load ops inserted during
|
// Perform the lowering to remove memref.tensor_load ops inserted during
|
||||||
// `mhlo-legalize-to-lmhlo`.
|
// `mhlo-legalize-to-lmhlo`.
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
|
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||||
#include "llvm/ADT/SetVector.h"
|
#include "llvm/ADT/SetVector.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
|
@ -2428,7 +2429,7 @@ class RemoveSignTypeConverter : public TypeConverter {
|
||||||
// iterator_types = ["parallel", "parallel"],
|
// iterator_types = ["parallel", "parallel"],
|
||||||
// } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
// } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||||
struct LhloLegalizeToLinalgPass
|
struct LhloLegalizeToLinalgPass
|
||||||
: public PassWrapper<LhloLegalizeToLinalgPass, FunctionPass> {
|
: public lmhlo::LhloLegalizeToLinalgPassBase<LhloLegalizeToLinalgPass> {
|
||||||
void getDependentDialects(DialectRegistry& registry) const override {
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
registry
|
registry
|
||||||
.insert<AffineDialect, complex::ComplexDialect, linalg::LinalgDialect,
|
.insert<AffineDialect, complex::ComplexDialect, linalg::LinalgDialect,
|
||||||
|
@ -2454,7 +2455,7 @@ struct LhloLegalizeToLinalgPass
|
||||||
};
|
};
|
||||||
|
|
||||||
struct HloLegalizeToLinalgPass
|
struct HloLegalizeToLinalgPass
|
||||||
: public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
|
: public mhlo::HloLegalizeToLinalgPassBase<HloLegalizeToLinalgPass> {
|
||||||
void getDependentDialects(DialectRegistry& registry) const override {
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
registry
|
registry
|
||||||
.insert<linalg::LinalgDialect, scf::SCFDialect, complex::ComplexDialect,
|
.insert<linalg::LinalgDialect, scf::SCFDialect, complex::ComplexDialect,
|
||||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "llvm/ADT/StringSwitch.h"
|
#include "llvm/ADT/StringSwitch.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
@ -177,7 +178,7 @@ class ConvertIotaOp : public OpRewritePattern<mhlo::IotaOp> {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
struct LegalizeToStandardPass
|
struct LegalizeToStandardPass
|
||||||
: public PassWrapper<LegalizeToStandardPass, FunctionPass> {
|
: public LegalizeToStandardPassBase<LegalizeToStandardPass> {
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<StandardOpsDialect>();
|
registry.insert<StandardOpsDialect>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||||
// This file implements the lowering for trigonometric standard ops to
|
// This file implements the lowering for trigonometric standard ops to
|
||||||
// approximations.
|
// approximations.
|
||||||
|
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
#include "mlir/Dialect/Math/IR/Math.h"
|
#include "mlir/Dialect/Math/IR/Math.h"
|
||||||
|
@ -154,8 +155,8 @@ class ApproximateTanhLowering
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LegalizeTrigonometricToApproximationPass
|
struct LegalizeTrigonometricToApproximationPass
|
||||||
: public PassWrapper<LegalizeTrigonometricToApproximationPass,
|
: public LegalizeTanhToApproximationPassBase<
|
||||||
FunctionPass> {
|
LegalizeTrigonometricToApproximationPass> {
|
||||||
/// Perform the lowering of standard dialect operations to approximations.
|
/// Perform the lowering of standard dialect operations to approximations.
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
OwningRewritePatternList patterns(&getContext());
|
OwningRewritePatternList patterns(&getContext());
|
||||||
|
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
||||||
|
@ -36,8 +37,7 @@ namespace {
|
||||||
|
|
||||||
using linalg::LinalgOp;
|
using linalg::LinalgOp;
|
||||||
|
|
||||||
class LhloFuseLinalgPass
|
class LhloFuseLinalgPass : public LhloFuseLinalgPassBase<LhloFuseLinalgPass> {
|
||||||
: public PassWrapper<LhloFuseLinalgPass, FunctionPass> {
|
|
||||||
void getDependentDialects(DialectRegistry& registry) const override {
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
|
registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
|
||||||
}
|
}
|
||||||
|
@ -202,18 +202,6 @@ class LhloFuseLinalgPass
|
||||||
.setLoopType(loopType));
|
.setLoopType(loopType));
|
||||||
return tiled_generic_op.hasValue();
|
return tiled_generic_op.hasValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
Option<bool> use_parallel_loops_{
|
|
||||||
*this, "use-parallel-loops",
|
|
||||||
llvm::cl::desc(
|
|
||||||
"Tiles GenericOp consumer to parallel loops before linalg fusion"),
|
|
||||||
llvm::cl::init(false)};
|
|
||||||
|
|
||||||
ListOption<unsigned> tile_sizes_{
|
|
||||||
*this, "tile-sizes",
|
|
||||||
llvm::cl::desc(
|
|
||||||
"Tile sizes by which to tile linalg generic before linalg fusion"),
|
|
||||||
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||||
// This file implements logic for lowering LHLO dialect to Affine dialect.
|
// This file implements logic for lowering LHLO dialect to Affine dialect.
|
||||||
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
||||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
@ -229,7 +230,7 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct LhloLegalizeToAffinePass
|
struct LhloLegalizeToAffinePass
|
||||||
: public PassWrapper<LhloLegalizeToAffinePass, FunctionPass> {
|
: public LhloLegalizeToAffinePassBase<LhloLegalizeToAffinePass> {
|
||||||
void getDependentDialects(DialectRegistry& registry) const override {
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
registry.insert<AffineDialect>();
|
registry.insert<AffineDialect>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
|
||||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
#include "mlir/Dialect/GPU/GPUDialect.h"
|
#include "mlir/Dialect/GPU/GPUDialect.h"
|
||||||
|
@ -172,7 +173,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LhloLegalizeToGpuPass
|
struct LhloLegalizeToGpuPass
|
||||||
: public PassWrapper<LhloLegalizeToGpuPass, FunctionPass> {
|
: public LhloLegalizeToGpuPassBase<LhloLegalizeToGpuPass> {
|
||||||
void getDependentDialects(DialectRegistry& registry) const override {
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect,
|
registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect,
|
||||||
memref::MemRefDialect, scf::SCFDialect>();
|
memref::MemRefDialect, scf::SCFDialect>();
|
||||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
|
@ -700,7 +701,8 @@ class SelectAndScatterOpConverter
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LhloLegalizeToParallelLoopsPass
|
struct LhloLegalizeToParallelLoopsPass
|
||||||
: public PassWrapper<LhloLegalizeToParallelLoopsPass, FunctionPass> {
|
: public LhloLegalizeToParallelLoopsPassBase<
|
||||||
|
LhloLegalizeToParallelLoopsPass> {
|
||||||
void getDependentDialects(DialectRegistry& registry) const override {
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
registry.insert<StandardOpsDialect, scf::SCFDialect>();
|
registry.insert<StandardOpsDialect, scf::SCFDialect>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,7 +24,9 @@ limitations under the License.
|
||||||
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
#include "mlir-hlo/utils/hlo_utils.h"
|
#include "mlir-hlo/utils/hlo_utils.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
@ -35,35 +37,17 @@ limitations under the License.
|
||||||
#include "mlir/Pass/PassRegistry.h"
|
#include "mlir/Pass/PassRegistry.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
using mlir::FunctionPass;
|
|
||||||
using mlir::OwningRewritePatternList;
|
|
||||||
using mlir::PassWrapper;
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
class LowerComplexPass : public PassWrapper<LowerComplexPass, FunctionPass> {
|
|
||||||
public:
|
|
||||||
explicit LowerComplexPass() : PassWrapper<LowerComplexPass, FunctionPass>() {}
|
|
||||||
|
|
||||||
/// Performs the lowering to MHLO dialect.
|
|
||||||
void runOnFunction() override;
|
|
||||||
};
|
|
||||||
} // end anonymous namespace
|
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace mhlo {
|
namespace mhlo {
|
||||||
namespace {
|
namespace {
|
||||||
|
class LowerComplexPass : public LowerComplexPassBase<LowerComplexPass> {
|
||||||
|
public:
|
||||||
|
/// Performs the lowering to MHLO dialect.
|
||||||
|
void runOnFunction() override;
|
||||||
|
};
|
||||||
|
|
||||||
#include "generated_lower_complex.inc"
|
#include "generated_lower_complex.inc"
|
||||||
|
|
||||||
} // end anonymous namespace
|
|
||||||
|
|
||||||
void PopulateComplexLoweringPatterns(MLIRContext* context,
|
|
||||||
OwningRewritePatternList* patterns) {
|
|
||||||
populateWithGenerated(*patterns);
|
|
||||||
}
|
|
||||||
} // end namespace mhlo
|
|
||||||
} // end namespace mlir
|
|
||||||
|
|
||||||
// Lowers the complex operations that can be represented using other operations.
|
// Lowers the complex operations that can be represented using other operations.
|
||||||
void LowerComplexPass::runOnFunction() {
|
void LowerComplexPass::runOnFunction() {
|
||||||
// Add lowering patterns to the list.
|
// Add lowering patterns to the list.
|
||||||
|
@ -73,6 +57,15 @@ void LowerComplexPass::runOnFunction() {
|
||||||
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<FunctionPass> mlir::mhlo::createLowerComplexPass() {
|
} // end anonymous namespace
|
||||||
return std::make_unique<LowerComplexPass>();
|
} // end namespace mhlo
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
void mlir::mhlo::PopulateComplexLoweringPatterns(
|
||||||
|
MLIRContext* context, OwningRewritePatternList* patterns) {
|
||||||
|
populateWithGenerated(*patterns);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<mlir::FunctionPass> mlir::mhlo::createLowerComplexPass() {
|
||||||
|
return std::make_unique<mlir::mhlo::LowerComplexPass>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/StringSwitch.h"
|
#include "llvm/ADT/StringSwitch.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
@ -30,28 +31,16 @@ limitations under the License.
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
using mlir::DenseIntElementsAttr;
|
namespace mlir {
|
||||||
using mlir::ElementsAttr;
|
namespace mhlo {
|
||||||
using mlir::failure;
|
|
||||||
using mlir::FunctionPass;
|
|
||||||
using mlir::LogicalResult;
|
|
||||||
using mlir::MLIRContext;
|
|
||||||
using mlir::OpRewritePattern;
|
|
||||||
using mlir::OwningRewritePatternList;
|
|
||||||
using mlir::PassWrapper;
|
|
||||||
using mlir::PatternRewriter;
|
|
||||||
using mlir::RankedTensorType;
|
|
||||||
using mlir::success;
|
|
||||||
using mlir::Value;
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
Value TransposeReshape(Value arg, mlir::Location loc,
|
Value TransposeReshape(Value arg, Location loc,
|
||||||
llvm::ArrayRef<int64_t> left_dims,
|
llvm::ArrayRef<int64_t> left_dims,
|
||||||
llvm::ArrayRef<int64_t> right_dims,
|
llvm::ArrayRef<int64_t> right_dims,
|
||||||
llvm::ArrayRef<int64_t> arg_shape,
|
llvm::ArrayRef<int64_t> arg_shape,
|
||||||
PatternRewriter *rewriter) {
|
PatternRewriter *rewriter) {
|
||||||
auto element_type = mlir::getElementTypeOrSelf(arg.getType());
|
auto element_type = getElementTypeOrSelf(arg.getType());
|
||||||
|
|
||||||
int64_t left_size = 1;
|
int64_t left_size = 1;
|
||||||
for (auto dim : left_dims) {
|
for (auto dim : left_dims) {
|
||||||
|
@ -68,7 +57,7 @@ Value TransposeReshape(Value arg, mlir::Location loc,
|
||||||
left_dims.end());
|
left_dims.end());
|
||||||
transpose_permutation.append(right_dims.begin(), right_dims.end());
|
transpose_permutation.append(right_dims.begin(), right_dims.end());
|
||||||
|
|
||||||
mlir::TensorType transpose_permutation_type = RankedTensorType::get(
|
TensorType transpose_permutation_type = RankedTensorType::get(
|
||||||
{static_cast<int64_t>(transpose_permutation.size())},
|
{static_cast<int64_t>(transpose_permutation.size())},
|
||||||
rewriter->getIntegerType(64));
|
rewriter->getIntegerType(64));
|
||||||
|
|
||||||
|
@ -83,20 +72,18 @@ Value TransposeReshape(Value arg, mlir::Location loc,
|
||||||
transposed_shape.push_back(arg_shape[val]);
|
transposed_shape.push_back(arg_shape[val]);
|
||||||
}
|
}
|
||||||
auto transpose_type = RankedTensorType::get(transposed_shape, element_type);
|
auto transpose_type = RankedTensorType::get(transposed_shape, element_type);
|
||||||
auto transpose_result = rewriter->create<mlir::mhlo::TransposeOp>(
|
auto transpose_result = rewriter->create<TransposeOp>(
|
||||||
loc, transpose_type, arg, transpose_permutation_attr);
|
loc, transpose_type, arg, transpose_permutation_attr);
|
||||||
|
|
||||||
// Return the final result.
|
// Return the final result.
|
||||||
auto reshaped_type =
|
auto reshaped_type =
|
||||||
RankedTensorType::get({left_size, right_size}, element_type);
|
RankedTensorType::get({left_size, right_size}, element_type);
|
||||||
return rewriter->create<mlir::mhlo::ReshapeOp>(loc, reshaped_type,
|
return rewriter->create<ReshapeOp>(loc, reshaped_type, transpose_result);
|
||||||
transpose_result);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Value ProcessDotArg(Value arg, mlir::Location loc,
|
Value ProcessDotArg(Value arg, Location loc, ElementsAttr contract_dims_attr,
|
||||||
ElementsAttr contract_dims_attr, bool outer_dims_first,
|
bool outer_dims_first, PatternRewriter *rewriter) {
|
||||||
PatternRewriter *rewriter) {
|
auto shape = arg.getType().cast<ShapedType>().getShape();
|
||||||
auto shape = arg.getType().cast<mlir::ShapedType>().getShape();
|
|
||||||
|
|
||||||
llvm::SmallVector<bool, 5> is_outer_dim;
|
llvm::SmallVector<bool, 5> is_outer_dim;
|
||||||
is_outer_dim.resize(shape.size(), true);
|
is_outer_dim.resize(shape.size(), true);
|
||||||
|
@ -124,7 +111,7 @@ Value ProcessDotArg(Value arg, mlir::Location loc,
|
||||||
return TransposeReshape(arg, loc, contract_dims, outer_dims, shape, rewriter);
|
return TransposeReshape(arg, loc, contract_dims, outer_dims, shape, rewriter);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct GeneralDotConvert : public OpRewritePattern<mlir::mhlo::DotGeneralOp> {
|
struct GeneralDotConvert : public OpRewritePattern<DotGeneralOp> {
|
||||||
// Attempts to lower a General Dot operator to a standard Dot operator.
|
// Attempts to lower a General Dot operator to a standard Dot operator.
|
||||||
// General dots include batching dimensions and can have collapsing
|
// General dots include batching dimensions and can have collapsing
|
||||||
// dimensions along any axis. Inserting correctly arrange transpose and
|
// dimensions along any axis. Inserting correctly arrange transpose and
|
||||||
|
@ -136,9 +123,9 @@ struct GeneralDotConvert : public OpRewritePattern<mlir::mhlo::DotGeneralOp> {
|
||||||
explicit GeneralDotConvert(MLIRContext *context)
|
explicit GeneralDotConvert(MLIRContext *context)
|
||||||
: OpRewritePattern(context) {}
|
: OpRewritePattern(context) {}
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(mlir::mhlo::DotGeneralOp op,
|
LogicalResult matchAndRewrite(DotGeneralOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto dot_element_type = mlir::getElementTypeOrSelf(op);
|
auto dot_element_type = getElementTypeOrSelf(op);
|
||||||
|
|
||||||
auto dot_numbers = op.dot_dimension_numbers();
|
auto dot_numbers = op.dot_dimension_numbers();
|
||||||
if (dot_numbers.lhs_batching_dimensions().getNumElements() != 0 ||
|
if (dot_numbers.lhs_batching_dimensions().getNumElements() != 0 ||
|
||||||
|
@ -155,8 +142,8 @@ struct GeneralDotConvert : public OpRewritePattern<mlir::mhlo::DotGeneralOp> {
|
||||||
/*outer_dims_first=*/false, &rewriter);
|
/*outer_dims_first=*/false, &rewriter);
|
||||||
|
|
||||||
// Accept only static shaped types.
|
// Accept only static shaped types.
|
||||||
auto lhs_shape_type = lhs.getType().dyn_cast_or_null<mlir::ShapedType>();
|
auto lhs_shape_type = lhs.getType().dyn_cast_or_null<ShapedType>();
|
||||||
auto rhs_shape_type = rhs.getType().dyn_cast_or_null<mlir::ShapedType>();
|
auto rhs_shape_type = rhs.getType().dyn_cast_or_null<ShapedType>();
|
||||||
if (!lhs_shape_type || !rhs_shape_type) return failure();
|
if (!lhs_shape_type || !rhs_shape_type) return failure();
|
||||||
if (!lhs_shape_type.hasStaticShape() || !rhs_shape_type.hasStaticShape())
|
if (!lhs_shape_type.hasStaticShape() || !rhs_shape_type.hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -167,28 +154,29 @@ struct GeneralDotConvert : public OpRewritePattern<mlir::mhlo::DotGeneralOp> {
|
||||||
auto new_dot_type =
|
auto new_dot_type =
|
||||||
RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type);
|
RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type);
|
||||||
|
|
||||||
mlir::ArrayAttr precision_config;
|
ArrayAttr precision_config;
|
||||||
if (op.precision_config()) precision_config = *op.precision_config();
|
if (op.precision_config()) precision_config = *op.precision_config();
|
||||||
auto new_dot_op = rewriter.create<mlir::mhlo::DotOp>(
|
auto new_dot_op = rewriter.create<DotOp>(op.getLoc(), new_dot_type, lhs,
|
||||||
op.getLoc(), new_dot_type, lhs, rhs, precision_config);
|
rhs, precision_config);
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<mlir::mhlo::ReshapeOp>(op, op.getType(),
|
rewriter.replaceOpWithNewOp<ReshapeOp>(op, op.getType(), new_dot_op);
|
||||||
new_dot_op);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LegalizeGeneralDotPass
|
struct LegalizeGeneralDotPass
|
||||||
: public PassWrapper<LegalizeGeneralDotPass, FunctionPass> {
|
: public LegalizeGeneralDotPassBase<LegalizeGeneralDotPass> {
|
||||||
/// Lower all general dots that can be represented as a non-batched matmul.
|
/// Lower all general dots that can be represented as a non-batched matmul.
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
OwningRewritePatternList patterns(&getContext());
|
OwningRewritePatternList patterns(&getContext());
|
||||||
mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(&patterns, &getContext());
|
PopulateGeneralDotOpLoweringPatterns(&patterns, &getContext());
|
||||||
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
} // namespace mhlo
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
void mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(
|
void mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(
|
||||||
OwningRewritePatternList *patterns, MLIRContext *ctx) {
|
OwningRewritePatternList *patterns, MLIRContext *ctx) {
|
||||||
|
|
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
@ -28,7 +29,7 @@ namespace mhlo {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct TestMaterializeBroadcastsPass
|
struct TestMaterializeBroadcastsPass
|
||||||
: public PassWrapper<TestMaterializeBroadcastsPass, FunctionPass> {
|
: public TestMaterializeBroadcastsPassBase<TestMaterializeBroadcastsPass> {
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
ConversionTarget conversionTarget(getContext());
|
ConversionTarget conversionTarget(getContext());
|
||||||
OwningRewritePatternList conversionPatterns(&getContext());
|
OwningRewritePatternList conversionPatterns(&getContext());
|
||||||
|
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
@ -39,7 +40,7 @@ void MatchAndRewrite(WhileOp whileOp);
|
||||||
|
|
||||||
/// Pass that converts MHLO control flow to SCF.
|
/// Pass that converts MHLO control flow to SCF.
|
||||||
class ControlFlowToScfPass
|
class ControlFlowToScfPass
|
||||||
: public mlir::PassWrapper<ControlFlowToScfPass, FunctionPass> {
|
: public LegalizeControlFlowToScfPassBase<ControlFlowToScfPass> {
|
||||||
void getDependentDialects(DialectRegistry& registry) const override {
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
registry.insert<scf::SCFDialect>();
|
registry.insert<scf::SCFDialect>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||||
#include "llvm/ADT/EquivalenceClasses.h"
|
#include "llvm/ADT/EquivalenceClasses.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
#include "mlir-hlo/utils/cycle_detector.h"
|
#include "mlir-hlo/utils/cycle_detector.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||||
|
@ -480,7 +481,7 @@ class FusionPlanner {
|
||||||
EquivalenceClasses<int32_t> leader_for_node_;
|
EquivalenceClasses<int32_t> leader_for_node_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MhloFusionPass : public mlir::PassWrapper<MhloFusionPass, FunctionPass> {
|
struct MhloFusionPass : public MhloFusionPassBase<MhloFusionPass> {
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
FuncOp func = getFunction();
|
FuncOp func = getFunction();
|
||||||
if (!IsTargetFunc(func)) {
|
if (!IsTargetFunc(func)) {
|
||||||
|
|
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
@ -23,28 +24,26 @@ limitations under the License.
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
using mlir::FunctionPass;
|
namespace mlir {
|
||||||
using mlir::PassWrapper;
|
namespace mhlo {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class OptimizeMhloPass : public PassWrapper<OptimizeMhloPass, FunctionPass> {
|
class OptimizeMhloPass : public OptimizeMhloPassBase<OptimizeMhloPass> {
|
||||||
public:
|
public:
|
||||||
explicit OptimizeMhloPass() : PassWrapper<OptimizeMhloPass, FunctionPass>() {}
|
|
||||||
|
|
||||||
/// Performs the lowering to MHLO dialect.
|
/// Performs the lowering to MHLO dialect.
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
};
|
};
|
||||||
} // end anonymous namespace
|
|
||||||
|
|
||||||
// Lowers the complex operations that can be represented using other operations.
|
// Lowers the complex operations that can be represented using other operations.
|
||||||
void OptimizeMhloPass::runOnFunction() {
|
void OptimizeMhloPass::runOnFunction() {
|
||||||
// Add lowering patterns to the list.
|
// Add lowering patterns to the list.
|
||||||
mlir::OwningRewritePatternList patterns(&getContext());
|
OwningRewritePatternList patterns(&getContext());
|
||||||
mlir::mhlo::PopulateOptimizeMHLOPatterns(&getContext(), &patterns);
|
PopulateOptimizeMHLOPatterns(&getContext(), &patterns);
|
||||||
|
|
||||||
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
|
||||||
}
|
}
|
||||||
|
} // end anonymous namespace
|
||||||
|
} // namespace mhlo
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
std::unique_ptr<mlir::FunctionPass> mlir::mhlo::createOptimizeMhloPass() {
|
std::unique_ptr<mlir::FunctionPass> mlir::mhlo::createOptimizeMhloPass() {
|
||||||
return std::make_unique<OptimizeMhloPass>();
|
return std::make_unique<mlir::mhlo::OptimizeMhloPass>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/Identifier.h"
|
#include "mlir/IR/Identifier.h"
|
||||||
|
@ -84,7 +85,8 @@ struct ReifyReturnTypeShapesPattern : public RewritePattern {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TestInferShapedTypeMethodsPass
|
struct TestInferShapedTypeMethodsPass
|
||||||
: public PassWrapper<TestInferShapedTypeMethodsPass, FunctionPass> {
|
: public TestInferShapedTypeMethodsPassBase<
|
||||||
|
TestInferShapedTypeMethodsPass> {
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<shape::ShapeDialect>();
|
registry.insert<shape::ShapeDialect>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
|
@ -538,7 +539,7 @@ struct ConvertUnrankedDynamicBroadcastNaryOp
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TransformUnrankedHloPass
|
struct TransformUnrankedHloPass
|
||||||
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
|
: public mhlo::TransformUnrankedHloPassBase<TransformUnrankedHloPass> {
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<chlo::HloClientDialect, mhlo::MhloDialect, scf::SCFDialect,
|
registry.insert<chlo::HloClientDialect, mhlo::MhloDialect, scf::SCFDialect,
|
||||||
shape::ShapeDialect>();
|
shape::ShapeDialect>();
|
||||||
|
|
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
@ -30,7 +31,7 @@ namespace mhlo {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
struct TestUnfuseBatchNormPass
|
struct TestUnfuseBatchNormPass
|
||||||
: public PassWrapper<TestUnfuseBatchNormPass, OperationPass<>> {
|
: public TestUnfuseBatchNormPassBase<TestUnfuseBatchNormPass> {
|
||||||
void getDependentDialects(DialectRegistry& registry) const override {
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
registry.insert<memref::MemRefDialect>();
|
registry.insert<memref::MemRefDialect>();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue