More cleanup in mlir-hlo to prepare for the standalone build

Shuffle files around, use TableGen to register passes, and introduce
a `mlir-hlo-opt.cpp` file to hold the main entry point of the -opt tool
and stop relying on static registration for dialect/passes.

PiperOrigin-RevId: 323674455
This commit is contained in:
Mehdi Amini 2020-07-28 16:12:08 -07:00 committed by TensorFlow MLIR Team
parent effd3fb4f9
commit cd01bb4c4e
65 changed files with 976 additions and 619 deletions

View File

@ -16,16 +16,16 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_ #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectImplementation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h"
namespace mlir { namespace mlir {
namespace chlo { namespace chlo {
@ -37,7 +37,7 @@ class HloClientDialect : public Dialect {
}; };
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc"
} // namespace chlo } // namespace chlo
} // namespace mlir } // namespace mlir

View File

@ -29,9 +29,9 @@ limitations under the License.
#ifndef CHLO_OPS #ifndef CHLO_OPS
#define CHLO_OPS #define CHLO_OPS
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td"
include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
def HLOClient_Dialect : Dialect { def HLOClient_Dialect : Dialect {

View File

@ -18,24 +18,24 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_ #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h" #include "mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/DialectImplementation.h" #include "mlir/IR/Dialect.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" #include "mlir/IR/DialectImplementation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" #include "mlir/IR/Location.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h" #include "mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" #include "mlir/IR/OpDefinition.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h" #include "mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/IR/Types.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/InferTypeOpInterface.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" #include "mlir/Interfaces/SideEffectInterfaces.h"
namespace mlir { namespace mlir {
class OpBuilder; class OpBuilder;
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.h.inc" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.h.inc"
namespace mhlo { namespace mhlo {
@ -91,7 +91,7 @@ LogicalResult deriveShapeFromFirstOperand(
SmallVectorImpl<Value> *reifiedReturnShapes); SmallVectorImpl<Value> *reifiedReturnShapes);
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"
} // end namespace mhlo } // end namespace mhlo
} // end namespace mlir } // end namespace mlir

View File

@ -18,12 +18,12 @@ limitations under the License.
#ifndef HLO_OPS #ifndef HLO_OPS
#define HLO_OPS #define HLO_OPS
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td"
include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td"
include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td" include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td"
def HLO_Dialect : Dialect { def HLO_Dialect : Dialect {
let name = "mhlo"; let name = "mhlo";

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifndef HLO_OPS_BASE #ifndef HLO_OPS_BASE
#define HLO_OPS_BASE #define HLO_OPS_BASE
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">; def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;

View File

@ -18,7 +18,7 @@ limitations under the License.
#ifndef HLO_UTILS #ifndef HLO_UTILS
#define HLO_UTILS #define HLO_UTILS
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
def NullArrayAttr : NativeCodeCall<"ArrayAttr()">; def NullArrayAttr : NativeCodeCall<"ArrayAttr()">;

View File

@ -21,7 +21,7 @@ limitations under the License.
namespace mlir { namespace mlir {
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h.inc" #include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h.inc"
} // namespace mlir } // namespace mlir

View File

@ -19,7 +19,7 @@ limitations under the License.
#ifndef MLIR_INFER_FUSIBILITY_OP_INTERFACE #ifndef MLIR_INFER_FUSIBILITY_OP_INTERFACE
#define MLIR_INFER_FUSIBILITY_OP_INTERFACE #define MLIR_INFER_FUSIBILITY_OP_INTERFACE
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
// OpInterface to query if an op is fusible and to query the shape equality // OpInterface to query if an op is fusible and to query the shape equality
// constraint among the inputs and outputs of an op. // constraint among the inputs and outputs of an op.

View File

@ -18,22 +18,22 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" #include "mlir/IR/Location.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Interfaces/ViewLikeInterface.h"
namespace mlir { namespace mlir {
class OpBuilder; class OpBuilder;
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.h.inc" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc"
namespace lmhlo { namespace lmhlo {
@ -44,7 +44,7 @@ class LmhloDialect : public Dialect {
}; };
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
} // namespace lmhlo } // namespace lmhlo
} // end namespace mlir } // end namespace mlir

View File

@ -33,9 +33,9 @@ limitations under the License.
#ifndef LHLO_OPS #ifndef LHLO_OPS
#define LHLO_OPS #define LHLO_OPS
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td"
include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/ViewLikeInterface.td" include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
def LHLO_Dialect : Dialect { def LHLO_Dialect : Dialect {

View File

@ -0,0 +1,27 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MLIR_HLO_DIALECT_MHLO_IR_REGISTER_H_
#define MLIR_HLO_DIALECT_MHLO_IR_REGISTER_H_
namespace mlir {
namespace mhlo {
void registerAllDialects();
}
} // namespace mlir
#endif // MLIR_HLO_DIALECT_MHLO_IR_REGISTER_H_

View File

@ -0,0 +1,65 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
include "mlir/Pass/PassBase.td"
def LhloCopyRemovalPass : Pass<"lhlo-copy-removal", "FuncOp"> {
let summary = "Removes redundant LHLO copy operations.";
let constructor = "createLhloCopyRemovalPass()";
}
def LhloLegalizeToLinalgPass : Pass<"lhlo-legalize-to-linalg", "FuncOp"> {
let summary = "Legalize from LHLO dialect to Linalg dialect.";
let constructor = "createLegalizeLhloToLinalgPass()";
}
def LhloFuseLinalgPass : Pass<"lhlo-fuse-linalg", "FuncOp"> {
let summary = "Greedily fuse linalg ops obtained after LHLO lowering.";
let constructor = "createLhloFuseLinalgPass()";
let options = [
Option<"use_parallel_loops_", "use-parallel-loops", "bool",
/*default=*/"false", "Tiles GenericOp consumer to parallel loops before linalg fusion">,
ListOption<"tile_sizes_", "tile-sizes", "unsigned",
"Faster memory space number to promote fusion buffers to",
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">,
];
}
def LhloLegalizeToAffinePass : Pass<"lhlo-legalize-to-affine", "FuncOp"> {
let summary = "Legalize from LHLO dialect to affine dialect.";
let constructor = "createLhloLegalizeToAffinePass()";
}
def LhloLegalizeToGpuPass : Pass<"lhlo-legalize-to-gpu", "FuncOp"> {
let summary = "Legalize from LHLO dialect to GPU dialect.";
let constructor = "createLegalizeToGpuPass()";
}
def TestLhloToLLVMPass : Pass<"test-lhlo-legalize-to-llvm", "FuncOp"> {
let summary = "Legalize from LHLO dialect to LLVM.";
let constructor = "createTestLhloToLLVMPass()";
}
def LhloLegalizeToParallelLoopsPass : Pass<"lhlo-legalize-to-parallel-loops", "FuncOp"> {
let summary = "Legalize from LHLO dialect to parallel loops.";
let constructor = "createLegalizeLhloToParallelLoopsPass()";
}

View File

@ -18,8 +18,8 @@ limitations under the License.
#include <type_traits> #include <type_traits>
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
namespace mlir { namespace mlir {
namespace mhlo { namespace mhlo {

View File

@ -16,12 +16,12 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_ #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h" #include "llvm/ADT/StringSwitch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
namespace mlir { namespace mlir {
namespace lmhlo { namespace lmhlo {

View File

@ -0,0 +1,108 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
include "mlir/Pass/PassBase.td"
def TestChloLegalizeToHloPass : Pass<"mhlo-test-chlo-legalize-to-hlo", "FuncOp"> {
let summary = "Test pass for applying chlo -> hlo legalization patterns.";
let constructor = "createTestChloLegalizeToHloPass()";
}
def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> {
let summary = "Legalize from HLO dialect to LHLO dialect.";
let constructor = "createLegalizeToLhloPass()";
}
def LegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "FuncOp"> {
let summary = "Legalize from MHLO control flow to CFG control flow.";
let constructor = "createLegalizeControlFlowPass()";
}
def LegalizeGatherToTorchIndexSelectPass : Pass<"mhlo-legalize-gather-to-torch-index-select", "FuncOp"> {
let summary = "Legalizes gathers to a torch index select.";
let constructor = "createLegalizeGatherToTorchIndexSelectPass()";
}
def LegalizeTanhToApproximationPass : Pass<"mhlo-legalize-tanh-to-approximation", "FuncOp"> {
let summary = "Legalize tanh from standard dialect to an approximation.";
let constructor = "createLegalizeTanhToApproximationPass()";
}
def HloLegalizeToLinalgPass : Pass<"hlo-legalize-to-linalg", "FuncOp"> {
let summary = "Legalize from HLO dialect to Linalg dialect.";
let constructor = "createLegalizeHloToLinalgPass()";
}
def LegalizeToStandardPass : Pass<"mhlo-legalize-to-std", "FuncOp"> {
let summary = "Legalize from MHLO dialect to standard dialect.";
let constructor = "createLegalizeToStdPass()";
}
def LowerComplexPass : Pass<"mhlo-test-lower-complex", "FuncOp"> {
let summary = "Lower complex operations into non-complex operations.";
let constructor = "createLowerComplexPass()";
}
def LegalizeGeneralDotPass : Pass<"mhlo-test-lower-general-dot", "FuncOp"> {
let summary = "Tests lowering general dot to a non-batched dot when possible.";
let constructor = "createLegalizeGeneralDotPass()";
}
def TestMaterializeBroadcastsPass : Pass<"mhlo-test-materialize-broadcasts", "FuncOp"> {
let summary = "Test pass for materializing 'broadcast_dimensions' attributes.";
let constructor = "createTestMaterializeBroadcastsPass()";
}
def MhloFusionPass : Pass<"mhlo-fusion", "FuncOp"> {
let summary = "Fuse mhlo ops to kLoop/kInput fusion patterns.";
let constructor = "createMhloFusionPass()";
}
def OptimizeMhloPass : Pass<"mhlo-test-optimize", "FuncOp"> {
let summary = "Run optional HLO optimizations.";
let constructor = "createOptimizeMhloPass()";
}
def SinkConstantsToControlFlowPass : Pass<"mhlo-sink-constants-to-control-flow", "FuncOp"> {
let summary = "Sink constants implicitly captured in control flow regions. This "
"is necessary to export to XLA.";
let constructor = "createSinkConstantsToControlFlowPass()";
}
def TestInferShapedTypeMethodsPass : Pass<"mhlo-test-infer-shaped-type-methods", "FuncOp"> {
let summary = "Uses test ops to invoke InferShapedTypeOpInterface methods.";
let constructor = "createTestInferShapedTypeMethodsPass()";
}
def TransformUnrankedHloPass : Pass<"transform-unranked-hlo", "FuncOp"> {
let summary = "Realize element-wise operations on ranked tensors where possible.";
let constructor = "createTransformUnrankedHloPass()";
}
def TestUnfuseBatchNormPass : Pass<"mhlo-test-unfuse-batch-norm", "FuncOp"> {
let summary = "Test pass for materializing 'broadcast_dimensions' attributes.";
let constructor = "createTestUnfuseBatchNormPass()";
}

View File

@ -18,11 +18,12 @@ limitations under the License.
#include <memory> #include <memory>
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
namespace mlir { namespace mlir {
class FuncOp; class FuncOp;
class FunctionPass;
class ModuleOp; class ModuleOp;
class Operation; class Operation;
template <typename T> template <typename T>
@ -58,18 +59,26 @@ std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
// fuse mhlo ops to kLoop/kInput fusion patterns // fuse mhlo ops to kLoop/kInput fusion patterns
std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass(); std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass();
/// Lowers the standard TanhOp to an approximation that does not use intrinsics.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTanhToApproximationPass();
std::unique_ptr<FunctionPass> createOptimizeMhloPass();
std::unique_ptr<FunctionPass> createLowerComplexPass();
std::unique_ptr<::mlir::Pass> createLegalizeGeneralDotPass();
std::unique_ptr<FunctionPass> createLegalizeGatherToTorchIndexSelectPass();
} // namespace mhlo } // namespace mhlo
namespace lmhlo { namespace lmhlo {
// Lowers from LHLO dialect to Affine dialect. // Lowers from LHLO dialect to Affine dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToAffinePass(); std::unique_ptr<OperationPass<FuncOp>> createLhloLegalizeToAffinePass();
// Lowers from LHLO dialect to Linalg dialect. // Lowers from LHLO dialect to Linalg dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass(); std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass();
// Lowers from LHLO dialect to GPU dialect. // Lowers from LHLO dialect to GPU dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToGpuPass(); std::unique_ptr<FunctionPass> createLegalizeToGpuPass();
// Fuses linalg ops obtained after LHLO lowering. To enable fusion, // Fuses linalg ops obtained after LHLO lowering. To enable fusion,
// operations are first tiled. // operations are first tiled.
@ -80,7 +89,7 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeToGpuPass();
// 'tile_sizes' provides the tile sizes to use for tiling. If the linalg // 'tile_sizes' provides the tile sizes to use for tiling. If the linalg
// operation has more dimensions than tile sizes provided, 1 is used as // operation has more dimensions than tile sizes provided, 1 is used as
// default. // default.
std::unique_ptr<OperationPass<FuncOp>> createLhloFuseLinalg( std::unique_ptr<FunctionPass> createLhloFuseLinalgPass(
bool use_parallel_loops = false, llvm::ArrayRef<unsigned> tile_sizes = {}); bool use_parallel_loops = false, llvm::ArrayRef<unsigned> tile_sizes = {});
// Removes unnecessary LHLO copies which copy from the allocated buffers to the // Removes unnecessary LHLO copies which copy from the allocated buffers to the
@ -94,12 +103,6 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
} // namespace lmhlo } // namespace lmhlo
namespace hlo {
/// Lowers the standard TanhOp to an approximation that does not use intrinsics.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTanhToApproximationPass();
} // namespace hlo
} // namespace mlir } // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_ #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_

View File

@ -0,0 +1,49 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REGISTER_PASSES_H_
#define MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REGISTER_PASSES_H_
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace mhlo {
std::unique_ptr<Pass> createTestChloLegalizeToHloPass();
std::unique_ptr<FunctionPass> createTestInferShapedTypeMethodsPass();
std::unique_ptr<Pass> createTestMaterializeBroadcastsPass();
std::unique_ptr<Pass> createTestUnfuseBatchNormPass();
inline void registerAllMhloPasses() {
#define GEN_PASS_REGISTRATION
#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc"
}
} // namespace mhlo
namespace lmhlo {
std::unique_ptr<Pass> createTestLhloToLLVMPass();
inline void registerAllLmhloPasses() {
#define GEN_PASS_REGISTRATION
#include "mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.h.inc"
}
} // namespace lmhlo
} // namespace mlir
#endif // MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REGISTER_PASSES_H_

View File

@ -18,9 +18,9 @@ limitations under the License.
#include <memory> #include <memory>
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
namespace mlir { namespace mlir {
class LLVMTypeConverter; class LLVMTypeConverter;
@ -80,6 +80,11 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
void PopulateUnfuseBatchNormPatterns(MLIRContext *context, void PopulateUnfuseBatchNormPatterns(MLIRContext *context,
OwningRewritePatternList *patterns); OwningRewritePatternList *patterns);
// Populates a pattern that translates the standard TanhOp to an approximation
// that does not use intrinsics.
void PopulateTanhToApproximationPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
} // namespace mhlo } // namespace mhlo
namespace lmhlo { namespace lmhlo {
@ -100,14 +105,6 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
} // namespace chlo } // namespace chlo
namespace hlo {
// Populates a pattern that translates the standard TanhOp to an approximation
// that does not use intrinsics.
void PopulateTanhToApproximationPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
} // namespace hlo
} // namespace mlir } // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_ #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_

View File

@ -19,12 +19,12 @@ limitations under the License.
// Utilities relating to implementing HLO broadcasting. // Utilities relating to implementing HLO broadcasting.
// Note: This file should not depend on any non-MLIR TensorFlow libraries. // Note: This file should not depend on any non-MLIR TensorFlow libraries.
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" #include "mlir/IR/Location.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
namespace mlir { namespace mlir {
namespace hlo { namespace hlo {

View File

@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_ #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
namespace mlir { namespace mlir {
namespace hlo { namespace hlo {

View File

@ -18,7 +18,7 @@ limitations under the License.
#include <vector> #include <vector>
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
namespace mlir { namespace mlir {

View File

@ -16,11 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_ #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" #include "mlir/IR/TypeUtilities.h"
namespace mlir { namespace mlir {
namespace hlo { namespace hlo {

View File

@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "mlir-hlo/utils/broadcast_utils.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" #include "mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Diagnostics.h" #include "mlir/IR/Builders.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "mlir/IR/Diagnostics.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" #include "mlir/IR/StandardTypes.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h" #include "mlir/IR/TypeUtilities.h"
namespace mlir { namespace mlir {
namespace chlo { namespace chlo {
@ -260,7 +260,7 @@ BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
#undef BROADCAST_BINARY_OP_DEFS #undef BROADCAST_BINARY_OP_DEFS
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// chlo Dialect Constructor // chlo Dialect Constructor
@ -270,7 +270,7 @@ HloClientDialect::HloClientDialect(MLIRContext* context)
: Dialect(getDialectNamespace(), context) { : Dialect(getDialectNamespace(), context) {
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
>(); >();
} }

View File

@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
// Static initialization for *HLO dialects registration. // Static initialization for *HLO dialects registration.
static mlir::DialectRegistration<mlir::mhlo::MhloDialect> mhlo_ops; static mlir::DialectRegistration<mlir::mhlo::MhloDialect> mhlo_ops;

View File

@ -15,7 +15,7 @@ limitations under the License.
// This file defines the operations used in the MHLO dialect. // This file defines the operations used in the MHLO dialect.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include <assert.h> #include <assert.h>
#include <stddef.h> #include <stddef.h>
@ -24,44 +24,43 @@ limitations under the License.
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include "third_party/absl/container/flat_hash_set.h" #include "llvm/ADT/APFloat.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h" #include "llvm/ADT/iterator_range.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/MathExtras.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h" #include "mlir-hlo/utils/convert_op_folder.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir-hlo/utils/hlo_utils.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "mlir/Dialect/Shape/IR/Shape.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h" #include "mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" #include "mlir/IR/Builders.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" #include "mlir/IR/Dialect.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Matchers.h" #include "mlir/IR/Location.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h" #include "mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpImplementation.h" #include "mlir/IR/Matchers.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" #include "mlir/IR/OpDefinition.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h" #include "mlir/IR/OpImplementation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" #include "mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "mlir/IR/OperationSupport.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" #include "mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h" #include "mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Value.h" #include "mlir/IR/TypeUtilities.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h" #include "mlir/IR/Types.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LogicalResult.h" #include "mlir/IR/Value.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/InliningUtils.h" #include "mlir/Support/LLVM.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc" #include "mlir/Support/LogicalResult.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h" #include "mlir/Transforms/InliningUtils.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
namespace mlir { namespace mlir {
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_patterns.cc.inc" #include "hlo_patterns.cc.inc"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.cc.inc" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.cc.inc"
namespace mhlo { namespace mhlo {
Operation* MhloDialect::materializeConstant(OpBuilder& builder, Attribute value, Operation* MhloDialect::materializeConstant(OpBuilder& builder, Attribute value,
@ -106,7 +105,7 @@ DenseIntElementsAttr BuildSliceLimits(DenseIntElementsAttr start_indices,
return GetI64ElementsAttr(slice_limits, builder); return GetI64ElementsAttr(slice_limits, builder);
} }
#include "third_party/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_canonicalize.inc" #include "mhlo_canonicalize.inc"
} // namespace } // namespace
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -375,8 +374,8 @@ static LogicalResult Verify(CollectivePermuteOp op) {
<< "expect source_target_pairs attribute of shape (N, 2), but got (" << "expect source_target_pairs attribute of shape (N, 2), but got ("
<< type.getShape() << ")"; << type.getShape() << ")";
// Check source target pairs for duplicate sources or targets // Check source target pairs for duplicate sources or targets
absl::flat_hash_set<int64_t> sources; llvm::DenseSet<int64_t> sources;
absl::flat_hash_set<int64_t> targets; llvm::DenseSet<int64_t> targets;
for (auto i = op.source_target_pairs().begin(), for (auto i = op.source_target_pairs().begin(),
e = op.source_target_pairs().end(); e = op.source_target_pairs().end();
i != e; ++i) { i != e; ++i) {
@ -2123,7 +2122,7 @@ void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
} }
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// mhlo Dialect Interfaces // mhlo Dialect Interfaces
@ -2154,7 +2153,7 @@ MhloDialect::MhloDialect(MLIRContext* context)
: Dialect(getDialectNamespace(), context) { : Dialect(getDialectNamespace(), context) {
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
>(); >();
addInterfaces<HLOInlinerInterface>(); addInterfaces<HLOInlinerInterface>();
addTypes<TokenType>(); addTypes<TokenType>();

View File

@ -15,8 +15,8 @@ limitations under the License.
// Canonicalization patterns for the MHLO dialect. // Canonicalization patterns for the MHLO dialect.
include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td" include "mlir/Dialect/Shape/IR/ShapeOps.td"
include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
def EqualBinaryOperands : Constraint<CPred<"$0 == $1">>; def EqualBinaryOperands : Constraint<CPred<"$0 == $1">>;

View File

@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" #include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
namespace mlir { namespace mlir {
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.cc.inc" #include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.cpp.inc"
} // namespace mlir } // namespace mlir

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/register.h"
// Static initialization for *HLO dialects registration.
void mlir::mhlo::registerAllDialects() {
static bool init_once = []() {
registerDialect<mlir::chlo::HloClientDialect>();
registerDialect<mlir::lmhlo::LmhloDialect>();
registerDialect<mlir::mhlo::MhloDialect>();
return true;
}();
(void)init_once;
// Dependent dialects
}

View File

@ -15,44 +15,44 @@ limitations under the License.
// This file defines the operations used in the LMHLO dialect. // This file defines the operations used in the LMHLO dialect.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include <assert.h> #include <assert.h>
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/APFloat.h" #include "llvm/ADT/APFloat.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/APInt.h" #include "llvm/ADT/APInt.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" #include "mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h" #include "mlir/IR/Builders.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" #include "mlir/IR/Dialect.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" #include "mlir/IR/Location.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h" #include "mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpImplementation.h" #include "mlir/IR/OpDefinition.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" #include "mlir/IR/OpImplementation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h" #include "mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" #include "mlir/IR/OperationSupport.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" #include "mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h" #include "mlir/IR/TypeUtilities.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Value.h" #include "mlir/IR/Types.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" #include "mlir/IR/Value.h"
namespace mlir { namespace mlir {
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.cc.inc" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc"
namespace lmhlo { namespace lmhlo {
LmhloDialect::LmhloDialect(MLIRContext *context) LmhloDialect::LmhloDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) { : Dialect(getDialectNamespace(), context) {
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"
>(); >();
} }
@ -127,7 +127,7 @@ static LogicalResult Verify(ReshapeMemRefCastOp op) {
} }
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"
// TODO(cheshire): Support folding, reuse code from hlo_ops.cc. // TODO(cheshire): Support folding, reuse code from hlo_ops.cc.

View File

@ -15,7 +15,7 @@ limitations under the License.
// This is the canonicalize pattern definition file. // This is the canonicalize pattern definition file.
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"

View File

@ -13,19 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/SCF.h" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "mlir-hlo/utils/broadcast_utils.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" #include "mlir/Dialect/SCF/SCF.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h" #include "mlir/Dialect/Shape/IR/Shape.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" #include "mlir/IR/MLIRContext.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir/IR/OperationSupport.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/IR/PatternMatch.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir/IR/StandardTypes.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h" #include "mlir/Transforms/DialectConversion.h"
namespace mlir { namespace mlir {
namespace chlo { namespace chlo {

View File

@ -13,16 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/SCF.h" #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir/Dialect/SCF/SCF.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Shape/IR/Shape.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
namespace mlir { namespace mlir {
namespace chlo { namespace mhlo {
namespace { namespace {
@ -32,7 +33,7 @@ struct TestChloLegalizeToHloPass
ConversionTarget conversionTarget(getContext()); ConversionTarget conversionTarget(getContext());
OwningRewritePatternList conversionPatterns; OwningRewritePatternList conversionPatterns;
conversionTarget.addIllegalDialect<HloClientDialect>(); conversionTarget.addIllegalDialect<chlo::HloClientDialect>();
// Consider the mhlo dialect legal for tests. // Consider the mhlo dialect legal for tests.
conversionTarget.addLegalDialect<mhlo::MhloDialect>(); conversionTarget.addLegalDialect<mhlo::MhloDialect>();
// The conversion uses helpers from the Standard dialect. // The conversion uses helpers from the Standard dialect.
@ -40,7 +41,7 @@ struct TestChloLegalizeToHloPass
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>(); conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
conversionTarget.addLegalDialect<mlir::scf::SCFDialect>(); conversionTarget.addLegalDialect<mlir::scf::SCFDialect>();
PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns); chlo::PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns);
if (failed(applyPartialConversion(getFunction(), conversionTarget, if (failed(applyPartialConversion(getFunction(), conversionTarget,
conversionPatterns))) { conversionPatterns))) {
@ -51,9 +52,10 @@ struct TestChloLegalizeToHloPass
} // namespace } // namespace
} // namespace chlo std::unique_ptr<FunctionPass> createTestChloLegalizeToHloPass() {
return std::make_unique<TestChloLegalizeToHloPass>();
}
} // namespace mhlo
} // namespace mlir } // namespace mlir
static mlir::PassRegistration<mlir::chlo::TestChloLegalizeToHloPass> pass(
"mhlo-test-chlo-legalize-to-hlo",
"Test pass for applying chlo -> hlo legalization patterns");

View File

@ -15,26 +15,25 @@ limitations under the License.
// This file implements logic for lowering HLO dialect to LHLO dialect. // This file implements logic for lowering HLO dialect to LHLO dialect.
#include "third_party/absl/memory/memory.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/AffineMap.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BlockAndValueMapping.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" #include "mlir/IR/AffineMap.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" #include "mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" #include "mlir/IR/BlockAndValueMapping.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" #include "mlir/IR/Builders.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" #include "mlir/IR/Function.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "mlir/IR/Location.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/BufferPlacement.h" #include "mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" #include "mlir/IR/PatternMatch.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/IR/StandardTypes.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir/Pass/Pass.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" #include "mlir/Transforms/BufferPlacement.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir/Transforms/DialectConversion.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir { namespace mlir {
namespace mhlo { namespace mhlo {
@ -511,11 +510,8 @@ void populateHLOToLHLOConversionPattern(
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass( std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
bool results_escape_function) { bool results_escape_function) {
return absl::make_unique<HloLegalizeToLhlo>(results_escape_function); return std::make_unique<HloLegalizeToLhlo>(results_escape_function);
} }
static PassRegistration<HloLegalizeToLhlo> legalize_pass(
"hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect");
} // namespace mhlo } // namespace mhlo
} // namespace mlir } // namespace mlir

View File

@ -15,30 +15,30 @@ limitations under the License.
// This file implements logic for lowering MHLO dialect to Standard dialect. // This file implements logic for lowering MHLO dialect to Standard dialect.
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h" #include "llvm/ADT/StringSwitch.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Block.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BlockAndValueMapping.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" #include "mlir/IR/Block.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" #include "mlir/IR/BlockAndValueMapping.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" #include "mlir/IR/Builders.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "mlir/IR/Function.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" #include "mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassRegistry.h" #include "mlir/IR/TypeUtilities.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LogicalResult.h" #include "mlir/Pass/Pass.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/Pass/PassRegistry.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir/Support/LogicalResult.h"
using mlir::PassRegistration; using mlir::PassRegistration;
namespace mlir { namespace mlir {
namespace mhlo { namespace mhlo {
namespace { namespace {
struct LegalizeControlFlow struct LegalizeControlFlowPass
: public mlir::PassWrapper<LegalizeControlFlow, FunctionPass> { : public mlir::PassWrapper<LegalizeControlFlowPass, FunctionPass> {
// Perform the lowering to MLIR control flow. // Perform the lowering to MLIR control flow.
void runOnFunction() override; void runOnFunction() override;
}; };
@ -206,7 +206,7 @@ LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) {
return success(); return success();
} }
void LegalizeControlFlow::runOnFunction() { void LegalizeControlFlowPass::runOnFunction() {
auto func = getFunction(); auto func = getFunction();
llvm::SmallVector<IfOp, 4> if_ops; llvm::SmallVector<IfOp, 4> if_ops;
func.walk([&](IfOp op) { if_ops.push_back(op); }); func.walk([&](IfOp op) { if_ops.push_back(op); });
@ -228,9 +228,5 @@ void LegalizeControlFlow::runOnFunction() {
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>> std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
mlir::mhlo::createLegalizeControlFlowPass() { mlir::mhlo::createLegalizeControlFlowPass() {
return std::make_unique<LegalizeControlFlow>(); return std::make_unique<LegalizeControlFlowPass>();
} }
static PassRegistration<mlir::mhlo::LegalizeControlFlow> legalize_cf_pass(
"mhlo-legalize-control-flow",
"Legalize from MHLO control flow to CFG control flow");

View File

@ -13,13 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "third_party/absl/memory/memory.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "mlir/IR/Function.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/IR/PatternMatch.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir/Pass/Pass.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir { namespace mlir {
@ -128,8 +127,8 @@ struct GatherIsTorchIndexSelect : public OpRewritePattern<GatherOp> {
} }
}; };
struct LegalizeGatherToTorchIndexSelect struct LegalizeGatherToTorchIndexSelectPass
: public PassWrapper<LegalizeGatherToTorchIndexSelect, FunctionPass> { : public PassWrapper<LegalizeGatherToTorchIndexSelectPass, FunctionPass> {
/// Perform the lowering of standard dialect operations to approximations. /// Perform the lowering of standard dialect operations to approximations.
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
@ -144,9 +143,9 @@ void PopulateGatherToTorchIndexSelectPatterns(
patterns->insert<GatherIsTorchIndexSelect>(context); patterns->insert<GatherIsTorchIndexSelect>(context);
} }
static PassRegistration<LegalizeGatherToTorchIndexSelect> legalize_hlo_pass( std::unique_ptr<FunctionPass> createLegalizeGatherToTorchIndexSelectPass() {
"mhlo-legalize-gather-to-torch-index-select", return std::make_unique<LegalizeGatherToTorchIndexSelectPass>();
"Legalizes gathers to a torch index select."); }
} // namespace mhlo } // namespace mhlo
} // namespace mlir } // namespace mlir

View File

@ -16,15 +16,15 @@ limitations under the License.
// This file implements logic for lowering the tanh standard ops to an // This file implements logic for lowering the tanh standard ops to an
// approximation. // approximation.
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "mlir/IR/Function.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir/IR/PatternMatch.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir/Pass/Pass.h"
namespace mlir { namespace mlir {
namespace hlo { namespace mhlo {
namespace { namespace {
/// Emits the fast tanh approximation that is also used by XLA. /// Emits the fast tanh approximation that is also used by XLA.
@ -126,8 +126,8 @@ class ApproximateTanhLowering : public OpRewritePattern<TanhOp> {
} }
}; };
struct LegalizeTanhToApproximation struct LegalizeTanhToApproximationPass
: public PassWrapper<LegalizeTanhToApproximation, FunctionPass> { : public PassWrapper<LegalizeTanhToApproximationPass, FunctionPass> {
/// Perform the lowering of standard dialect operations to approximations. /// Perform the lowering of standard dialect operations to approximations.
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
@ -140,7 +140,7 @@ struct LegalizeTanhToApproximation
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>> std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
createLegalizeTanhToApproximationPass() { createLegalizeTanhToApproximationPass() {
return std::make_unique<LegalizeTanhToApproximation>(); return std::make_unique<LegalizeTanhToApproximationPass>();
} }
void PopulateTanhToApproximationPatterns(mlir::MLIRContext *context, void PopulateTanhToApproximationPatterns(mlir::MLIRContext *context,
@ -148,9 +148,5 @@ void PopulateTanhToApproximationPatterns(mlir::MLIRContext *context,
patterns->insert<ApproximateTanhLowering>(context); patterns->insert<ApproximateTanhLowering>(context);
} }
static PassRegistration<LegalizeTanhToApproximation> legalize_pass( } // namespace mhlo
"mhlo-legalize-tanh-to-approximation",
"Legalize tanh from standard dialect to an approximation");
} // namespace hlo
} // namespace mlir } // namespace mlir

View File

@ -15,26 +15,25 @@ limitations under the License.
// This file implements logic for lowering HLO/LHLO dialect to Linalg dialect. // This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
#include "third_party/absl/memory/memory.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/AffineExpr.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" #include "mlir/IR/AffineExpr.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" #include "mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" #include "mlir/IR/Builders.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" #include "mlir/IR/Function.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" #include "mlir/IR/Location.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" #include "mlir/IR/PatternMatch.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/IR/StandardTypes.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir/Pass/Pass.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" #include "mlir/Transforms/DialectConversion.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir { namespace mlir {
namespace { namespace {
@ -826,8 +825,8 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
// indexing_maps = [#map0, #map0, #map0], // indexing_maps = [#map0, #map0, #map0],
// iterator_types = ["parallel", "parallel"], // iterator_types = ["parallel", "parallel"],
// } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () // } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
struct LhloLegalizeToLinalg struct LhloLegalizeToLinalgPass
: public PassWrapper<LhloLegalizeToLinalg, FunctionPass> { : public PassWrapper<LhloLegalizeToLinalgPass, FunctionPass> {
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
ConversionTarget target(getContext()); ConversionTarget target(getContext());
@ -842,8 +841,8 @@ struct LhloLegalizeToLinalg
} }
}; };
struct HloLegalizeToLinalg struct HloLegalizeToLinalgPass
: public PassWrapper<HloLegalizeToLinalg, FunctionPass> { : public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
ConversionTarget target(getContext()); ConversionTarget target(getContext());
@ -861,11 +860,8 @@ struct HloLegalizeToLinalg
namespace lmhlo { namespace lmhlo {
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass() { std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass() {
return absl::make_unique<LhloLegalizeToLinalg>(); return std::make_unique<LhloLegalizeToLinalgPass>();
} }
static PassRegistration<LhloLegalizeToLinalg> legalize_lhlo_pass(
"lhlo-legalize-to-linalg", "Legalize from LHLO dialect to Linalg dialect");
} // namespace lmhlo } // namespace lmhlo
namespace mhlo { namespace mhlo {
@ -906,10 +902,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
} }
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() { std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
return absl::make_unique<HloLegalizeToLinalg>(); return std::make_unique<HloLegalizeToLinalgPass>();
} }
static PassRegistration<HloLegalizeToLinalg> legalize_hlo_pass(
"hlo-legalize-to-linalg", "Legalize from HLO dialect to Linalg dialect");
} // namespace mhlo } // namespace mhlo
} // namespace mlir } // namespace mlir

View File

@ -15,18 +15,18 @@ limitations under the License.
// This file implements logic for lowering MHLO dialect to Standard dialect. // This file implements logic for lowering MHLO dialect to Standard dialect.
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h" #include "llvm/ADT/StringSwitch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/IR/Function.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir/IR/PatternMatch.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir/Pass/Pass.h"
namespace mlir { namespace mlir {
namespace { namespace {
#include "third_party/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_legalize_to_standard.inc" #include "generated_legalize_to_standard.inc"
} // end anonymous namespace } // end anonymous namespace
namespace mhlo { namespace mhlo {
namespace { namespace {
@ -176,15 +176,15 @@ class ConvertIotaOp : public OpRewritePattern<mhlo::IotaOp> {
} // end anonymous namespace } // end anonymous namespace
namespace { namespace {
struct LegalizeToStandard struct LegalizeToStandardPass
: public PassWrapper<LegalizeToStandard, FunctionPass> { : public PassWrapper<LegalizeToStandardPass, FunctionPass> {
/// Perform the lowering to Standard dialect. /// Perform the lowering to Standard dialect.
void runOnFunction() override; void runOnFunction() override;
}; };
} // end anonymous namespace } // end anonymous namespace
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>> createLegalizeToStdPass() { std::unique_ptr<mlir::OperationPass<mlir::FuncOp>> createLegalizeToStdPass() {
return std::make_unique<LegalizeToStandard>(); return std::make_unique<LegalizeToStandardPass>();
} }
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
@ -194,14 +194,11 @@ void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
} }
/// Perform the lowering to standard dialect. /// Perform the lowering to standard dialect.
void LegalizeToStandard::runOnFunction() { void LegalizeToStandardPass::runOnFunction() {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
mlir::mhlo::PopulateMhloToStdPatterns(&patterns, &getContext()); mlir::mhlo::PopulateMhloToStdPatterns(&patterns, &getContext());
applyPatternsAndFoldGreedily(getFunction(), patterns); applyPatternsAndFoldGreedily(getFunction(), patterns);
} }
static PassRegistration<LegalizeToStandard> legalize_pass(
"mhlo-legalize-to-std", "Legalize from MHLO dialect to standard dialect");
} // end namespace mhlo } // end namespace mhlo
} // end namespace mlir } // end namespace mlir

View File

@ -15,9 +15,9 @@ limitations under the License.
// This is the legalization pattern definition file for MHLO to StandardOps. // This is the legalization pattern definition file for MHLO to StandardOps.
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td" include "mlir/Dialect/StandardOps/IR/Ops.td"
include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Nullary op patterns. // Nullary op patterns.

View File

@ -15,12 +15,11 @@ limitations under the License.
// This file implements a pass to remove redundant LHLO copy operations. // This file implements a pass to remove redundant LHLO copy operations.
#include "third_party/absl/memory/memory.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "mlir/IR/Operation.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir/Pass/Pass.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
namespace mlir { namespace mlir {
namespace lmhlo { namespace lmhlo {
@ -30,7 +29,8 @@ namespace {
// arguments. All uses of each buffer are replaced with the corresponding block // arguments. All uses of each buffer are replaced with the corresponding block
// argument and the buffer is freed. Note that this pass only works in regions // argument and the buffer is freed. Note that this pass only works in regions
// with a single block. // with a single block.
struct LhloCopyRemoval : mlir::PassWrapper<LhloCopyRemoval, OperationPass<>> { struct LhloCopyRemovalPass
: mlir::PassWrapper<LhloCopyRemovalPass, OperationPass<>> {
void runOnOperation() override { void runOnOperation() override {
llvm::SmallVector<mlir::Operation*, 2> eraseList; llvm::SmallVector<mlir::Operation*, 2> eraseList;
auto operation = getOperation(); auto operation = getOperation();
@ -95,11 +95,8 @@ struct LhloCopyRemoval : mlir::PassWrapper<LhloCopyRemoval, OperationPass<>> {
} // namespace } // namespace
std::unique_ptr<Pass> createLhloCopyRemovalPass() { std::unique_ptr<Pass> createLhloCopyRemovalPass() {
return absl::make_unique<LhloCopyRemoval>(); return std::make_unique<LhloCopyRemovalPass>();
} }
static PassRegistration<LhloCopyRemoval> copy_removal_pass(
"lhlo-copy-removal", "Removes redundant LHLO copy operations");
} // namespace lmhlo } // namespace lmhlo
} // namespace mlir } // namespace mlir

View File

@ -16,15 +16,14 @@ limitations under the License.
// This file implements logic for fusing linalg ops obtained after LHLO // This file implements logic for fusing linalg ops obtained after LHLO
// lowering. // lowering.
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "third_party/absl/memory/memory.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" #include "mlir/Pass/Pass.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Transforms/FoldUtils.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/FoldUtils.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
namespace mlir { namespace mlir {
namespace lmhlo { namespace lmhlo {
@ -32,11 +31,13 @@ namespace {
using linalg::LinalgOp; using linalg::LinalgOp;
class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> { class LhloFuseLinalgPass
: public PassWrapper<LhloFuseLinalgPass, FunctionPass> {
public: public:
LhloFuseLinalg() = default; LhloFuseLinalgPass() = default;
LhloFuseLinalg(const LhloFuseLinalg&) {} LhloFuseLinalgPass(const LhloFuseLinalgPass&) {}
LhloFuseLinalg(bool use_parallel_loops, llvm::ArrayRef<unsigned> tile_sizes) { LhloFuseLinalgPass(bool use_parallel_loops,
llvm::ArrayRef<unsigned> tile_sizes) {
tile_sizes_ = tile_sizes; tile_sizes_ = tile_sizes;
use_parallel_loops_.setValue(use_parallel_loops); use_parallel_loops_.setValue(use_parallel_loops);
} }
@ -138,14 +139,10 @@ class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> createLhloFuseLinalg( std::unique_ptr<FunctionPass> createLhloFuseLinalgPass(
bool use_parallel_loops, ArrayRef<unsigned> tile_sizes) { bool use_parallel_loops, ArrayRef<unsigned> tile_sizes) {
return absl::make_unique<LhloFuseLinalg>(use_parallel_loops, tile_sizes); return std::make_unique<LhloFuseLinalgPass>(use_parallel_loops, tile_sizes);
} }
static PassRegistration<LhloFuseLinalg> legalize_pass(
"lhlo-fuse-linalg",
"Greedily fuse linalg ops obtained after LHLO lowering.");
} // namespace lmhlo } // namespace lmhlo
} // namespace mlir } // namespace mlir

View File

@ -15,17 +15,16 @@ limitations under the License.
// This file implements logic for lowering LHLO dialect to Affine dialect. // This file implements logic for lowering LHLO dialect to Affine dialect.
#include "third_party/absl/memory/memory.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" #include "mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" #include "mlir/IR/Location.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" #include "mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "mlir/IR/StandardTypes.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir/Pass/Pass.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
namespace mlir { namespace mlir {
namespace lmhlo { namespace lmhlo {
@ -138,8 +137,8 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context,
// clang-format on // clang-format on
} }
struct LhloLegalizeToAffine struct LhloLegalizeToAffinePass
: public PassWrapper<LhloLegalizeToAffine, FunctionPass> { : public PassWrapper<LhloLegalizeToAffinePass, FunctionPass> {
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
auto func = getFunction(); auto func = getFunction();
@ -150,12 +149,9 @@ struct LhloLegalizeToAffine
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToAffinePass() { std::unique_ptr<OperationPass<FuncOp>> createLhloLegalizeToAffinePass() {
return absl::make_unique<LhloLegalizeToAffine>(); return std::make_unique<LhloLegalizeToAffinePass>();
} }
static PassRegistration<LhloLegalizeToAffine> legalize_pass(
"lhlo-legalize-to-affine", "Legalize from LHLO dialect to affine dialect");
} // namespace lmhlo } // namespace lmhlo
} // namespace mlir } // namespace mlir

View File

@ -17,25 +17,24 @@ limitations under the License.
#include <cstdint> #include <cstdint>
#include "third_party/absl/memory/memory.h" #include "llvm/ADT/ArrayRef.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/GPU/GPUDialect.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/GPU/GPUDialect.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/SCF/SCF.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" #include "mlir/IR/BlockAndValueMapping.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" #include "mlir/IR/Builders.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" #include "mlir/IR/Function.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" #include "mlir/IR/Location.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" #include "mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" #include "mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" #include "mlir/Pass/Pass.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir/Transforms/DialectConversion.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
namespace mlir { namespace mlir {
namespace lmhlo { namespace lmhlo {
@ -168,7 +167,8 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
}; };
}; };
struct LhloLegalizeToGpu : public PassWrapper<LhloLegalizeToGpu, FunctionPass> { struct LhloLegalizeToGpuPass
: public PassWrapper<LhloLegalizeToGpuPass, FunctionPass> {
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
ConversionTarget target(getContext()); ConversionTarget target(getContext());
@ -185,12 +185,9 @@ struct LhloLegalizeToGpu : public PassWrapper<LhloLegalizeToGpu, FunctionPass> {
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToGpuPass() { std::unique_ptr<FunctionPass> createLegalizeToGpuPass() {
return absl::make_unique<LhloLegalizeToGpu>(); return std::make_unique<LhloLegalizeToGpuPass>();
} }
static PassRegistration<LhloLegalizeToGpu> legalize_pass(
"lhlo-legalize-to-gpu", "Legalize from LHLO dialect to GPU dialect");
} // namespace lmhlo } // namespace lmhlo
} // namespace mlir } // namespace mlir

View File

@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" #include "mlir/IR/StandardTypes.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir/Transforms/DialectConversion.h"
namespace mlir { namespace mlir {
namespace lmhlo { namespace lmhlo {

View File

@ -13,16 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir/IR/StandardTypes.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir/Pass/Pass.h"
namespace mlir { namespace mlir {
namespace lmhlo { namespace lmhlo {
@ -57,8 +57,9 @@ class TestLhloToLLVMPass
} // namespace } // namespace
static PassRegistration<TestLhloToLLVMPass> legalize_lhlo_pass( std::unique_ptr<Pass> createTestLhloToLLVMPass() {
"test-lhlo-legalize-to-llvm", "Legalize from LHLO dialect to LLVM."); return std::make_unique<TestLhloToLLVMPass>();
}
} // namespace lmhlo } // namespace lmhlo
} // namespace mlir } // namespace mlir

View File

@ -13,17 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "third_party/absl/memory/memory.h" #include "llvm/ADT/ArrayRef.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
namespace mlir { namespace mlir {
namespace lmhlo { namespace lmhlo {
@ -690,8 +689,8 @@ class SelectAndScatterOpConverter
} }
}; };
struct LhloLegalizeToParallelLoops struct LhloLegalizeToParallelLoopsPass
: public PassWrapper<LhloLegalizeToParallelLoops, FunctionPass> { : public PassWrapper<LhloLegalizeToParallelLoopsPass, FunctionPass> {
void runOnFunction() override { void runOnFunction() override {
auto func = getFunction(); auto func = getFunction();
@ -715,16 +714,11 @@ struct LhloLegalizeToParallelLoops
} }
} }
}; };
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass() { std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass() {
return absl::make_unique<LhloLegalizeToParallelLoops>(); return std::make_unique<LhloLegalizeToParallelLoopsPass>();
} }
static PassRegistration<LhloLegalizeToParallelLoops> legalize_lhlo_pass(
"lhlo-legalize-to-parallel-loops",
"Legalize from LHLO dialect to parallel loops.");
} // namespace lmhlo } // namespace lmhlo
} // namespace mlir } // namespace mlir

View File

@ -22,18 +22,18 @@ limitations under the License.
#include <iterator> #include <iterator>
#include <numeric> #include <numeric>
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" #include "mlir-hlo/utils/hlo_utils.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" #include "mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" #include "mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h" #include "mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassRegistry.h" #include "mlir/IR/TypeUtilities.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/IR/Types.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir/Pass/Pass.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h" #include "mlir/Pass/PassRegistry.h"
using mlir::FunctionPass; using mlir::FunctionPass;
using mlir::OwningRewritePatternList; using mlir::OwningRewritePatternList;
@ -41,9 +41,9 @@ using mlir::PassRegistration;
using mlir::PassWrapper; using mlir::PassWrapper;
namespace { namespace {
class LowerComplex : public PassWrapper<LowerComplex, FunctionPass> { class LowerComplexPass : public PassWrapper<LowerComplexPass, FunctionPass> {
public: public:
explicit LowerComplex() : PassWrapper<LowerComplex, FunctionPass>() {} explicit LowerComplexPass() : PassWrapper<LowerComplexPass, FunctionPass>() {}
/// Performs the lowering to MHLO dialect. /// Performs the lowering to MHLO dialect.
void runOnFunction() override; void runOnFunction() override;
@ -54,7 +54,7 @@ namespace mlir {
namespace mhlo { namespace mhlo {
namespace { namespace {
#include "third_party/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_lower_complex.inc" #include "generated_lower_complex.inc"
} // end anonymous namespace } // end anonymous namespace
@ -66,7 +66,7 @@ void PopulateComplexLoweringPatterns(MLIRContext* context,
} // end namespace mlir } // end namespace mlir
// Lowers the complex operations that can be represented using other operations. // Lowers the complex operations that can be represented using other operations.
void LowerComplex::runOnFunction() { void LowerComplexPass::runOnFunction() {
// Add lowering patterns to the list. // Add lowering patterns to the list.
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
mlir::mhlo::PopulateComplexLoweringPatterns(&getContext(), &patterns); mlir::mhlo::PopulateComplexLoweringPatterns(&getContext(), &patterns);
@ -74,6 +74,6 @@ void LowerComplex::runOnFunction() {
applyPatternsAndFoldGreedily(getFunction(), patterns); applyPatternsAndFoldGreedily(getFunction(), patterns);
} }
static PassRegistration<LowerComplex> pass( std::unique_ptr<FunctionPass> mlir::mhlo::createLowerComplexPass() {
"mhlo-test-lower-complex", return std::make_unique<LowerComplexPass>();
"Lower complex operations into non-complex operations"); }

View File

@ -16,9 +16,9 @@ limitations under the License.
// This is the legalization pattern that converts complex operations into // This is the legalization pattern that converts complex operations into
// equivalent real value operations. // equivalent real value operations.
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td" include "mlir/Dialect/StandardOps/IR/Ops.td"
include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Binary op patterns. // Binary op patterns.

View File

@ -15,20 +15,20 @@ limitations under the License.
// This file implements logic for lowering MHLO general dot to a regular dot. // This file implements logic for lowering MHLO general dot to a regular dot.
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h" #include "llvm/ADT/StringSwitch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" #include "mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" #include "mlir/IR/Function.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "mlir/IR/Location.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" #include "mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "mlir/IR/PatternMatch.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/IR/StandardTypes.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir/IR/TypeUtilities.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir/Pass/Pass.h"
using mlir::DenseIntElementsAttr; using mlir::DenseIntElementsAttr;
using mlir::ElementsAttr; using mlir::ElementsAttr;
@ -170,8 +170,8 @@ struct GeneralDotConvert : public OpRewritePattern<mlir::mhlo::DotGeneralOp> {
} }
}; };
struct LegalizeGeneralDot struct LegalizeGeneralDotPass
: public PassWrapper<LegalizeGeneralDot, FunctionPass> { : public PassWrapper<LegalizeGeneralDotPass, FunctionPass> {
/// Lower all general dots that can be represented as a non-batched matmul. /// Lower all general dots that can be represented as a non-batched matmul.
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
@ -187,6 +187,6 @@ void mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(
patterns->insert<GeneralDotConvert>(ctx); patterns->insert<GeneralDotConvert>(ctx);
} }
static PassRegistration<LegalizeGeneralDot> legalize_pass( std::unique_ptr<::mlir::Pass> mlir::mhlo::createLegalizeGeneralDotPass() {
"mhlo-test-lower-general-dot", return std::make_unique<LegalizeGeneralDotPass>();
"Tests lowering general dot to a non-batched dot when possible"); }

View File

@ -15,12 +15,12 @@ limitations under the License.
#include <numeric> #include <numeric>
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" #include "mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" #include "mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" #include "mlir/IR/PatternMatch.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/Transforms/DialectConversion.h"
namespace mlir { namespace mlir {
namespace mhlo { namespace mhlo {

View File

@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" #include "mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" #include "mlir/IR/PatternMatch.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/Pass/Pass.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir/Transforms/DialectConversion.h"
namespace mlir { namespace mlir {
namespace mhlo { namespace mhlo {
@ -50,9 +50,9 @@ struct TestMaterializeBroadcastsPass
} // namespace } // namespace
std::unique_ptr<::mlir::Pass> createTestMaterializeBroadcastsPass() {
return std::make_unique<TestMaterializeBroadcastsPass>();
}
} // namespace mhlo } // namespace mhlo
} // namespace mlir } // namespace mlir
static mlir::PassRegistration<mlir::mhlo::TestMaterializeBroadcastsPass> pass(
"mhlo-test-materialize-broadcasts",
"Test pass for materializing 'broadcast_dimensions' attributes");

View File

@ -18,14 +18,14 @@ limitations under the License.
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "llvm/ADT/EquivalenceClasses.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/utils/cycle_detector.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h" // TF:local_config_mlir #include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project #include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/EquivalenceClasses.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h"
// This pass has similar functionality of the fusion pass in XLA stack. // This pass has similar functionality of the fusion pass in XLA stack.
// However, unlike XLA, it targets the fully dynamic shape scenario. // However, unlike XLA, it targets the fully dynamic shape scenario.
@ -479,7 +479,7 @@ class FusionPlanner {
EquivalenceClasses<int32_t> leader_for_node_; EquivalenceClasses<int32_t> leader_for_node_;
}; };
struct MhloFusion : public mlir::PassWrapper<MhloFusion, FunctionPass> { struct MhloFusionPass : public mlir::PassWrapper<MhloFusionPass, FunctionPass> {
void runOnFunction() override { void runOnFunction() override {
FuncOp func = getFunction(); FuncOp func = getFunction();
if (!IsTargetFunc(func)) { if (!IsTargetFunc(func)) {
@ -568,12 +568,9 @@ struct MhloFusion : public mlir::PassWrapper<MhloFusion, FunctionPass> {
} // namespace } // namespace
std::unique_ptr<OperationPass<FuncOp>> createMhloFusion() { std::unique_ptr<OperationPass<FuncOp>> createMhloFusionPass() {
return std::make_unique<MhloFusion>(); return std::make_unique<MhloFusionPass>();
} }
static PassRegistration<MhloFusion> mhlo_fusion_pass(
"mhlo-fusion", "fuse mhlo ops to kLoop/kInput fusion patterns.");
} // namespace mhlo } // namespace mhlo
} // namespace mlir } // namespace mlir

View File

@ -21,18 +21,18 @@ limitations under the License.
#include <iterator> #include <iterator>
#include <numeric> #include <numeric>
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" #include "mlir-hlo/utils/hlo_utils.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" #include "mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" #include "mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h" #include "mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassRegistry.h" #include "mlir/IR/TypeUtilities.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/IR/Types.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir/Pass/Pass.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h" #include "mlir/Pass/PassRegistry.h"
using mlir::OwningRewritePatternList; using mlir::OwningRewritePatternList;

View File

@ -13,23 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" #include "mlir/IR/Operation.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/IR/PatternMatch.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
using mlir::FunctionPass; using mlir::FunctionPass;
using mlir::PassRegistration; using mlir::PassRegistration;
using mlir::PassWrapper; using mlir::PassWrapper;
namespace { namespace {
class OptimizeMhlo : public PassWrapper<OptimizeMhlo, FunctionPass> { class OptimizeMhloPass : public PassWrapper<OptimizeMhloPass, FunctionPass> {
public: public:
explicit OptimizeMhlo() : PassWrapper<OptimizeMhlo, FunctionPass>() {} explicit OptimizeMhloPass() : PassWrapper<OptimizeMhloPass, FunctionPass>() {}
/// Performs the lowering to MHLO dialect. /// Performs the lowering to MHLO dialect.
void runOnFunction() override; void runOnFunction() override;
@ -37,7 +38,7 @@ class OptimizeMhlo : public PassWrapper<OptimizeMhlo, FunctionPass> {
} // end anonymous namespace } // end anonymous namespace
// Lowers the complex operations that can be represented using other operations. // Lowers the complex operations that can be represented using other operations.
void OptimizeMhlo::runOnFunction() { void OptimizeMhloPass::runOnFunction() {
// Add lowering patterns to the list. // Add lowering patterns to the list.
mlir::OwningRewritePatternList patterns; mlir::OwningRewritePatternList patterns;
mlir::mhlo::PopulateOptimizeMHLOPatterns(&getContext(), &patterns); mlir::mhlo::PopulateOptimizeMHLOPatterns(&getContext(), &patterns);
@ -45,5 +46,6 @@ void OptimizeMhlo::runOnFunction() {
applyPatternsAndFoldGreedily(getFunction(), patterns); applyPatternsAndFoldGreedily(getFunction(), patterns);
} }
static PassRegistration<OptimizeMhlo> pass("mhlo-test-optimize", std::unique_ptr<mlir::FunctionPass> mlir::mhlo::createOptimizeMhloPass() {
"Run optional HLO optimizations."); return std::make_unique<OptimizeMhloPass>();
}

View File

@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassManager.h" #include "mlir/Pass/Pass.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h" #include "mlir/Pass/PassManager.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/RegionUtils.h" #include "mlir/Support/LLVM.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/Transforms/RegionUtils.h"
namespace mlir { namespace mlir {
namespace mhlo { namespace mhlo {
@ -29,8 +29,8 @@ namespace {
// A pass that sinks constants implicitly captured in control flow regions. This // A pass that sinks constants implicitly captured in control flow regions. This
// is necessary to export to XLA. // is necessary to export to XLA.
class SinkConstantsToControlFlow class SinkConstantsToControlFlowPass
: public mlir::PassWrapper<SinkConstantsToControlFlow, FunctionPass> { : public mlir::PassWrapper<SinkConstantsToControlFlowPass, FunctionPass> {
void runOnFunction() override { void runOnFunction() override {
getFunction().walk([](Operation* op) { getFunction().walk([](Operation* op) {
if (auto while_op = llvm::dyn_cast<WhileOp>(op)) { if (auto while_op = llvm::dyn_cast<WhileOp>(op)) {
@ -70,15 +70,10 @@ class SinkConstantsToControlFlow
} }
}; };
static mlir::PassRegistration<SinkConstantsToControlFlow> pass(
"mhlo-sink-constants-to-control-flow",
"Sink constants implicitly captured in control flow regions. This is "
"necessary to export to XLA.");
} // anonymous namespace } // anonymous namespace
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass() { std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass() {
return std::make_unique<SinkConstantsToControlFlow>(); return std::make_unique<SinkConstantsToControlFlowPass>();
} }
} // namespace mhlo } // namespace mhlo

View File

@ -13,16 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Identifier.h" #include "mlir/IR/Identifier.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h" #include "mlir/IR/OperationSupport.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
namespace mlir { namespace mlir {
namespace hlo { namespace mhlo {
namespace { namespace {
struct InferReturnTypeComponentsPattern : public RewritePattern { struct InferReturnTypeComponentsPattern : public RewritePattern {
@ -92,9 +92,10 @@ struct TestInferShapedTypeMethodsPass
}; };
} // namespace } // namespace
} // namespace hlo
} // namespace mlir
static mlir::PassRegistration<mlir::hlo::TestInferShapedTypeMethodsPass> pass( std::unique_ptr<FunctionPass> createTestInferShapedTypeMethodsPass() {
"mhlo-test-infer-shaped-type-methods", return std::make_unique<TestInferShapedTypeMethodsPass>();
"Uses test ops to invoke InferShapedTypeOpInterface methods"); }
} // namespace mhlo
} // namespace mlir

View File

@ -14,18 +14,17 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "third_party/absl/memory/memory.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Shape/IR/Shape.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" #include "mlir/IR/Function.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" #include "mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" #include "mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" #include "mlir/Pass/Pass.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/Transforms/DialectConversion.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir { namespace mlir {
namespace mhlo { namespace mhlo {
@ -204,9 +203,9 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
// clang-format on // clang-format on
} }
static PassRegistration<TransformUnrankedHloPass> transform_unranked_hlo_pass( std::unique_ptr<::mlir::Pass> createTransformUnrankedHloPass() {
"transform-unranked-hlo", return std::make_unique<TransformUnrankedHloPass>();
"Realize element-wise operations on ranked tensors where possible"); }
} // namespace mhlo } // namespace mhlo
} // namespace mlir } // namespace mlir

View File

@ -13,16 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" #include "mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" #include "mlir/IR/Builders.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" #include "mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h" #include "mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" #include "mlir/IR/Types.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/Transforms/DialectConversion.h"
namespace mlir { namespace mlir {
namespace mhlo { namespace mhlo {

View File

@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" #include "mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" #include "mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" #include "mlir/IR/PatternMatch.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir/Pass/Pass.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir/Transforms/DialectConversion.h"
namespace mlir { namespace mlir {
namespace mhlo { namespace mhlo {
@ -38,9 +38,9 @@ struct TestUnfuseBatchNormPass
} // namespace } // namespace
std::unique_ptr<::mlir::Pass> createTestUnfuseBatchNormPass() {
return std::make_unique<TestUnfuseBatchNormPass>();
}
} // namespace mhlo } // namespace mhlo
} // namespace mlir } // namespace mlir
static mlir::PassRegistration<mlir::mhlo::TestUnfuseBatchNormPass> pass(
"mhlo-test-unfuse-batch-norm",
"Test pass for materializing 'broadcast_dimensions' attributes");

View File

@ -13,15 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h" #include "mlir-hlo/utils/broadcast_utils.h"
#include <algorithm> #include <algorithm>
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/Sequence.h" #include "llvm/ADT/Sequence.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Shape/IR/Shape.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Diagnostics.h" #include "mlir/IR/Diagnostics.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
namespace mlir { namespace mlir {
namespace hlo { namespace hlo {

View File

@ -15,11 +15,11 @@ limitations under the License.
// This file defines helpers useful when creating or manipulating lhlo/hlo. // This file defines helpers useful when creating or manipulating lhlo/hlo.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h" #include "mlir-hlo/utils/convert_op_folder.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" #include "mlir/IR/TypeUtilities.h"
namespace mlir { namespace mlir {
namespace hlo { namespace hlo {

View File

@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h" #include "mlir-hlo/utils/cycle_detector.h"
#include <algorithm> #include <algorithm>
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/DenseSet.h" #include "llvm/ADT/DenseSet.h"
namespace mlir { namespace mlir {

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h" #include "mlir-hlo/utils/cycle_detector.h"
#include "third_party/tensorflow/compiler/xla/test.h" #include "third_party/tensorflow/compiler/xla/test.h"

View File

@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h" #include "mlir-hlo/utils/hlo_utils.h"
#include <numeric> #include <numeric>
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
namespace mlir { namespace mlir {
namespace hlo { namespace hlo {

View File

@ -0,0 +1,121 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir-hlo/Dialect/mhlo/IR/register.h"
#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/MlirOptMain.h"
// NOLINTNEXTLINE
static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
llvm::cl::desc("<input file>"),
llvm::cl::init("-"));
// NOLINTNEXTLINE
static llvm::cl::opt<std::string> outputFilename(
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
llvm::cl::init("-"));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> splitInputFile(
"split-input-file",
llvm::cl::desc("Split the input file into pieces and process each "
"chunk independently"),
llvm::cl::init(false));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> verifyDiagnostics(
"verify-diagnostics",
llvm::cl::desc("Check that emitted diagnostics match "
"expected-* lines on the corresponding line"),
llvm::cl::init(false));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> verifyPasses(
"verify-each",
llvm::cl::desc("Run the verifier after each transformation pass"),
llvm::cl::init(true));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> allowUnregisteredDialects(
"allow-unregistered-dialect",
llvm::cl::desc("Allow operation with no registered dialects"),
llvm::cl::init(false));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> showDialects(
"show-dialects", llvm::cl::desc("Print the list of registered dialects"),
llvm::cl::init(false));
int main(int argc, char **argv) {
mlir::registerAllDialects();
mlir::registerAllPasses();
mlir::mhlo::registerAllDialects();
mlir::mhlo::registerAllMhloPasses();
mlir::lmhlo::registerAllLmhloPasses();
llvm::InitLLVM y(argc, argv);
// Register any pass manager command line options.
mlir::registerPassManagerCLOptions();
mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run");
// Parse pass names in main to ensure static initialization completed.
llvm::cl::ParseCommandLineOptions(argc, argv,
"MLIR modular optimizer driver\n");
if (showDialects) {
mlir::MLIRContext context;
llvm::outs() << "Registered Dialects:\n";
for (mlir::Dialect *dialect : context.getRegisteredDialects()) {
llvm::outs() << dialect->getNamespace() << "\n";
}
return 0;
}
// Set up the input file.
std::string errorMessage;
auto file = mlir::openInputFile(inputFilename, &errorMessage);
if (!file) {
llvm::errs() << errorMessage << "\n";
return 1;
}
auto output = mlir::openOutputFile(outputFilename, &errorMessage);
if (!output) {
llvm::errs() << errorMessage << "\n";
exit(1);
}
if (failed(MlirOptMain(output->os(), std::move(file), passPipeline,
splitInputFile, verifyDiagnostics, verifyPasses,
allowUnregisteredDialects))) {
return 1;
}
// Keep the output file if the invocation of MlirOptMain was successful.
output->keep();
return 0;
}