diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h index dad298f..c5483e9 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h @@ -16,16 +16,16 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_ -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/DialectImplementation.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" namespace mlir { namespace chlo { @@ -37,7 +37,7 @@ class HloClientDialect : public Dialect { }; #define GET_OP_CLASSES -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc" } // namespace chlo } // namespace mlir diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index e93bf91..79d6fb2 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -29,9 +29,9 @@ limitations under the License. #ifndef CHLO_OPS #define CHLO_OPS -include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" -include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.td" -include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" def HLOClient_Dialect : Dialect { diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h index 976b06f..0036cc0 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h @@ -18,24 +18,24 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_ -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/DialectImplementation.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" +#include "llvm/ADT/StringRef.h" +#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" namespace mlir { class OpBuilder; -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.h.inc" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.h.inc" namespace mhlo { @@ -91,7 +91,7 @@ LogicalResult deriveShapeFromFirstOperand( SmallVectorImpl *reifiedReturnShapes); #define GET_OP_CLASSES -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc" } // end namespace mhlo } // end namespace mlir diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 653d742..0ed4235 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -18,12 +18,12 @@ limitations under the License. #ifndef HLO_OPS #define HLO_OPS -include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" -include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.td" -include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.td" -include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" -include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td" -include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" +include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td" +include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td" def HLO_Dialect : Dialect { let name = "mhlo"; diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td index cf90fc2..7f9784d 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td @@ -16,7 +16,7 @@ limitations under the License. #ifndef HLO_OPS_BASE #define HLO_OPS_BASE -include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" +include "mlir/IR/OpBase.td" def HLO_Pred : TypeAlias; diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td index 7c7b643..e1ae9e1 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td @@ -18,7 +18,7 @@ limitations under the License. #ifndef HLO_UTILS #define HLO_UTILS -include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" +include "mlir/IR/OpBase.td" def NullArrayAttr : NativeCodeCall<"ArrayAttr()">; diff --git a/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h b/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h index 76b4ad8..00de117 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h @@ -21,7 +21,7 @@ limitations under the License. namespace mlir { -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h.inc" +#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h.inc" } // namespace mlir diff --git a/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td b/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td index 7e66ee8..f8e02d4 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td @@ -19,7 +19,7 @@ limitations under the License. #ifndef MLIR_INFER_FUSIBILITY_OP_INTERFACE #define MLIR_INFER_FUSIBILITY_OP_INTERFACE -include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" +include "mlir/IR/OpBase.td" // OpInterface to query if an op is fusible and to query the shape equality // constraint among the inputs and outputs of an op. diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h index 2aff525..bb9b290 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h @@ -18,22 +18,22 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/ViewLikeInterface.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/ViewLikeInterface.h" namespace mlir { class OpBuilder; -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.h.inc" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc" namespace lmhlo { @@ -44,7 +44,7 @@ class LmhloDialect : public Dialect { }; #define GET_OP_CLASSES -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" } // namespace lmhlo } // end namespace mlir diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index e8fcae3..8708221 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -33,9 +33,9 @@ limitations under the License. #ifndef LHLO_OPS #define LHLO_OPS -include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" -include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.td" -include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/ViewLikeInterface.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/ViewLikeInterface.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" def LHLO_Dialect : Dialect { diff --git a/include/mlir-hlo/Dialect/mhlo/IR/register.h b/include/mlir-hlo/Dialect/mhlo/IR/register.h new file mode 100644 index 0000000..5773901 --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/register.h @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MLIR_HLO_DIALECT_MHLO_IR_REGISTER_H_ +#define MLIR_HLO_DIALECT_MHLO_IR_REGISTER_H_ + +namespace mlir { +namespace mhlo { + +void registerAllDialects(); + +} +} // namespace mlir + +#endif // MLIR_HLO_DIALECT_MHLO_IR_REGISTER_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td b/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td new file mode 100644 index 0000000..963ff5d --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td @@ -0,0 +1,65 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/Pass/PassBase.td" + +def LhloCopyRemovalPass : Pass<"lhlo-copy-removal", "FuncOp"> { + let summary = "Removes redundant LHLO copy operations."; + let constructor = "createLhloCopyRemovalPass()"; +} + + +def LhloLegalizeToLinalgPass : Pass<"lhlo-legalize-to-linalg", "FuncOp"> { + let summary = "Legalize from LHLO dialect to Linalg dialect."; + let constructor = "createLegalizeLhloToLinalgPass()"; +} + + +def LhloFuseLinalgPass : Pass<"lhlo-fuse-linalg", "FuncOp"> { + let summary = "Greedily fuse linalg ops obtained after LHLO lowering."; + let constructor = "createLhloFuseLinalgPass()"; + let options = [ + Option<"use_parallel_loops_", "use-parallel-loops", "bool", + /*default=*/"false", "Tiles GenericOp consumer to parallel loops before linalg fusion">, + ListOption<"tile_sizes_", "tile-sizes", "unsigned", + "Faster memory space number to promote fusion buffers to", + "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">, + ]; +} + + +def LhloLegalizeToAffinePass : Pass<"lhlo-legalize-to-affine", "FuncOp"> { + let summary = "Legalize from LHLO dialect to affine dialect."; + let constructor = "createLhloLegalizeToAffinePass()"; +} + + +def LhloLegalizeToGpuPass : Pass<"lhlo-legalize-to-gpu", "FuncOp"> { + let summary = "Legalize from LHLO dialect to GPU dialect."; + let constructor = "createLegalizeToGpuPass()"; +} + + +def TestLhloToLLVMPass : Pass<"test-lhlo-legalize-to-llvm", "FuncOp"> { + let summary = "Legalize from LHLO dialect to LLVM."; + let constructor = "createTestLhloToLLVMPass()"; +} + + +def LhloLegalizeToParallelLoopsPass : Pass<"lhlo-legalize-to-parallel-loops", "FuncOp"> { + let summary = "Legalize from LHLO dialect to parallel loops."; + let constructor = "createLegalizeLhloToParallelLoopsPass()"; +} + diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h index 9e7126e..c51bcfc 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" namespace mlir { namespace mhlo { diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index fbcb21a..2bb5ab2 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -16,12 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_ -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSwitch.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" namespace mlir { namespace lmhlo { diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td b/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td new file mode 100644 index 0000000..fa3bde2 --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td @@ -0,0 +1,108 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/Pass/PassBase.td" + +def TestChloLegalizeToHloPass : Pass<"mhlo-test-chlo-legalize-to-hlo", "FuncOp"> { + let summary = "Test pass for applying chlo -> hlo legalization patterns."; + let constructor = "createTestChloLegalizeToHloPass()"; +} + +def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> { + let summary = "Legalize from HLO dialect to LHLO dialect."; + let constructor = "createLegalizeToLhloPass()"; +} + +def LegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "FuncOp"> { + let summary = "Legalize from MHLO control flow to CFG control flow."; + let constructor = "createLegalizeControlFlowPass()"; +} + +def LegalizeGatherToTorchIndexSelectPass : Pass<"mhlo-legalize-gather-to-torch-index-select", "FuncOp"> { + let summary = "Legalizes gathers to a torch index select."; + let constructor = "createLegalizeGatherToTorchIndexSelectPass()"; +} + + +def LegalizeTanhToApproximationPass : Pass<"mhlo-legalize-tanh-to-approximation", "FuncOp"> { + let summary = "Legalize tanh from standard dialect to an approximation."; + let constructor = "createLegalizeTanhToApproximationPass()"; +} + + +def HloLegalizeToLinalgPass : Pass<"hlo-legalize-to-linalg", "FuncOp"> { + let summary = "Legalize from HLO dialect to Linalg dialect."; + let constructor = "createLegalizeHloToLinalgPass()"; +} + + +def LegalizeToStandardPass : Pass<"mhlo-legalize-to-std", "FuncOp"> { + let summary = "Legalize from MHLO dialect to standard dialect."; + let constructor = "createLegalizeToStdPass()"; +} + +def LowerComplexPass : Pass<"mhlo-test-lower-complex", "FuncOp"> { + let summary = "Lower complex operations into non-complex operations."; + let constructor = "createLowerComplexPass()"; +} + + +def LegalizeGeneralDotPass : Pass<"mhlo-test-lower-general-dot", "FuncOp"> { + let summary = "Tests lowering general dot to a non-batched dot when possible."; + let constructor = "createLegalizeGeneralDotPass()"; +} + + +def TestMaterializeBroadcastsPass : Pass<"mhlo-test-materialize-broadcasts", "FuncOp"> { + let summary = "Test pass for materializing 'broadcast_dimensions' attributes."; + let constructor = "createTestMaterializeBroadcastsPass()"; +} + + +def MhloFusionPass : Pass<"mhlo-fusion", "FuncOp"> { + let summary = "Fuse mhlo ops to kLoop/kInput fusion patterns."; + let constructor = "createMhloFusionPass()"; +} + + +def OptimizeMhloPass : Pass<"mhlo-test-optimize", "FuncOp"> { + let summary = "Run optional HLO optimizations."; + let constructor = "createOptimizeMhloPass()"; +} + + +def SinkConstantsToControlFlowPass : Pass<"mhlo-sink-constants-to-control-flow", "FuncOp"> { + let summary = "Sink constants implicitly captured in control flow regions. This " + "is necessary to export to XLA."; + let constructor = "createSinkConstantsToControlFlowPass()"; +} + + +def TestInferShapedTypeMethodsPass : Pass<"mhlo-test-infer-shaped-type-methods", "FuncOp"> { + let summary = "Uses test ops to invoke InferShapedTypeOpInterface methods."; + let constructor = "createTestInferShapedTypeMethodsPass()"; +} + + +def TransformUnrankedHloPass : Pass<"transform-unranked-hlo", "FuncOp"> { + let summary = "Realize element-wise operations on ranked tensors where possible."; + let constructor = "createTransformUnrankedHloPass()"; +} + + +def TestUnfuseBatchNormPass : Pass<"mhlo-test-unfuse-batch-norm", "FuncOp"> { + let summary = "Test pass for materializing 'broadcast_dimensions' attributes."; + let constructor = "createTestUnfuseBatchNormPass()"; +} diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h index 8a52578..efa116f 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h @@ -18,11 +18,12 @@ limitations under the License. #include -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h" +#include "llvm/ADT/ArrayRef.h" namespace mlir { class FuncOp; +class FunctionPass; class ModuleOp; class Operation; template @@ -58,18 +59,26 @@ std::unique_ptr> createSinkConstantsToControlFlowPass(); // fuse mhlo ops to kLoop/kInput fusion patterns std::unique_ptr> createMhloFusionPass(); +/// Lowers the standard TanhOp to an approximation that does not use intrinsics. +std::unique_ptr> createLegalizeTanhToApproximationPass(); + +std::unique_ptr createOptimizeMhloPass(); +std::unique_ptr createLowerComplexPass(); +std::unique_ptr<::mlir::Pass> createLegalizeGeneralDotPass(); +std::unique_ptr createLegalizeGatherToTorchIndexSelectPass(); + } // namespace mhlo namespace lmhlo { // Lowers from LHLO dialect to Affine dialect. -std::unique_ptr> createLegalizeToAffinePass(); +std::unique_ptr> createLhloLegalizeToAffinePass(); // Lowers from LHLO dialect to Linalg dialect. std::unique_ptr> createLegalizeLhloToLinalgPass(); // Lowers from LHLO dialect to GPU dialect. -std::unique_ptr> createLegalizeToGpuPass(); +std::unique_ptr createLegalizeToGpuPass(); // Fuses linalg ops obtained after LHLO lowering. To enable fusion, // operations are first tiled. @@ -80,7 +89,7 @@ std::unique_ptr> createLegalizeToGpuPass(); // 'tile_sizes' provides the tile sizes to use for tiling. If the linalg // operation has more dimensions than tile sizes provided, 1 is used as // default. -std::unique_ptr> createLhloFuseLinalg( +std::unique_ptr createLhloFuseLinalgPass( bool use_parallel_loops = false, llvm::ArrayRef tile_sizes = {}); // Removes unnecessary LHLO copies which copy from the allocated buffers to the @@ -94,12 +103,6 @@ std::unique_ptr> createLegalizeLhloToParallelLoopsPass(); } // namespace lmhlo -namespace hlo { - -/// Lowers the standard TanhOp to an approximation that does not use intrinsics. -std::unique_ptr> createLegalizeTanhToApproximationPass(); - -} // namespace hlo } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h b/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h new file mode 100644 index 0000000..5c862d8 --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REGISTER_PASSES_H_ +#define MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REGISTER_PASSES_H_ + +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace mhlo { + +std::unique_ptr createTestChloLegalizeToHloPass(); +std::unique_ptr createTestInferShapedTypeMethodsPass(); +std::unique_ptr createTestMaterializeBroadcastsPass(); +std::unique_ptr createTestUnfuseBatchNormPass(); + +inline void registerAllMhloPasses() { +#define GEN_PASS_REGISTRATION +#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" +} + +} // namespace mhlo + +namespace lmhlo { + +std::unique_ptr createTestLhloToLLVMPass(); + +inline void registerAllLmhloPasses() { +#define GEN_PASS_REGISTRATION +#include "mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.h.inc" +} + +} // namespace lmhlo +} // namespace mlir + +#endif // MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REGISTER_PASSES_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index 88a7758..e5ca4f7 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -18,9 +18,9 @@ limitations under the License. #include -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { class LLVMTypeConverter; @@ -80,6 +80,11 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context, void PopulateUnfuseBatchNormPatterns(MLIRContext *context, OwningRewritePatternList *patterns); +// Populates a pattern that translates the standard TanhOp to an approximation +// that does not use intrinsics. +void PopulateTanhToApproximationPatterns(MLIRContext *context, + OwningRewritePatternList *patterns); + } // namespace mhlo namespace lmhlo { @@ -100,14 +105,6 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context, } // namespace chlo -namespace hlo { - -// Populates a pattern that translates the standard TanhOp to an approximation -// that does not use intrinsics. -void PopulateTanhToApproximationPatterns(MLIRContext *context, - OwningRewritePatternList *patterns); - -} // namespace hlo } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_ diff --git a/include/mlir-hlo/utils/broadcast_utils.h b/include/mlir-hlo/utils/broadcast_utils.h index 85e5ffe..1e24042 100644 --- a/include/mlir-hlo/utils/broadcast_utils.h +++ b/include/mlir-hlo/utils/broadcast_utils.h @@ -19,12 +19,12 @@ limitations under the License. // Utilities relating to implementing HLO broadcasting. // Note: This file should not depend on any non-MLIR TensorFlow libraries. -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Support/LLVM.h" namespace mlir { namespace hlo { diff --git a/include/mlir-hlo/utils/convert_op_folder.h b/include/mlir-hlo/utils/convert_op_folder.h index 9a62a03..4cf7438 100644 --- a/include/mlir-hlo/utils/convert_op_folder.h +++ b/include/mlir-hlo/utils/convert_op_folder.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_ -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/StandardTypes.h" namespace mlir { namespace hlo { diff --git a/include/mlir-hlo/utils/cycle_detector.h b/include/mlir-hlo/utils/cycle_detector.h index 0cec777..79b56b3 100644 --- a/include/mlir-hlo/utils/cycle_detector.h +++ b/include/mlir-hlo/utils/cycle_detector.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseMap.h" namespace mlir { diff --git a/include/mlir-hlo/utils/hlo_utils.h b/include/mlir-hlo/utils/hlo_utils.h index de9d9b3..1e335ae 100644 --- a/include/mlir-hlo/utils/hlo_utils.h +++ b/include/mlir-hlo/utils/hlo_utils.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_ -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" namespace mlir { namespace hlo { diff --git a/lib/Dialect/mhlo/IR/chlo_ops.cc b/lib/Dialect/mhlo/IR/chlo_ops.cc index 63a7cda..99ed8bc 100644 --- a/lib/Dialect/mhlo/IR/chlo_ops.cc +++ b/lib/Dialect/mhlo/IR/chlo_ops.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Diagnostics.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h" +#include "mlir-hlo/utils/broadcast_utils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" namespace mlir { namespace chlo { @@ -260,7 +260,7 @@ BROADCAST_BINARY_OP_DEFS(BroadcastXorOp); #undef BROADCAST_BINARY_OP_DEFS #define GET_OP_CLASSES -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc" //===----------------------------------------------------------------------===// // chlo Dialect Constructor @@ -270,7 +270,7 @@ HloClientDialect::HloClientDialect(MLIRContext* context) : Dialect(getDialectNamespace(), context) { addOperations< #define GET_OP_LIST -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc" >(); } diff --git a/lib/Dialect/mhlo/IR/dialect_registration.cc b/lib/Dialect/mhlo/IR/dialect_registration.cc index 0608341..9d1c354 100644 --- a/lib/Dialect/mhlo/IR/dialect_registration.cc +++ b/lib/Dialect/mhlo/IR/dialect_registration.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" // Static initialization for *HLO dialects registration. static mlir::DialectRegistration mhlo_ops; diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index a0a266c..69b0100 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -15,7 +15,7 @@ limitations under the License. // This file defines the operations used in the MHLO dialect. -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include #include @@ -24,44 +24,43 @@ limitations under the License. #include #include -#include "third_party/absl/container/flat_hash_set.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/APFloat.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/APInt.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/iterator_range.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/Casting.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/FormatVariadic.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/MathExtras.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Matchers.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpImplementation.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Value.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LogicalResult.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/InliningUtils.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MathExtras.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc" +#include "mlir-hlo/utils/convert_op_folder.h" +#include "mlir-hlo/utils/hlo_utils.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/InliningUtils.h" namespace mlir { -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_patterns.cc.inc" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.cc.inc" +#include "hlo_patterns.cc.inc" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.cc.inc" namespace mhlo { Operation* MhloDialect::materializeConstant(OpBuilder& builder, Attribute value, @@ -106,7 +105,7 @@ DenseIntElementsAttr BuildSliceLimits(DenseIntElementsAttr start_indices, return GetI64ElementsAttr(slice_limits, builder); } -#include "third_party/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_canonicalize.inc" +#include "mhlo_canonicalize.inc" } // namespace //===----------------------------------------------------------------------===// @@ -375,8 +374,8 @@ static LogicalResult Verify(CollectivePermuteOp op) { << "expect source_target_pairs attribute of shape (N, 2), but got (" << type.getShape() << ")"; // Check source target pairs for duplicate sources or targets - absl::flat_hash_set sources; - absl::flat_hash_set targets; + llvm::DenseSet sources; + llvm::DenseSet targets; for (auto i = op.source_target_pairs().begin(), e = op.source_target_pairs().end(); i != e; ++i) { @@ -2123,7 +2122,7 @@ void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs, } #define GET_OP_CLASSES -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc" //===----------------------------------------------------------------------===// // mhlo Dialect Interfaces @@ -2154,7 +2153,7 @@ MhloDialect::MhloDialect(MLIRContext* context) : Dialect(getDialectNamespace(), context) { addOperations< #define GET_OP_LIST -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc" >(); addInterfaces(); addTypes(); diff --git a/lib/Dialect/mhlo/IR/hlo_patterns.td b/lib/Dialect/mhlo/IR/hlo_patterns.td index 2c035dc..b8b6cb8 100644 --- a/lib/Dialect/mhlo/IR/hlo_patterns.td +++ b/lib/Dialect/mhlo/IR/hlo_patterns.td @@ -15,8 +15,8 @@ limitations under the License. // Canonicalization patterns for the MHLO dialect. -include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td" -include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" +include "mlir/Dialect/Shape/IR/ShapeOps.td" +include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" def EqualBinaryOperands : Constraint>; diff --git a/lib/Dialect/mhlo/IR/infer_fusibility_op_interface.cc b/lib/Dialect/mhlo/IR/infer_fusibility_op_interface.cc index 23712d1..e93a6cf 100644 --- a/lib/Dialect/mhlo/IR/infer_fusibility_op_interface.cc +++ b/lib/Dialect/mhlo/IR/infer_fusibility_op_interface.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" +#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" namespace mlir { -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.cc.inc" +#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.cpp.inc" } // namespace mlir diff --git a/lib/Dialect/mhlo/IR/init.cc b/lib/Dialect/mhlo/IR/init.cc new file mode 100644 index 0000000..9fffeae --- /dev/null +++ b/lib/Dialect/mhlo/IR/init.cc @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/register.h" + +// Static initialization for *HLO dialects registration. + +void mlir::mhlo::registerAllDialects() { + static bool init_once = []() { + registerDialect(); + registerDialect(); + registerDialect(); + return true; + }(); + (void)init_once; + + // Dependent dialects +} diff --git a/lib/Dialect/mhlo/IR/lhlo_ops.cc b/lib/Dialect/mhlo/IR/lhlo_ops.cc index 57d3b78..bbb463c 100644 --- a/lib/Dialect/mhlo/IR/lhlo_ops.cc +++ b/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -15,44 +15,44 @@ limitations under the License. // This file defines the operations used in the LMHLO dialect. -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include #include #include -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/APFloat.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/APInt.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/FormatVariadic.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpImplementation.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Value.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" namespace mlir { -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.cc.inc" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc" namespace lmhlo { LmhloDialect::LmhloDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { addOperations< #define GET_OP_LIST -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc" >(); } @@ -127,7 +127,7 @@ static LogicalResult Verify(ReshapeMemRefCastOp op) { } #define GET_OP_CLASSES -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc" // TODO(cheshire): Support folding, reuse code from hlo_ops.cc. diff --git a/lib/Dialect/mhlo/transforms/canonicalize.td b/lib/Dialect/mhlo/IR/mhlo_canonicalize.td similarity index 94% rename from lib/Dialect/mhlo/transforms/canonicalize.td rename to lib/Dialect/mhlo/IR/mhlo_canonicalize.td index a6435bc..eb92d9e 100644 --- a/lib/Dialect/mhlo/transforms/canonicalize.td +++ b/lib/Dialect/mhlo/IR/mhlo_canonicalize.td @@ -15,7 +15,7 @@ limitations under the License. // This is the canonicalize pattern definition file. -include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" +include "mlir/IR/OpBase.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td" diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index 87e0752..2a8482b 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -13,19 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/SCF.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir-hlo/utils/broadcast_utils.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace chlo { diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc index bdfdd39..50cd6df 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc @@ -13,16 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/SCF.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Pass/Pass.h" namespace mlir { -namespace chlo { +namespace mhlo { namespace { @@ -32,7 +33,7 @@ struct TestChloLegalizeToHloPass ConversionTarget conversionTarget(getContext()); OwningRewritePatternList conversionPatterns; - conversionTarget.addIllegalDialect(); + conversionTarget.addIllegalDialect(); // Consider the mhlo dialect legal for tests. conversionTarget.addLegalDialect(); // The conversion uses helpers from the Standard dialect. @@ -40,7 +41,7 @@ struct TestChloLegalizeToHloPass conversionTarget.addLegalDialect(); conversionTarget.addLegalDialect(); - PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns); + chlo::PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns); if (failed(applyPartialConversion(getFunction(), conversionTarget, conversionPatterns))) { @@ -51,9 +52,10 @@ struct TestChloLegalizeToHloPass } // namespace -} // namespace chlo +std::unique_ptr createTestChloLegalizeToHloPass() { + return std::make_unique(); +} + +} // namespace mhlo } // namespace mlir -static mlir::PassRegistration pass( - "mhlo-test-chlo-legalize-to-hlo", - "Test pass for applying chlo -> hlo legalization patterns"); diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index d14072a..a8c3ad1 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -15,26 +15,25 @@ limitations under the License. // This file implements logic for lowering HLO dialect to LHLO dialect. -#include "third_party/absl/memory/memory.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/AffineMap.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BlockAndValueMapping.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/BufferPlacement.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/BufferPlacement.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace mhlo { @@ -511,11 +510,8 @@ void populateHLOToLHLOConversionPattern( std::unique_ptr> createLegalizeToLhloPass( bool results_escape_function) { - return absl::make_unique(results_escape_function); + return std::make_unique(results_escape_function); } -static PassRegistration legalize_pass( - "hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect"); - } // namespace mhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/legalize_control_flow.cc b/lib/Dialect/mhlo/transforms/legalize_control_flow.cc index 766d3c3..b6e23a6 100644 --- a/lib/Dialect/mhlo/transforms/legalize_control_flow.cc +++ b/lib/Dialect/mhlo/transforms/legalize_control_flow.cc @@ -15,30 +15,30 @@ limitations under the License. // This file implements logic for lowering MHLO dialect to Standard dialect. -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/Casting.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Block.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BlockAndValueMapping.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassRegistry.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LogicalResult.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/Casting.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LogicalResult.h" using mlir::PassRegistration; namespace mlir { namespace mhlo { namespace { -struct LegalizeControlFlow - : public mlir::PassWrapper { +struct LegalizeControlFlowPass + : public mlir::PassWrapper { // Perform the lowering to MLIR control flow. void runOnFunction() override; }; @@ -206,7 +206,7 @@ LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) { return success(); } -void LegalizeControlFlow::runOnFunction() { +void LegalizeControlFlowPass::runOnFunction() { auto func = getFunction(); llvm::SmallVector if_ops; func.walk([&](IfOp op) { if_ops.push_back(op); }); @@ -228,9 +228,5 @@ void LegalizeControlFlow::runOnFunction() { std::unique_ptr> mlir::mhlo::createLegalizeControlFlowPass() { - return std::make_unique(); + return std::make_unique(); } - -static PassRegistration legalize_cf_pass( - "mhlo-legalize-control-flow", - "Legalize from MHLO control flow to CFG control flow"); diff --git a/lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc b/lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc index 1fb2c13..59cd338 100644 --- a/lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc +++ b/lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc @@ -13,13 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "third_party/absl/memory/memory.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" namespace mlir { @@ -128,8 +127,8 @@ struct GatherIsTorchIndexSelect : public OpRewritePattern { } }; -struct LegalizeGatherToTorchIndexSelect - : public PassWrapper { +struct LegalizeGatherToTorchIndexSelectPass + : public PassWrapper { /// Perform the lowering of standard dialect operations to approximations. void runOnFunction() override { OwningRewritePatternList patterns; @@ -144,9 +143,9 @@ void PopulateGatherToTorchIndexSelectPatterns( patterns->insert(context); } -static PassRegistration legalize_hlo_pass( - "mhlo-legalize-gather-to-torch-index-select", - "Legalizes gathers to a torch index select."); +std::unique_ptr createLegalizeGatherToTorchIndexSelectPass() { + return std::make_unique(); +} } // namespace mhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc b/lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc index dfd05bb..57c494f 100644 --- a/lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc +++ b/lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc @@ -16,15 +16,15 @@ limitations under the License. // This file implements logic for lowering the tanh standard ops to an // approximation. -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" namespace mlir { -namespace hlo { +namespace mhlo { namespace { /// Emits the fast tanh approximation that is also used by XLA. @@ -126,8 +126,8 @@ class ApproximateTanhLowering : public OpRewritePattern { } }; -struct LegalizeTanhToApproximation - : public PassWrapper { +struct LegalizeTanhToApproximationPass + : public PassWrapper { /// Perform the lowering of standard dialect operations to approximations. void runOnFunction() override { OwningRewritePatternList patterns; @@ -140,7 +140,7 @@ struct LegalizeTanhToApproximation std::unique_ptr> createLegalizeTanhToApproximationPass() { - return std::make_unique(); + return std::make_unique(); } void PopulateTanhToApproximationPatterns(mlir::MLIRContext *context, @@ -148,9 +148,5 @@ void PopulateTanhToApproximationPatterns(mlir::MLIRContext *context, patterns->insert(context); } -static PassRegistration legalize_pass( - "mhlo-legalize-tanh-to-approximation", - "Legalize tanh from standard dialect to an approximation"); - -} // namespace hlo +} // namespace mhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 3eb8099..f47f2c2 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -15,26 +15,25 @@ limitations under the License. // This file implements logic for lowering HLO/LHLO dialect to Linalg dialect. -#include "third_party/absl/memory/memory.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/AffineExpr.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace { @@ -826,8 +825,8 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, // indexing_maps = [#map0, #map0, #map0], // iterator_types = ["parallel", "parallel"], // } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () -struct LhloLegalizeToLinalg - : public PassWrapper { +struct LhloLegalizeToLinalgPass + : public PassWrapper { void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); @@ -842,8 +841,8 @@ struct LhloLegalizeToLinalg } }; -struct HloLegalizeToLinalg - : public PassWrapper { +struct HloLegalizeToLinalgPass + : public PassWrapper { void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); @@ -861,11 +860,8 @@ struct HloLegalizeToLinalg namespace lmhlo { std::unique_ptr> createLegalizeLhloToLinalgPass() { - return absl::make_unique(); + return std::make_unique(); } - -static PassRegistration legalize_lhlo_pass( - "lhlo-legalize-to-linalg", "Legalize from LHLO dialect to Linalg dialect"); } // namespace lmhlo namespace mhlo { @@ -906,10 +902,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, } std::unique_ptr> createLegalizeHloToLinalgPass() { - return absl::make_unique(); + return std::make_unique(); } - -static PassRegistration legalize_hlo_pass( - "hlo-legalize-to-linalg", "Legalize from HLO dialect to Linalg dialect"); } // namespace mhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/legalize_to_standard.cc b/lib/Dialect/mhlo/transforms/legalize_to_standard.cc index e3104e4..cc574e0 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_standard.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_standard.cc @@ -15,18 +15,18 @@ limitations under the License. // This file implements logic for lowering MHLO dialect to Standard dialect. -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "llvm/ADT/StringSwitch.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" namespace mlir { namespace { -#include "third_party/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_legalize_to_standard.inc" +#include "generated_legalize_to_standard.inc" } // end anonymous namespace namespace mhlo { namespace { @@ -176,15 +176,15 @@ class ConvertIotaOp : public OpRewritePattern { } // end anonymous namespace namespace { -struct LegalizeToStandard - : public PassWrapper { +struct LegalizeToStandardPass + : public PassWrapper { /// Perform the lowering to Standard dialect. void runOnFunction() override; }; } // end anonymous namespace std::unique_ptr> createLegalizeToStdPass() { - return std::make_unique(); + return std::make_unique(); } void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, @@ -194,14 +194,11 @@ void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, } /// Perform the lowering to standard dialect. -void LegalizeToStandard::runOnFunction() { +void LegalizeToStandardPass::runOnFunction() { OwningRewritePatternList patterns; mlir::mhlo::PopulateMhloToStdPatterns(&patterns, &getContext()); applyPatternsAndFoldGreedily(getFunction(), patterns); } -static PassRegistration legalize_pass( - "mhlo-legalize-to-std", "Legalize from MHLO dialect to standard dialect"); - } // end namespace mhlo } // end namespace mlir diff --git a/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td b/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td index a1c3fb7..ea67c05 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td +++ b/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td @@ -15,9 +15,9 @@ limitations under the License. // This is the legalization pattern definition file for MHLO to StandardOps. -include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" -include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td" -include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" +include "mlir/IR/OpBase.td" +include "mlir/Dialect/StandardOps/IR/Ops.td" +include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" //===----------------------------------------------------------------------===// // Nullary op patterns. diff --git a/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc b/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc index 3310170..7a44184 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc @@ -15,12 +15,11 @@ limitations under the License. // This file implements a pass to remove redundant LHLO copy operations. -#include "third_party/absl/memory/memory.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" namespace mlir { namespace lmhlo { @@ -30,7 +29,8 @@ namespace { // arguments. All uses of each buffer are replaced with the corresponding block // argument and the buffer is freed. Note that this pass only works in regions // with a single block. -struct LhloCopyRemoval : mlir::PassWrapper> { +struct LhloCopyRemovalPass + : mlir::PassWrapper> { void runOnOperation() override { llvm::SmallVector eraseList; auto operation = getOperation(); @@ -95,11 +95,8 @@ struct LhloCopyRemoval : mlir::PassWrapper> { } // namespace std::unique_ptr createLhloCopyRemovalPass() { - return absl::make_unique(); + return std::make_unique(); } -static PassRegistration copy_removal_pass( - "lhlo-copy-removal", "Removes redundant LHLO copy operations"); - } // namespace lmhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc b/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc index 01aba61..1467f01 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc @@ -16,15 +16,14 @@ limitations under the License. // This file implements logic for fusing linalg ops obtained after LHLO // lowering. +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" -#include "third_party/absl/memory/memory.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/FoldUtils.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/FoldUtils.h" namespace mlir { namespace lmhlo { @@ -32,11 +31,13 @@ namespace { using linalg::LinalgOp; -class LhloFuseLinalg : public PassWrapper { +class LhloFuseLinalgPass + : public PassWrapper { public: - LhloFuseLinalg() = default; - LhloFuseLinalg(const LhloFuseLinalg&) {} - LhloFuseLinalg(bool use_parallel_loops, llvm::ArrayRef tile_sizes) { + LhloFuseLinalgPass() = default; + LhloFuseLinalgPass(const LhloFuseLinalgPass&) {} + LhloFuseLinalgPass(bool use_parallel_loops, + llvm::ArrayRef tile_sizes) { tile_sizes_ = tile_sizes; use_parallel_loops_.setValue(use_parallel_loops); } @@ -138,14 +139,10 @@ class LhloFuseLinalg : public PassWrapper { } // namespace -std::unique_ptr> createLhloFuseLinalg( +std::unique_ptr createLhloFuseLinalgPass( bool use_parallel_loops, ArrayRef tile_sizes) { - return absl::make_unique(use_parallel_loops, tile_sizes); + return std::make_unique(use_parallel_loops, tile_sizes); } -static PassRegistration legalize_pass( - "lhlo-fuse-linalg", - "Greedily fuse linalg ops obtained after LHLO lowering."); - } // namespace lmhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc index 2da3e8a..0789132 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc @@ -15,17 +15,16 @@ limitations under the License. // This file implements logic for lowering LHLO dialect to Affine dialect. -#include "third_party/absl/memory/memory.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" namespace mlir { namespace lmhlo { @@ -138,8 +137,8 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context, // clang-format on } -struct LhloLegalizeToAffine - : public PassWrapper { +struct LhloLegalizeToAffinePass + : public PassWrapper { void runOnFunction() override { OwningRewritePatternList patterns; auto func = getFunction(); @@ -150,12 +149,9 @@ struct LhloLegalizeToAffine } // namespace -std::unique_ptr> createLegalizeToAffinePass() { - return absl::make_unique(); +std::unique_ptr> createLhloLegalizeToAffinePass() { + return std::make_unique(); } -static PassRegistration legalize_pass( - "lhlo-legalize-to-affine", "Legalize from LHLO dialect to affine dialect"); - } // namespace lmhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc index 4bee2cc..0d0b8b0 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc @@ -17,25 +17,24 @@ limitations under the License. #include -#include "third_party/absl/memory/memory.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/GPU/GPUDialect.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/SCF.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BlockAndValueMapping.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" +#include "llvm/ADT/ArrayRef.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace lmhlo { @@ -168,7 +167,8 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { }; }; -struct LhloLegalizeToGpu : public PassWrapper { +struct LhloLegalizeToGpuPass + : public PassWrapper { void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); @@ -185,12 +185,9 @@ struct LhloLegalizeToGpu : public PassWrapper { } // namespace -std::unique_ptr> createLegalizeToGpuPass() { - return absl::make_unique(); +std::unique_ptr createLegalizeToGpuPass() { + return std::make_unique(); } -static PassRegistration legalize_pass( - "lhlo-legalize-to-gpu", "Legalize from LHLO dialect to GPU dialect"); - } // namespace lmhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc index 0ed1b18..35bbea7 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace lmhlo { diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc index 99f8391..2ed0182 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/SCFToStandard/SCFToStandard.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" namespace mlir { namespace lmhlo { @@ -57,8 +57,9 @@ class TestLhloToLLVMPass } // namespace -static PassRegistration legalize_lhlo_pass( - "test-lhlo-legalize-to-llvm", "Legalize from LHLO dialect to LLVM."); +std::unique_ptr createTestLhloToLLVMPass() { + return std::make_unique(); +} } // namespace lmhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc index 3736763..19f47d0 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc @@ -13,17 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "third_party/absl/memory/memory.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/SCF.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace lmhlo { @@ -690,8 +689,8 @@ class SelectAndScatterOpConverter } }; -struct LhloLegalizeToParallelLoops - : public PassWrapper { +struct LhloLegalizeToParallelLoopsPass + : public PassWrapper { void runOnFunction() override { auto func = getFunction(); @@ -715,16 +714,11 @@ struct LhloLegalizeToParallelLoops } } }; - } // namespace std::unique_ptr> createLegalizeLhloToParallelLoopsPass() { - return absl::make_unique(); + return std::make_unique(); } -static PassRegistration legalize_lhlo_pass( - "lhlo-legalize-to-parallel-loops", - "Legalize from LHLO dialect to parallel loops."); - } // namespace lmhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/lower_complex.cc b/lib/Dialect/mhlo/transforms/lower_complex.cc index 083da64..9f7c946 100644 --- a/lib/Dialect/mhlo/transforms/lower_complex.cc +++ b/lib/Dialect/mhlo/transforms/lower_complex.cc @@ -22,18 +22,18 @@ limitations under the License. #include #include -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassRegistry.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/utils/hlo_utils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" using mlir::FunctionPass; using mlir::OwningRewritePatternList; @@ -41,9 +41,9 @@ using mlir::PassRegistration; using mlir::PassWrapper; namespace { -class LowerComplex : public PassWrapper { +class LowerComplexPass : public PassWrapper { public: - explicit LowerComplex() : PassWrapper() {} + explicit LowerComplexPass() : PassWrapper() {} /// Performs the lowering to MHLO dialect. void runOnFunction() override; @@ -54,7 +54,7 @@ namespace mlir { namespace mhlo { namespace { -#include "third_party/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_lower_complex.inc" +#include "generated_lower_complex.inc" } // end anonymous namespace @@ -66,7 +66,7 @@ void PopulateComplexLoweringPatterns(MLIRContext* context, } // end namespace mlir // Lowers the complex operations that can be represented using other operations. -void LowerComplex::runOnFunction() { +void LowerComplexPass::runOnFunction() { // Add lowering patterns to the list. OwningRewritePatternList patterns; mlir::mhlo::PopulateComplexLoweringPatterns(&getContext(), &patterns); @@ -74,6 +74,6 @@ void LowerComplex::runOnFunction() { applyPatternsAndFoldGreedily(getFunction(), patterns); } -static PassRegistration pass( - "mhlo-test-lower-complex", - "Lower complex operations into non-complex operations"); +std::unique_ptr mlir::mhlo::createLowerComplexPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/mhlo/transforms/lower_complex_patterns.td b/lib/Dialect/mhlo/transforms/lower_complex_patterns.td index e4e7067..2cc97c9 100644 --- a/lib/Dialect/mhlo/transforms/lower_complex_patterns.td +++ b/lib/Dialect/mhlo/transforms/lower_complex_patterns.td @@ -16,9 +16,9 @@ limitations under the License. // This is the legalization pattern that converts complex operations into // equivalent real value operations. -include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" -include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td" -include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" +include "mlir/IR/OpBase.td" +include "mlir/Dialect/StandardOps/IR/Ops.td" +include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" //===----------------------------------------------------------------------===// // Binary op patterns. diff --git a/lib/Dialect/mhlo/transforms/lower_general_dot.cc b/lib/Dialect/mhlo/transforms/lower_general_dot.cc index 1b3b1ac..2bbd469 100644 --- a/lib/Dialect/mhlo/transforms/lower_general_dot.cc +++ b/lib/Dialect/mhlo/transforms/lower_general_dot.cc @@ -15,20 +15,20 @@ limitations under the License. // This file implements logic for lowering MHLO general dot to a regular dot. -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringSwitch.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" using mlir::DenseIntElementsAttr; using mlir::ElementsAttr; @@ -170,8 +170,8 @@ struct GeneralDotConvert : public OpRewritePattern { } }; -struct LegalizeGeneralDot - : public PassWrapper { +struct LegalizeGeneralDotPass + : public PassWrapper { /// Lower all general dots that can be represented as a non-batched matmul. void runOnFunction() override { OwningRewritePatternList patterns; @@ -187,6 +187,6 @@ void mlir::mhlo::PopulateGeneralDotOpLoweringPatterns( patterns->insert(ctx); } -static PassRegistration legalize_pass( - "mhlo-test-lower-general-dot", - "Tests lowering general dot to a non-batched dot when possible"); +std::unique_ptr<::mlir::Pass> mlir::mhlo::createLegalizeGeneralDotPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/mhlo/transforms/materialize_broadcasts.cc b/lib/Dialect/mhlo/transforms/materialize_broadcasts.cc index 8abc099..445cf2e 100644 --- a/lib/Dialect/mhlo/transforms/materialize_broadcasts.cc +++ b/lib/Dialect/mhlo/transforms/materialize_broadcasts.cc @@ -15,12 +15,12 @@ limitations under the License. #include -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace mhlo { diff --git a/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc b/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc index a418e4c..3909f04 100644 --- a/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc +++ b/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace mhlo { @@ -50,9 +50,9 @@ struct TestMaterializeBroadcastsPass } // namespace +std::unique_ptr<::mlir::Pass> createTestMaterializeBroadcastsPass() { + return std::make_unique(); +} + } // namespace mhlo } // namespace mlir - -static mlir::PassRegistration pass( - "mhlo-test-materialize-broadcasts", - "Test pass for materializing 'broadcast_dimensions' attributes"); diff --git a/lib/Dialect/mhlo/transforms/mhlo_fusion.cc b/lib/Dialect/mhlo/transforms/mhlo_fusion.cc index f7158a0..233d95a 100644 --- a/lib/Dialect/mhlo/transforms/mhlo_fusion.cc +++ b/lib/Dialect/mhlo/transforms/mhlo_fusion.cc @@ -18,14 +18,14 @@ limitations under the License. #include #include +#include "llvm/ADT/EquivalenceClasses.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/utils/cycle_detector.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "mlir/IR/Matchers.h" #include "mlir/Pass/Pass.h" // TF:local_config_mlir #include "mlir/Transforms/RegionUtils.h" // TF:llvm-project -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/EquivalenceClasses.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h" // This pass has similar functionality of the fusion pass in XLA stack. // However, unlike XLA, it targets the fully dynamic shape scenario. @@ -479,7 +479,7 @@ class FusionPlanner { EquivalenceClasses leader_for_node_; }; -struct MhloFusion : public mlir::PassWrapper { +struct MhloFusionPass : public mlir::PassWrapper { void runOnFunction() override { FuncOp func = getFunction(); if (!IsTargetFunc(func)) { @@ -568,12 +568,9 @@ struct MhloFusion : public mlir::PassWrapper { } // namespace -std::unique_ptr> createMhloFusion() { - return std::make_unique(); +std::unique_ptr> createMhloFusionPass() { + return std::make_unique(); } -static PassRegistration mhlo_fusion_pass( - "mhlo-fusion", "fuse mhlo ops to kLoop/kInput fusion patterns."); - } // namespace mhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/optimize_mhlo.cc b/lib/Dialect/mhlo/transforms/optimize_mhlo.cc index 0e49c73..43de470 100644 --- a/lib/Dialect/mhlo/transforms/optimize_mhlo.cc +++ b/lib/Dialect/mhlo/transforms/optimize_mhlo.cc @@ -21,18 +21,18 @@ limitations under the License. #include #include -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassRegistry.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/utils/hlo_utils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" using mlir::OwningRewritePatternList; diff --git a/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc b/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc index b4184e2..32a846e 100644 --- a/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc +++ b/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc @@ -13,23 +13,24 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" using mlir::FunctionPass; using mlir::PassRegistration; using mlir::PassWrapper; namespace { -class OptimizeMhlo : public PassWrapper { +class OptimizeMhloPass : public PassWrapper { public: - explicit OptimizeMhlo() : PassWrapper() {} + explicit OptimizeMhloPass() : PassWrapper() {} /// Performs the lowering to MHLO dialect. void runOnFunction() override; @@ -37,7 +38,7 @@ class OptimizeMhlo : public PassWrapper { } // end anonymous namespace // Lowers the complex operations that can be represented using other operations. -void OptimizeMhlo::runOnFunction() { +void OptimizeMhloPass::runOnFunction() { // Add lowering patterns to the list. mlir::OwningRewritePatternList patterns; mlir::mhlo::PopulateOptimizeMHLOPatterns(&getContext(), &patterns); @@ -45,5 +46,6 @@ void OptimizeMhlo::runOnFunction() { applyPatternsAndFoldGreedily(getFunction(), patterns); } -static PassRegistration pass("mhlo-test-optimize", - "Run optional HLO optimizations."); +std::unique_ptr mlir::mhlo::createOptimizeMhloPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc b/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc index 446bfa9..0f31e61 100644 --- a/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc +++ b/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/DenseMap.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/Casting.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassManager.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/RegionUtils.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/Casting.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/RegionUtils.h" namespace mlir { namespace mhlo { @@ -29,8 +29,8 @@ namespace { // A pass that sinks constants implicitly captured in control flow regions. This // is necessary to export to XLA. -class SinkConstantsToControlFlow - : public mlir::PassWrapper { +class SinkConstantsToControlFlowPass + : public mlir::PassWrapper { void runOnFunction() override { getFunction().walk([](Operation* op) { if (auto while_op = llvm::dyn_cast(op)) { @@ -70,15 +70,10 @@ class SinkConstantsToControlFlow } }; -static mlir::PassRegistration pass( - "mhlo-sink-constants-to-control-flow", - "Sink constants implicitly captured in control flow regions. This is " - "necessary to export to XLA."); - } // anonymous namespace std::unique_ptr> createSinkConstantsToControlFlowPass() { - return std::make_unique(); + return std::make_unique(); } } // namespace mhlo diff --git a/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc b/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc index d258457..35e5a18 100644 --- a/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc +++ b/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Identifier.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Identifier.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" namespace mlir { -namespace hlo { +namespace mhlo { namespace { struct InferReturnTypeComponentsPattern : public RewritePattern { @@ -92,9 +92,10 @@ struct TestInferShapedTypeMethodsPass }; } // namespace -} // namespace hlo -} // namespace mlir -static mlir::PassRegistration pass( - "mhlo-test-infer-shaped-type-methods", - "Uses test ops to invoke InferShapedTypeOpInterface methods"); +std::unique_ptr createTestInferShapedTypeMethodsPass() { + return std::make_unique(); +} + +} // namespace mhlo +} // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 258f581..8db5d84 100644 --- a/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -14,18 +14,17 @@ limitations under the License. ==============================================================================*/ -#include "third_party/absl/memory/memory.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace mhlo { @@ -204,9 +203,9 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context, // clang-format on } -static PassRegistration transform_unranked_hlo_pass( - "transform-unranked-hlo", - "Realize element-wise operations on ranked tensors where possible"); +std::unique_ptr<::mlir::Pass> createTransformUnrankedHloPass() { + return std::make_unique(); +} } // namespace mhlo } // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc b/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc index 5028e28..1458e5f 100644 --- a/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc +++ b/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace mhlo { diff --git a/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc b/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc index 33f21d4..f187a74 100644 --- a/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc +++ b/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace mhlo { @@ -38,9 +38,9 @@ struct TestUnfuseBatchNormPass } // namespace +std::unique_ptr<::mlir::Pass> createTestUnfuseBatchNormPass() { + return std::make_unique(); +} + } // namespace mhlo } // namespace mlir - -static mlir::PassRegistration pass( - "mhlo-test-unfuse-batch-norm", - "Test pass for materializing 'broadcast_dimensions' attributes"); diff --git a/lib/utils/broadcast_utils.cc b/lib/utils/broadcast_utils.cc index 1350725..73111c0 100644 --- a/lib/utils/broadcast_utils.cc +++ b/lib/utils/broadcast_utils.cc @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h" +#include "mlir-hlo/utils/broadcast_utils.h" #include -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/Sequence.h" -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Diagnostics.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/StandardTypes.h" namespace mlir { namespace hlo { diff --git a/lib/utils/convert_op_folder.cc b/lib/utils/convert_op_folder.cc index 5893b37..0751d2c 100644 --- a/lib/utils/convert_op_folder.cc +++ b/lib/utils/convert_op_folder.cc @@ -15,11 +15,11 @@ limitations under the License. // This file defines helpers useful when creating or manipulating lhlo/hlo. -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h" +#include "mlir-hlo/utils/convert_op_folder.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" namespace mlir { namespace hlo { diff --git a/lib/utils/cycle_detector.cc b/lib/utils/cycle_detector.cc index b3b51dd..0914460 100644 --- a/lib/utils/cycle_detector.cc +++ b/lib/utils/cycle_detector.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h" +#include "mlir-hlo/utils/cycle_detector.h" #include -#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/DenseSet.h" +#include "llvm/ADT/DenseSet.h" namespace mlir { diff --git a/lib/utils/cycle_detector_test.cc b/lib/utils/cycle_detector_test.cc index bee96d2..1ad8def 100644 --- a/lib/utils/cycle_detector_test.cc +++ b/lib/utils/cycle_detector_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h" +#include "mlir-hlo/utils/cycle_detector.h" #include "third_party/tensorflow/compiler/xla/test.h" diff --git a/lib/utils/hlo_utils.cc b/lib/utils/hlo_utils.cc index 4528474..df2442c 100644 --- a/lib/utils/hlo_utils.cc +++ b/lib/utils/hlo_utils.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h" +#include "mlir-hlo/utils/hlo_utils.h" #include -#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" +#include "mlir/IR/Attributes.h" namespace mlir { namespace hlo { diff --git a/tools/mlir-hlo-opt/mlir-hlo-opt.cpp b/tools/mlir-hlo-opt/mlir-hlo-opt.cpp new file mode 100644 index 0000000..70fc21d --- /dev/null +++ b/tools/mlir-hlo-opt/mlir-hlo-opt.cpp @@ -0,0 +1,121 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/ToolOutputFile.h" +#include "mlir-hlo/Dialect/mhlo/IR/register.h" +#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllPasses.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/MlirOptMain.h" + +// NOLINTNEXTLINE +static llvm::cl::opt inputFilename(llvm::cl::Positional, + llvm::cl::desc(""), + llvm::cl::init("-")); + +// NOLINTNEXTLINE +static llvm::cl::opt outputFilename( + "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"), + llvm::cl::init("-")); + +// NOLINTNEXTLINE +static llvm::cl::opt splitInputFile( + "split-input-file", + llvm::cl::desc("Split the input file into pieces and process each " + "chunk independently"), + llvm::cl::init(false)); + +// NOLINTNEXTLINE +static llvm::cl::opt verifyDiagnostics( + "verify-diagnostics", + llvm::cl::desc("Check that emitted diagnostics match " + "expected-* lines on the corresponding line"), + llvm::cl::init(false)); + +// NOLINTNEXTLINE +static llvm::cl::opt verifyPasses( + "verify-each", + llvm::cl::desc("Run the verifier after each transformation pass"), + llvm::cl::init(true)); + +// NOLINTNEXTLINE +static llvm::cl::opt allowUnregisteredDialects( + "allow-unregistered-dialect", + llvm::cl::desc("Allow operation with no registered dialects"), + llvm::cl::init(false)); + +// NOLINTNEXTLINE +static llvm::cl::opt showDialects( + "show-dialects", llvm::cl::desc("Print the list of registered dialects"), + llvm::cl::init(false)); + +int main(int argc, char **argv) { + mlir::registerAllDialects(); + mlir::registerAllPasses(); + + mlir::mhlo::registerAllDialects(); + mlir::mhlo::registerAllMhloPasses(); + mlir::lmhlo::registerAllLmhloPasses(); + + llvm::InitLLVM y(argc, argv); + + // Register any pass manager command line options. + mlir::registerPassManagerCLOptions(); + mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run"); + + // Parse pass names in main to ensure static initialization completed. + llvm::cl::ParseCommandLineOptions(argc, argv, + "MLIR modular optimizer driver\n"); + + if (showDialects) { + mlir::MLIRContext context; + llvm::outs() << "Registered Dialects:\n"; + for (mlir::Dialect *dialect : context.getRegisteredDialects()) { + llvm::outs() << dialect->getNamespace() << "\n"; + } + return 0; + } + + // Set up the input file. + std::string errorMessage; + auto file = mlir::openInputFile(inputFilename, &errorMessage); + if (!file) { + llvm::errs() << errorMessage << "\n"; + return 1; + } + + auto output = mlir::openOutputFile(outputFilename, &errorMessage); + if (!output) { + llvm::errs() << errorMessage << "\n"; + exit(1); + } + + if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, + splitInputFile, verifyDiagnostics, verifyPasses, + allowUnregisteredDialects))) { + return 1; + } + // Keep the output file if the invocation of MlirOptMain was successful. + output->keep(); + return 0; +}